算子原型InferShape接口
IMPLEMT_INFERFUNC
函数功能
封装算子的InferShape函数。
该函数传入的OpType为基于Operator类派生出来的子类,会自动生成一个类型为此子类的对象op,可以使用子类的成员函数获取输入输出描述的方法,从而进行InferShape的实现。
基于OpType派生出来的子类op的成员函数如下:
- op.set_input_x(Operator &v, const string &srcName):将网络中算子v的输出srcName设置为当前算子的输入x。
- op.get_input_desc_x():获取该算子的输入x的描述信息,返回对象为TensorDesc类型。
op.update_input_desc_x(const TensorDesc& tensorDesc):更新输入x的描述信息,包括shape、datetype与format。
- op.get_output_desc_y():获取该算子的输出y的描述信息,返回对象TensorDesc类型。
- op.update_output_desc_y(const TensorDesc& tensorDesc):更新输出y的描述信息,包括shape、datetype与format。
- op.get_attr_attr1():获取算子属性attr1的值。
函数原型
IMPLEMT_INFERFUNC(op_name, func_name)
约束说明
无。
参数说明
参数名 |
输入/输出 |
描述 |
---|---|---|
op_name |
输入 |
算子类型。 |
func_name |
输入 |
InferShape函数名,用户自定义。 |
返回值
无。
IMPLEMT_COMMON_INFERFUNC
函数功能
封装算子的Common_InferShape函数。
与IMPLEMT_INFERFUNC的区别是,此函数自动生成的一个类型为Operator类的对象op,可直接调用Operator类接口进行InferShape的实现。若InferShape方法具有通用性,可被多个算子的原型实现调用,可选择此接口实现。
函数原型
IMPLEMT_COMMON_INFERFUNC(func_name)
约束说明
无。
参数说明
参数名 |
输入/输出 |
描述 |
---|---|---|
func_name |
输入 |
InferShape函数名,用户自定义。 |
返回值
无。
INFER_FUNC_REG
函数功能
注册算子的InferShape函数。
函数原型
INFER_FUNC_REG(op_name, x)
约束说明
无。
参数说明
参数名 |
输入/输出 |
描述 |
---|---|---|
op_name |
输入 |
算子类型。 |
x |
输入 |
InferShape函数名,和IMPLEMT_INFERFUNC的InferShape函数名保持一致。 |
返回值
无。
COMMON_INFER_FUNC_REG
函数功能
注册算子的InferShape函数。
与INFER_FUNC_REG的区别是,此函数注册的InferShape函数入参为operator基类而非子类,此接口支持多算子共用同一个InferShape函数。
函数原型
COMMON_INFER_FUNC_REG(op_name, x)
约束说明
无。
参数说明
参数名 |
输入/输出 |
描述 |
---|---|---|
op_name |
输入 |
算子类型。 |
x |
输入 |
InferShape函数名,和IMPLEMT_COMMON_INFERFUNC的InferShape函数名保持一致。 |
返回值
无。
ELMTWISE_INFER_SHAPEANDTYPE
函数功能
提供公共函数宏封装,供算子开发者开发InferShape函数。该函数基于输入的shape和dtype,设置输出的shape和dtype。
例如,输入shape为(1,2,3,4),dtype为float,则该宏会设置算子的输出shape为(1,2,3,4),输出dtype为float。
函数原型
ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name)
约束说明
无。
参数说明
参数名 |
输入/输出 |
描述 |
---|---|---|
in_name |
输入 |
算子输入。 |
out_name |
输入 |
算子输出。 |
返回值
执行成功或失败。
调用示例
COMMON_INFER_FUNC_REG(DiagD, ELMTWISE_INFER_SHAPEANDTYPE("assist", "y"));
BROADCAST_INFER
函数功能
提供公共函数宏封装,供算子开发者开发InferShape函数。该函数基于2个输入的shape,设置输出的shape。该宏只是设置shape,未设置dtype。
- 如果2个输入的shape一致,会按输入的shape设置输出shape。
- 如果2个输入的shape不一致,会按照broadcast的策略,取2个输入shape的并集。
比如输入shape分别为(1,2,3,4)和(3,1,3,4),则该宏会设置算子的输出shape为(3,2,3,4)。
函数原型
BROADCAST_INFER(in1_name, in2_name, out_name)
该函数会自动调用如下函数:
graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape, const function<vector<int64_t>()> &get_in2_shape, const function<void(const vector<int64_t> &y_shape)> &set_out_shape);
约束说明
无。
参数说明
参数名 |
输入/输出 |
描述 |
---|---|---|
in1_name |
输入 |
算子第一个输入。 |
in2_name |
输入 |
算子第二个输入。 |
out_name |
输入 |
算子输出。 |
返回值
执行成功或失败。
调用示例
IMPLEMT_INFERFUNC(RightShift, RightShiftInfer) { DataType type = op.GetInputDesc("x").GetDataType(); SET_OUTPUT_TYPE(op, "z", type); return BROADCAST_INFER("x", "y", "z")(op); }