Conv2D Operator Definition
The following uses a complex Conv2D as an example to describe how to define an operator.
Conv2D operator prototype definition:
REG_OP(Conv2D) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8})) .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8})) .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32})) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1}) .ATTR(groups, Int, 1) .ATTR(data_format, String, "NHWC") .ATTR(offset_x, Int, 0) .OP_END_FACTORY_REG(Conv2D)
The prototype definition shows that the Conv2D operator has two required inputs (x and filter), two optional inputs (bias and offset_w), two required attributes (strides and pads), and four optional attributes (dilations, groups, data_format, and offset_x). The definition code of Conv2D is as follows.
auto conv2d = op::Conv2D("Conv2d") .set_input_x(quant) .set_input_filter(conv_weight) .set_input_bias(conv_bias) .set_attr_strides({ 1, 1, 1, 1 }) .set_attr_pads({ 0, 0, 0, 0 }) .set_attr_dilations({ 1, 1, 1, 1 }); TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32); TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32); conv2d.update_input_desc_x(conv2d_input_desc_x); conv2d.update_input_desc_filter(conv2d_input_desc_filter); conv2d.update_input_desc_bias(conv2d_input_desc_bias); conv2d.update_output_desc_y(conv2d_output_desc_y);
The major steps are as follows:
- Call the operator type constructor, for example, Conv2D(const string& name) to create an operator instance, and pass the operator name (for example, Conv2d1).
auto conv2d1 = op::Conv2D("Conv2d")
- Call set_input_input name to set operator inputs.
.set_input_x(data) .set_input_filter(conv_weight) .set_input_bias(conv_bias)
data is the input node of the entire graph. It is constructed by the Data operator. For details, see Data Operator Definition.
conv_weight is a constant data constructed by the Const operator. For details, see Const Operator Definition.
conv_bias is a constant data constructed by the Const operator. For details, see Const Operator Definition.
- Call set_attr_attribute name to set operator attributes.
.set_attr_strides({1, 1, 1, 1}) // Set the strides attribute. .set_attr_pads({0, 0, 0, 0}) // Set the pads attribute. .set_attr_dilations({1, 1, 1, 1}); // Set the dilations attribute.
- For convolution operators such as Conv2D or operators that are sensitive to C axis processing, you are advised to set the format to NCHW or NHWC in the update_input_desc_input name call. The format must be consistent with the format to be processed.
TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32); TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32); conv2d.update_input_desc_x(conv2d_input_desc_x); conv2d.update_input_desc_filter(conv2d_input_desc_filter); conv2d.update_input_desc_bias(conv2d_input_desc_bias); conv2d.update_output_desc_y(conv2d_output_desc_y);