4D/5D互转接口
compute_four2five
功能说明
把给定4-D “NCHW”数据格式转换为5-D “NC1HWC0”数据格式。
您可以在ATC包的安装目录下的“python/site-packages/te/lang/cce/te_compute/dim_conv.py”查看接口定义。
约束说明
此接口暂不支持与其他TBE DSL接口混合使用。
支持的数据类型:float16。
函数原型
te.lang.cce.compute_four2five(input, raw_shape_4D)
参数说明
- input:输入tensor,4-D格式(N, C, H, W),tvm.tensor类型。
- raw_shape_4D:输入tensor的维度。
返回值:
res_tensor:转换为5-D格式(N, C1, H, W, C0)后的tensor,tvm.tensor类型
调用示例
import tvm import te.lang.cce raw_shape = (2, 32, 16, 128) in_dtype = "float16" input = tvm.placeholder(raw_shape, name='input', dtype=in_dtype) res = te.lang.cce.compute_four2five(input, raw_shape) # res.shape = (2,(32+15)//16,16,128,16)
compute_five2four
功能说明
把给定5-D “NC1HWC0”数据格式转换为4-D “NCHW”数据格式。
您可以在ATC包的安装目录下的“python/site-packages/te/lang/cce/te_compute/dim_conv.py”查看接口定义。
约束说明
此接口暂不支持与其他TBE DSL接口混合使用。
支持的数据类型:float16。
函数原型
te.lang.cce.compute_five2four(input, raw_shape_4D)
参数说明
- input:输入tensor,5-D格式(N, C1, H, W, C0),tvm.tensor类型。
- raw_shape_4D:转换后tensor的维度。
返回值
res_tensor:转换为4-D格式(N, C, H, W)后的tensor,tvm.tensor类型。
调用示例
import tvm import te.lang.cce raw_shape = (2, 32, 16, 128) input_shape = (2,(32+15)//16,16,128,16) in_dtype = "float16" input = tvm.placeholder(input_shape, name='input', dtype=in_dtype) res = te.lang.cce.compute_five2four(input, raw_shape) # res.shape = (2, 32, 16, 128)