评分并提供意见反馈 :
华为采用机器翻译与人工审校相结合的方式将此文档翻译成不同语言,希望能帮助您更容易理解此文档的内容。 请注意:即使是最好的机器翻译,其准确度也不及专业翻译人员的水平。 华为对于翻译的准确性不承担任何责任,并建议您参考英文文档(已提供链接)。
check_supported函数实现
若开发者需要在算子融合阶段进行算子参数校验,则可在算子实现文件中实现check_supported函数,并在算子信息定义文件中将配置项needCheckSupport的flag参数配置为true,算子信息定义的配置可参见算子信息库定义。
check_supported函数的声明如下所示:
def check_supported(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="xx"):
check_supported函数的入参和算子接口函数保持一致(即算子的输入、输出、属性及kernel_name)。
若校验成功,则返回True;若校验失败,则返回False。
check_supported函数中可自定义实现算子输入输出dtype的校验以及shape的校验。
例如,InTopK算子的check_supported函数实现如下,实现对输入参数的数据类型的校验。
def check_supported(predictions, targets, precision, k, kernel_name='in_top_k'): prediction_dtype = predictions.get("dtype").lower() target_dtype = targets.get("dtype").lower() if prediction_dtype != "float32": return False if target_dtype != "int32": return False return True
InplaceUpdate算子的check_supported函数实现如下,实现对输入参数的数据类型以及shape的校验。
def check_supported(x, indices, v, y, kernel_name="inplace_update"): shape_indices = indices.get("shape") shape_v = v.get("shape") dtype_v = v.get("dtype").lower() reg_v_len = 1 for i in range(1, len(shape_v)): reg_v_len = reg_v_len * shape_v[i] if dtype_v in ("float32", "int32"): dtype_size = 4 else: dtype_size = 2 reg_v_size = reg_v_len * dtype_size try: if len(shape_indices) != 1 or (reg_v_size % 32 != 0): return False except RuntimeError: return False return True