Concat Compute API
concat
Description
Reconcatenates multiple input tensors based on a specified axis.
raw_tensors indicates multiple input tensors, which have the same data type.
If raw_tensors[i].shape = [D0, D1, ... Daxis(i), ...Dn], the shape of the output after concatenation based on axis is [D0, D1, ... Raxis, ...Dn].
Where, Raxis = sum(Daxis(i)).
For example:
t1 = [[1, 2, 3], [4, 5, 6]] t2 = [[7, 8, 9], [10, 11, 12]] concat([t1, t2], 0) # [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] concat([t1, t2], 1) # [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]] # The shape of tensor t1 is [2, 3]. # The shape of tensor t2 is [2, 3]. concat([t1, t2], 0).shape # [4, 3] concat([t1, t2], 1).shape # [2, 6]
The parameter axis can also be a negative number, indicating the axis + len(shape) axis, which is computed starting from the end of the dimension.
For example:
t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]] t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]] concat([t1, t2], -1)
The output is as follows:
[[[ 1, 2, 7, 4], [ 2, 3, 8, 4]], [[ 4, 4, 2, 10], [ 5, 3, 15, 11]]]
The API is defined in python/site-packages/te/lang/cce/te_compute/concat_compute.py in the ATC installation path.
Restrictions
This API cannot be used in conjunction with other TBE DSL APIs.
For input tensors, the axes except axis must have the same dimensions.
The supported data types are as follows: int8, uint8, int16, int32, float16, and float32.
Prototype
te.lang.cce.concat(raw_tensors, axis)
Parameters
- raw_tensors: tensor list, list type. The element is tvm.tensor, and the last dimension of the tensor shape must be 32-byte aligned.
- axis: axis along which a concatenation is performed. The value range is [–d, d – 1], where d indicates the dimension count of raw_tensor.
Returns
res_tensor: a tvm.tensor for the result
Example
import tvm import te.lang.cce shape1 = (64,128) shape1 = (64,128) input_dtype = "float16" data1 = tvm.placeholder(shape1, name="data1", dtype=input_dtype) data2 = tvm.placeholder(shape2, name="data1", dtype=input_dtype) data = [data1, data2] res = te.lang.cce.concat(data, 0) # res.shape = (128,128)