Function check_supported
To verify operator arguments in the operator fusion phase, implement the check_supported function in the operator implementation file and set the flag parameter of the needCheckSupport configuration item to true in the operator information definition file. For details about the configuration of the operator information definition, see Operator Information Library Definition.
The check_supported function is declared as follows:
def check_supported(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="xx"):
The arguments to the check_supported function call must be consistent with those to the operator API call (in terms of the operator input, output, attributes, and kernel name).
Returns True if the verification is passed; otherwise, False.
In the check_supported function, you can customize the dtype and shape verification of the operator input and output.
For example, the check_supported function of the InTopK operator can be implemented as follows to verify the input data type:
def check_supported(predictions, targets, precision, k, kernel_name='in_top_k'): prediction_dtype = predictions.get("dtype").lower() target_dtype = targets.get("dtype").lower() if prediction_dtype != "float32": return False if target_dtype != "int32": return False return True
The check_supported function of the InplaceUpdate operator can be implemented as follows to verify the input data type and shape:
def check_supported(x, indices, v, y, kernel_name="inplace_update"): shape_indices = indices.get("shape") shape_v = v.get("shape") dtype_v = v.get("dtype").lower() reg_v_len = 1 for i in range(1, len(shape_v)): reg_v_len = reg_v_len * shape_v[i] if dtype_v in ("float32", "int32"): dtype_size = 4 else: dtype_size = 2 reg_v_size = reg_v_len * dtype_size try: if len(shape_indices) != 1 or (reg_v_size % 32 != 0): return False except RuntimeError: return False return True