Operator注册类
注册算子类型以REG_OP为起始,以“.”链接INPUT、OUTPUT、ATTR等接口注册算子的输入、输出和属性信息,最终以OP_END接口结束。注册算子类型成功后,自动生成以算子类型名称命名的类。
例如:
REG_OP(FullConnection)
.INPUT(x, TensorType::ALL())
.INPUT(w, TensorType::ALL())
.INPUT(b, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(num_output, AttrValue::INT{0})
.INFER_SHAPE_AND_TYPE(FullConnectionInfer)
.ATTR_ALL_VERIFY(FullConnectionVerify)
.OP_END()
Operator注册类接口在operator_reg.h中定义。已注册的算子及对应的头文件,请参见内置的算子类型列表。
REG_OP
函数原型
REG_OP(x)
功能说明
注册算子类型,同时自动生成算子类型的两个构造函数。
例如,注册算子的类型名称Conv2D,可调用REG_OP(Conv2D)接口,调用该接口后,定义了算子的类型名称Conv2D,同时产生Conv2D的两个构造函数,其中,Conv2D(const string& name)需指定算子名称,Conv2D()使用默认算子名称,例如“Conv2D唯一编号”。
class Conv2D : public Operator { typedef Conv2D _THIS_TYPE; public: explicit Conv2D(const string& name); explicit Conv2D(); }
参数说明
参数名 |
输入/输出 |
类型 |
描述 |
---|---|---|---|
x |
输入 |
- |
宏参数,被注册算子的类型名称 |
返回值
无。
异常处理
无。
约束说明
注册的算子类型名称需保持唯一,不能重复。
ATTR
函数原型
ATTR(x, default_value)
功能说明
注册算子属性,必须指定默认值,用户不设置算子对象的属性值时使用默认值。
注册算子属性成功后,自动生成算子属性的3个对外接口,用于获取属性的名称、获取属性的值、设置属性的值。
现以注册类型为int64_t的属性、类型为int64_t列表两种场景为例,说明所生成的算子属性接口:
- 调用ATTR(mode, AttrValue::INT{1})接口,注册属性mode,属性类型为int64_t,默认值为1。
注册属性成功后,自动生成以下接口:
static const string name_attr_mode(); // 返回属性的名称,即“mode” int64_t get_attr_mode() const; // 返回mode属性的值 _THIS_TYPE& set_attr_mode(int64_t v); // 设置mode属性的值,返回算子对象本身
- 调用ATTR(pad, AttrValue::LIST_INT{0, 0, 0, 0})接口,注册属性pad,属性类型为int64_t列表,默认值为{0,0,0,0}。
注册属性成功后,自动生成以下接口:
static const string name_attr_pad(); // 返回属性的名称,即“pad” vector<int64_t> get_attr_pad() const; ; // 返回属性pad的值 _THIS_TYPE& set_attr_pad(vector<int64_t> v); // 设置属性pad的值,返回算子对象本身
参数说明
参数名 |
输入/输出 |
类型 |
描述 |
---|---|---|---|
x |
输入 |
- |
宏参数,算子属性的名称。 |
default_value |
输入 |
- |
算子属性的值,根据不同类型指定默认值,支持的属性类型包括:
|
返回值
无。
异常处理
无。
约束说明
对于同一个算子,注册的算子属性名称需保持唯一,不能重复。
REQUIRED_ATTR
函数原型
REQUIRED_ATTR (x, type)
功能说明
注册算子属性,没有默认值,用户必须设置算子对象的属性值。
注册算子属性成功后,自动生成算子属性的3个对外接口,用于获取属性的名称、获取属性的值、设置属性的值。
例如,注册类型为int64_t的属性mode,可调用REQUIRED_ATTR (mode, Int)接口,注册算子属性成功后,会自动生成如下接口:
static const string name_attr_mode(); // 返回属性的名称,即“mode” OpInt get_attr_mode() const; // 返回mode属性的值,OpInt即int64_t _THIS_TYPE& set_attr_mode(const OpInt& v); // 设置mode属性的值,返回this对象
参数说明
参数名 |
输入/输出 |
类型 |
描述 |
---|---|---|---|
x |
输入 |
- |
宏参数,算子属性的名称。 |
type |
输入 |
- |
算子属性的类型,包括:
|
返回值
无。
异常处理
无。
约束说明
对于同一个算子,注册的算子属性名称需保持唯一,不能重复。
INPUT
函数原型
INPUT (x, t)
功能说明
注册算子输入信息。
注册算子输入信息成功后,自动生成算子输入的相关接口,用于获取算子输入的名称、设置算子输入的对应描述等。
例如,注册算子输入x,算子输入接收的数据类型为TensorType{DT_FLOAT},可调用INPUT(x, TensorType{DT_FLOAT})接口,注册算子输入成功后,自动生成以下相关接口:
static const string name_in_x(); // 返回输入的名称,即“x” _THIS_TYPE& set_input_x(Operator& v, const string& srcName);// 指定输入x与算子对象v的输出srcName存在连接关系,返回算子对象本身 _THIS_TYPE& set_input_x(Operator& v); // 指定输入x与算子对象v的索引0的输出存在连接关系,返回算子对象本身 TensorDesc get_input_desc_x(); // 返回输入x对应的描述 graphStatus update_input_desc_x(const TensorDesc& tensorDesc);// 设置输入x对应的描述,包括Shape、DataType、Format等信息,graphStatus即uint32_t类型,返回非0表示出错
参数说明
参数名 |
输入/输出 |
类型 |
描述 |
---|---|---|---|
x |
输入 |
- |
宏参数,算子输入的名称 |
t |
输入 |
- |
算子输入接收的数据类型,可以是TensorType定义的一个或多个,如果多个,通过“,”隔离,例如: TensorType{DT_FLOAT} TensorType({DT_FLOAT, DT_INT8} 关于TensorType类,请参见TensorType类说明。 |
返回值
无。
异常处理
无。
约束说明
对于同一个算子,注册的算子输入名称需保持唯一,不能重复。
OPTIONAL_INPUT
函数原型
OPTIONAL_INPUT(x, t)
功能说明
注册可选算子输入信息。
注册可选算子输入信息成功后,自动生成算子输入的相关接口,用于获取算子输入的名称、设置算子输入的对应描述等。
例如,注册算子输入b,算子输入接收的数据类型为TensorType{DT_FLOAT},可调用OPTIONAL_INPUT(b, TensorType{DT_FLOAT})接口,注册算子输入成功后,自动生成以下相关接口:
static const string name_in_b(); // 返回输入的名称,即“b” _THIS_TYPE& set_input_b(Operator& v, const string& srcName);// 指定输入b与算子对象v的输出srcName存在连接关系,返回算子对象本身 _THIS_TYPE& set_input_b(Operator& v); // 指定输入b与算子对象v的索引0的输出存在连接关系,返回算子对象本身 TensorDesc get_input_desc_b(); // 返回输入b对应的描述 graphStatus update_input_desc_b(const TensorDesc& tensorDesc);// 设置输入b对应的描述,包括Shape、DataType、Format等信息
参数说明
参数名 |
输入/输出 |
类型 |
描述 |
---|---|---|---|
x |
输入 |
- |
宏参数,算子输入的名称。 |
t |
输入 |
- |
算子输入接收的数据类型,可以是TensorType定义的一个或多个,如果多个,通过“,”隔离,例如: TensorType{DT_FLOAT} TensorType({DT_FLOAT, DT_INT8} |
返回值
无。
异常处理
无。
约束说明
对于同一个算子,注册的算子输入名称需保持唯一,不能重复。
DYNAMIC_INPUT
函数原型
DYNAMIC_INPUT (x, t)
功能说明
注册动态算子输入信息。
注册动态算子输入信息成功后,自动生成算子输入的相关接口,用于创建动态输入、设置算子输入的对应描述等。
例如,注册动态算子输入d,算子输入接收的数据类型为TensorType{DT_FLOAT},可调用DYNAMIC_INPUT(d, TensorType{DT_FLOAT})接口,注册动态算子输入成功后,自动生成以下相关接口:
_THIS_TYPE& create_dynamic_input_d(unsigned int num); // 创建动态输入d,包括num个输入 TensorDesc get_dynamic_input_desc_d(unsigned int index);// 返回动态输入d第index个描述,包括Shape、DataType、Format等信息 graphStatus update_dynamic_input_desc_d(unsigned int index, const TensorDesc& tensorDesc);// 更新动态输入d的第index个描述 _THIS_TYPE& set_dynamic_input_d(unsigned int dstIndex, Operator &v); // 指定输入d的第dstIndex个输入与算子对象v的索引0的输出存在连接关系,返回算子对象本身 _THIS_TYPE& set_dynamic_input_d(unsigned int dstIndex, Operator &v, const string &srcName); / /指定动态输入d的第dstIndex个输入与算子对象v的输出srcName存在连接关系,返回算子对象本身
参数说明
参数名 |
输入/输出 |
类型 |
描述 |
---|---|---|---|
x |
输入 |
- |
宏参数,算子输入的名称 |
t |
输入 |
- |
算子输入接收的数据类型,可以是TensorType定义的一个或多个,如果多个,通过“,”隔离,例如: TensorType{DT_FLOAT} TensorType({DT_FLOAT, DT_INT8} |
返回值
无。
异常处理
无。
约束说明
对于同一个算子,注册的算子输入名称需保持唯一,不能重复。
OUTPUT
函数原型
OUTPUT (x, t)
功能说明
注册算子输出信息。
注册算子输出信息成功后,自动生成算子输出的相关接口,用户获取算子输出的名称、获取算子输出的描述、设置算子输出的描述。
例如,注册算子输出y,算子输出接收的数据类型为TensorType{DT_FLOAT},可调用OUTPUT(y, TensorType{DT_FLOAT})接口,注册算子输出成功后,自动生成以下相关接口
static const string name_out_y();// 返回输出的名称,即“y” TensorDesc get_output_desc_y();// 返回输出y对应的描述 graphStatus update_output_desc_y(const TensorDesc& tensorDesc); );// 设置输出y对应的描述,包括Shape、DataType、Format等信息
参数说明
参数名 |
输入/输出 |
类型 |
描述 |
---|---|---|---|
x |
输入 |
- |
宏参数,算子输出的名称。 |
t |
输入 |
- |
算子输出接收的数据类型,可以是TensorType定义的一个或多个,如果多个,通过“,”隔离,例如: TensorType{DT_FLOAT} TensorType({DT_FLOAT, DT_INT8} |
返回值
无。
异常处理
无。
约束说明
对于同一个算子,注册的算子输出名称需保持唯一,不能重复。
DYNAMIC_OUTPUT
函数原型
DYNAMIC_OUTPUT (x, t)
功能说明
注册动态算子输出信息。
注册动态算子输出信息成功后,自动生成动态算子输出的相关接口,包括用于创建动态输出、设置算子输出的对应描述等
例如,注册动态算子输出d,算子输出接收的数据类型为TensorType{DT_FLOAT},可调用DYNAMIC_OUTPUT (d, TensorType{DT_FLOAT})接口,注册动态算子输出成功后,自动生成以下相关接口
_THIS_TYPE& create_dynamic_output_d(unsigned int num); // 创建动态输出d,包括num个输出 TensorDesc get_dynamic_output_desc_d(unsigned int index);// 返回动态输出d第index个描述,包括Shape、DataType、Format等信息 graphStatus update_dynamic_output_desc_d(unsigned int index, const TensorDesc& tensorDesc);// 更新动态输出d的第index个描述
参数说明
参数名 |
输入/输出 |
类型 |
描述 |
---|---|---|---|
x |
输入 |
- |
宏参数,算子输出的名称。 |
t |
输入 |
- |
算子输出接收的数据类型,可以是TensorType定义的一个或多个,如果多个,通过“,”隔离,例如: TensorType{DT_FLOAT} TensorType({DT_FLOAT, DT_INT8} |
返回值
无。
异常处理
无。
约束说明
对于同一个算子,注册的算子输出名称需保持唯一,不能重复。
INFER_SHAPE_AND_TYPE
函数原型
INFER_SHAPE_AND_TYPE (x)
功能说明
注册用于推理算子的Shape和DataType的函数。
参数说明
参数名 |
输入/输出 |
类型 |
描述 |
---|---|---|---|
x |
输入 |
- |
宏参数,算子Shape和DataType推理函数。 例如,INFER_SHAPE_AND_TYPE(FullConnectionInfer)用于注册FullConnectionInfer函数,用于推理算子的Shape和DataType。 FullConnectionInfer通过DECLARE_INFERFUNC声明,通过IMPLEMT_INFERFUNC定义,详情请参见DECLARE_INFERFUNC和IMPLEMT_INFERFUNC宏说明。 |
返回值
无。
异常处理
无。
约束说明
无。
DECLARE_INFERFUNC和IMPLEMT_INFERFUNC宏说明
在注册用于推理算子的Shape和DataType的函数前,需要先用DECLARE_INFERFUNC宏申明函数、用IMPLEMT_INFERFUNC宏定义函数。
- 声明函数
DECLARE_INFERFUNC(FullConnection, FullConnectionInfer)
DECLARE_INFERFUNC宏展开后的实现为:
namespace op { class FullConnection; } static graphStatus FullConnectionInfer(op::FullConnection& op);
- 定义函数
IMPLEMT_INFERFUNC(FullConnection, FullConnectionInfer) { // 实现细节 }
IMPLEMT_INFERFUNC宏展开后的实现为:
static graphStatus FullConnectionInfer(op::FullConnection& op){ // 实现细节 }
ATTR_ALL_VERIFY
函数原型
ATTR_ALL_VERIFY (x)
功能说明
注册算子校验函数。
参数说明
参数名 |
输入/输出 |
类型 |
描述 |
---|---|---|---|
x |
输入 |
- |
宏参数,算子校验函数。 例如,ATTR_ALL_VERIFY(FullConnectionVerify)用于注册算子校验函数为FullConnectionVerify。FullConnectionVerify通过DECLARE_VERIFIER声明,通过IMPLEMT_VERIFIER定义,详情请参见DECLARE_VERIFIER和IMPLEMT_VERIFIER宏说明。 |
返回值
无。
异常处理
无。
约束说明
无。
DECLARE_VERIFIER和IMPLEMT_VERIFIER宏说明
在注册算子校验函数前,需要先用DECLARE_VERIFIER宏申明函数、用IMPLEMT_VERIFIER宏定义函数。
- 声明函数
DECLARE_VERIFIER(FullConnection, FullConnectionVerify)
DECLARE_VERIFIER宏展开后的实现为:
namespace op { class FullConnection; } static graphStatus FullConnectionVerify(op::FullConnection op);
- 定义函数
IMPLEMT_VERIFIER(FullConnection, FullConnectionVerify) { // 实现细节 }
IMPLEMT_VERIFIER宏展开后的实现为:
static graphStatus FullConnectionVerify(op::FullConnection op){ // 实现细节 }
内置的算子类型列表
您可以在注册算子类型的头文件中查看内置的算子类型,注册各算子类型时设置的算子输入、输出、属性等信息。
算子名称 |
所在头文件 |
---|---|
Data |
array_defs.h |
Concat |
|
Flatten |
|
Reshape |
|
Split |
|
Const |
const_defs.h |
Permute |
detection_defs.h |
Add |
math_defs.h |
Mul |
|
Activation |
nn_defs.h |
BatchNorm |
|
Convolution |
|
Eltwise |
|
LRN |
|
ConvolutionDepthwise |
|
FullConnection |
|
Pooling |
|
Scale |
|
ShuffleChannel |
|
Softmax |