4D/5D Conversion APIs
compute_four2five
Description
Converts the 4D data format NCHW to the 5D data format NC1HWC0.
The API is defined in python/site-packages/te/lang/cce/te_compute/dim_conv.py in the ATC installation path.
Restrictions
This API cannot be used in conjunction with other TBE DSL APIs.
The supported data type is float16.
Prototype
te.lang.cce.compute_four2five(input, raw_shape_4D)
Parameters
- input: a 4D (N,C,H,W) tvm.tensor for the input
- raw_shape_4D: format of the input tensor
Returns
res_tensor: a 5D tvm.tensor for the result tensor (N,C1,H,W,C0)
Example
import tvm import te.lang.cce raw_shape = (2, 32, 16, 128) in_dtype = "float16" input = tvm.placeholder(raw_shape, name='input', dtype=in_dtype) res = te.lang.cce.compute_four2five(input, raw_shape) # res.shape = (2,(32+15)//16,16,128,16)
compute_five2four
Description
Converts the 5D data format NC1HWC0 to the 4D data format NCHW.
The API is defined in python/site-packages/te/lang/cce/te_compute/dim_conv.py in the ATC installation path.
Restrictions
This API cannot be used in conjunction with other TBE DSL APIs.
The supported data type is float16.
Prototype
te.lang.cce.compute_five2four(input, raw_shape_4D)
Parameters
- input: a 5D tvm.tensor for the input tensor (N,C1,H,W,C0)
- raw_shape_4D: format of the result tensor
Returns
res_tensor: a 4D tvm.tensor for the result tensor (N, C, H, W)
Example
import tvm import te.lang.cce raw_shape = (2, 32, 16, 128) input_shape = (2,(32+15)//16,16,128,16) in_dtype = "float16" input = tvm.placeholder(input_shape, name='input', dtype=in_dtype) res = te.lang.cce.compute_five2four(input, raw_shape) # res.shape = (2, 32, 16, 128)