ScatterNdAdd算子(TIK方式)
功能描述
ScatterNdAdd算子通过对输入数据中的单个值或切片应用稀疏算法,从而得到输出数据。
该算子具有var、indices和updates三个关键输入。其功能为使用updates更新var中indices指定位置的数据,即在var指定位置的数据上加上update的值。
三个输入之间的关系分别为:
- 张量var的shape的维度为P。
- indices是整数张量,shape的维度(rank)为Q,索引为ref,最后一维的元素个数为K(0<K<=P),shape为[d_0, ..., d_{Q-2}, K]。
- 张量updates的shape的维度(rank)为Q-1+P-K,shape为[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]。
例子:
如上面示例所示,indices总共有4个index,每个index有两个数表示一个二维坐标,指示var中所要更新的位置。
第一个index(0, 0)表示更新的var的起始位置为(0, 0, 0)到(0, 0, 3)总共四个数。updates的第一个分片为(1, 1, 1, 1),所以更新完之后的output的结果为(2, 2, 2, 2)。
第二个index(0, 1)表示更新的var的起始位置为(0, 1, 0)到(0, 1, 3)总共四个数。updates的第二个分片为(2, 2, 2, 2),所以更新完之后的output的结果为(3, 3, 3, 3)。
第三个index(0, 2)表示更新的var的起始位置为(0, 2, 0)到(0, 2, 3)总共四个数。updates的第三个分片为(3, 3, 3, 3),所以更新完之后的output的结果为(4, 4, 4, 4)。
第四个index(1, 1)表示更新的var的起始位置为(1, 1, 0)到(1, 1, 3)总共四个数。updates的第四个分片为(4, 4, 4, 4),所以更新完之后的output的结果为(5, 5, 5, 5)。
算子分析
开发ScatterNdAdd算子前,我们需要确定算子的功能,输入、输出,算子开发方式,算子类型以及算子实现函数名称等。
- 明确算子的功能。
ScatterNdAdd算子的功能是通过对输入数据中的单个值或切片应用稀疏算法,从而得到输出数据,详细功能示例可参考功能描述。
- 明确输入和输出。
- ScatterNdAdd算子有三个输入,一个输出。
输入:var(需要更新的tensor),indices(指定需要更新的索引位置),updates(更新数据)。
输出:var(输出更新后的tensor)。
- 本样例中算子的输入var支持的数据类型为float16、float32、int32、int8、uint8;输入indices支持的数据类型为int32;输入updates支持的数据类型为float16、float32、int32、int8、uint8;输出var支持的数据类型为:float16、float32、int32、int8、uint8。
- 算子输入支持所有shape,输出shape与输入shape相同。
- 算子输入支持的format为:ND。
- ScatterNdAdd算子有三个输入,一个输出。
- 确定算子开发方式及使用的计算流程设计。
由于ScatterNdAdd算子涉及对tensor的不同维度上的不同元素同时计算,TBE DSL接口都无法满足此算子的计算要求,所以考虑使用TIK方式进行此算子的实现。
该算子实现核心的计算流程如下:
- 将indices读入到UB Buffer中,然后遍历计算var中需要更新位置的index。
- 每算出一个index,就将相应的需要更新的var分片和updates分片搬入到UB Buffer中。
- 将两个分片每个对应位置的元素相加后再搬出到GM内存中。
- 当遍历完所有的index后,就可以得到最终的计算结果。
针对这个核心计算流程,我们要设计相应的schedule策略。schedule策略在设计的时候主要考虑两个基本问题:
- 首先是shape的泛化。
由于我们的UB空间有大小限制,所以不是所有的输入shape都能在UB上放下,需要考虑分片搬运入UB buffer进行计算。此时,就需要我们根据输入的shape大小,数据类型,UB buffer空间来计算每次分片搬运的大小和需要搬入的次数。在UB空间划分的时候,要充分合理的利用UB空间来提升性能。相同的输入shape,分10次搬入UB计算完之后再搬回到GM,比分100次搬运和计算性能更优。因此,要满足不同的shape泛化,我们要根据输入的shape来计算和划分UB buffer空间,计算各个指令的参数。
- 其次是多核,double buffer等策略。
当前昇腾AI处理器有多个AI Core可以做并行计算,可以极大的提升计算的性能。对于ScatterNdAdd这个算子,总共需要三层循环,如下所示
表14-3 ScatterNdAdd算子循环列表空间
循环
GM (拆分var)
多核LOOP
GM2UB (indices)
indices LOOP
GM2UB (var + update)
Split indices LOOP
核心是遍历indices,对每个index取出var和update做计算。为了实现多核并行计算,我们将var进行分拆,让不同AI Core进行不同的index位置的计算。例如,假设我们根据前面的公式算出index的取值范围为[0, 6],这时候我们把index=0对应的updates分片放在第一个核处理,index=2对应的updates分片放在第二个核处理,以此类推,最后一个核处理index=6对应的updates分片,这样可以实现多个核的并行处理。
- 明确算子实现文件名称、算子实现函数名称以及算子的类型(OpType)。
- 算子类型需要采用大驼峰的命名方式,即采用大写字符区分不同的语义。
- 算子文件名称和算子函数名称,可选用以下任意一种命名规则:
- 用户自定义,此时需要在算子信息定义中配置opFile.value与opInterface.value。
- 不配置算子信息定义中的opFile.value与opInterface.value,FE会将OpType按照如下方式进行转换后进行算子文件名和算子函数名的匹配。转换规则如下:
- 首字符的大写字符转换为小写字符。
例如:Abc -> abc
- 小写字符后的大写字符转换为下划线+小写字符。
例如:AbcDef -> abc_def
- 紧跟数字以及大写字符后的大写字符,作为同一语义字符串,查找此字符串后的第一个小写字符,并将此小写字符的前一个大写字符转换为下划线+小写字符,其余大写字符转换为小写字符。若此字符串后不存在小写字符,则直接将此字符串中的大写字符转换为小写字符。
例如:ABCDef -> abc_def;Abc2DEf -> abc2d_ef;Abc2DEF -> abc2def;ABC2dEF -> abc2d_ef。
- 首字符的大写字符转换为小写字符。
因此本例中,算子类型定义为ScatterNdAdd;算子的实现文件名称及实现函数名称定义为scatter_nd_add。
通过以上分析,得到ScatterNdAdd算子的设计规格如下:
表14-4 ScatterNdAdd算子设计规格算子类型(OpType)
ScatterNdAdd
算子输入
name:var
Type:Tensor
shape:all
data type:
float16、float32、int32、int8、uint8
format:ND
name:indices
Type:Tensor
shape:all
data type:int32
format:ND
name:updates
Type:Tensor
shape:all
data type:
float16、float32、int32、int8、uint8
format:ND
算子输出
name:var
Type:Tensor
shape:all
data type:
float16、float32、int32、int8、uint8
format:ND
算子实现文件/实现函数名称
scatter_nd_add
算子实现
算子代码实现
ScatterNdAdd算子的详细实现代码请参见“tbe/impl/scatter_nd_add.py”,下面主要介绍关键代码原理。
ScatterNdAdd的算子实现的关键点是进行算子schedule策略的实现,包含tiling参数的计算、多核实现等。
- 接口定义。
def scatter_nd_add(var, indices, updates, var_out, use_locking=False, kernel_name="scatter_nd_add"): scatter_nd = Scatter(var, indices, updates, var_out, True, kernel_name, "vadd") scatter_nd.scatter_operator()
主要包括以下关键点:
- tiling参数计算。
定义Scatter类,并在初始化函数中进行tiling参数的计算。核心计算主要是计算每个输入的shape的大小,再根据数据类型计算需要UB多少空间。我们可以通过tbe_platform.cce_conf.get_soc_spec()接口获取到UB的实际物理空间后,根据UB大小来划分UB空间,为定义UB上的tensor做准备。后续的步骤中,我们还会使用这些数据来计算data_move、vec_add等接口的参数。设置独立的tiling模块,将其与算子计算逻辑分离可以很好的做到算子的shape泛化。对于不同的shape,我们可以在不改变计算逻辑的情况下,只改变tiling参数来优化搬运和计算的次数,来做到泛化和高性能。
class Scatter(): def __init__(self, var, indices, updates, var_out, nd_flag, kernel_name, compute_type): # 初始化tik容器 self.tik_instance = tik.Tik(tik.Dprofile()) self.nd_flag = nd_flag # 初始化三个输入的shape和数据类型 self.var_shape = var.get("shape") self.var_dtype = var.get("dtype").lower() self.indices_shape = indices.get("shape") self.indices_dtype = indices.get("dtype").lower() self.updates_shape = updates.get("shape") self.updates_dtype = updates.get("dtype").lower() # 计算三个输入的tensor的大小 self.var_ele_num = functools_reduce(lambda x, y: x * y, self.var_shape) self.indices_num = functools_reduce(lambda x, y: x * y, self.indices_shape) self.updates_num = functools_reduce(lambda x, y: x * y, self.updates_shape) self.kernel_name = kernel_name self.check_param(var_out) # 计算index的个数和最大值,用于遍历和分核。ND和非ND场景按不同的分支计算 if nd_flag: if self.indices_shape[-1] == len(self.var_shape): self.update_data_num = 1 else: self.update_data_num = functools_reduce( lambda x, y: x * y, self.var_shape[self.indices_shape[-1]:]) self.max_indice = functools_reduce( lambda x, y: x * y, self.var_shape[0:self.indices_shape[-1]]) self.index_dims = self.indices_shape[-1] else: if len(self.var_shape) > 1: self.update_data_num = functools_reduce(lambda x, y: x * y, self.var_shape[1:]) else: self.update_data_num = 1 self.max_indice = self.var_shape[0] self.index_dims = 1 # 初始化算类型,用于兼容add和sub等方法,方便实现不同的操作 self.compute_type = compute_type # 获取UB buffer空间大小,并计算一个block可以存储多少相应数据类型的数据 self.ub_size_bytes = ( tbe_platform.cce_conf.get_soc_spec( tbe_platform.cce_conf.UB_SIZE) - 8192) self.var_dtype_bytes_size = cce.cce_intrin.get_bit_len( self.var_dtype) // 8 self.indices_dtype_bytes_size = cce.cce_intrin.get_bit_len( self.indices_dtype) // 8 self.var_data_each_block = 32 // self.var_dtype_bytes_size self.indices_data_each_block = 32 // self.indices_dtype_bytes_size self.indices_ub_number = 0 self.updates_ub_number = 0 self.index_loop_num = 0 self.max_num_one_repeat = 128 if self.var_dtype in ("float32", "int32"): self.max_num_one_repeat = 64 # 计算使用的AI Core的个数,以及每个AI Core处理多少个index,对于updates分片小于32B场景采用单核 if self.update_data_num < self.var_data_each_block: self.block_num = 1 else: ai_core_num = tbe_platform.cce_conf.get_soc_spec(tbe_platform.cce_conf.CORE_NUM) self.indice_step = math.ceil(self.max_indice / ai_core_num) self.block_num = math.ceil(self.max_indice / self.indice_step) # 定义输入和输出在GM中的tensor self.var_gm = self.tik_instance.Tensor( self.var_dtype, self.var_shape, name="var_gm", scope=tik.scope_gm) self.indices_gm = self.tik_instance.Tensor( self.indices_dtype, self.indices_shape, name="indices_gm", scope=tik.scope_gm) self.updates_gm = self.tik_instance.Tensor( self.updates_dtype, self.updates_shape, name="updates_gm", scope=tik.scope_gm) self.out_gm = self.tik_instance.Tensor( self.var_dtype, self.var_shape, name="out_gm", scope=tik.scope_gm) self.vconv_dst_dtype = "float16" self.init_ub_tensor_para() self.var_vconv_ub = None self.updates_vconv_ub = None self.var_tile_vconv_ub = None self.updates_tile_vconv_ub = None self.var_ub = None self.updates_ub = None self.indices_ub = None self.var_tile_ub = None self.updates_tile_ub = None self.var_read_index = None self.updates_read_index = None self.indices_loop_index = None self.indices_tmp = None # 计算UB大小的划分,根据输入的shape大小和数据类型计算 def init_ub_tensor_para(self): updates_size_bytes = self.var_dtype_bytes_size * self.update_data_num indices_size_bytes = self.indices_dtype_bytes_size * self.indices_num need_vconv_dtype = ("int8", "uint8") # update数据类型为int8或者uint8时的计算方法 if self.var_dtype in need_vconv_dtype: vconv_dtype_bytes_size = cce.cce_intrin.get_bit_len( self.vconv_dst_dtype) vconv_data_each_block = 32 // vconv_dtype_bytes_size vconv_size_bytes = ( updates_size_bytes // self.var_dtype_bytes_size * vconv_dtype_bytes_size) # 当update和var分片能在UB上放下时优先存储这两个数据 if (updates_size_bytes + vconv_size_bytes) * 2 < ( self.ub_size_bytes * 0.9): self.updates_ub_number = math.ceil( self.update_data_num / self.var_data_each_block) * self.var_data_each_block self.vconv_ub_number = math.ceil( self.update_data_num / vconv_data_each_block) * vconv_data_each_block self.indices_ub_number = ( self.ub_size_bytes - updates_size_bytes * 2 - vconv_size_bytes * 2) // self.indices_dtype_bytes_size self.indices_ub_number = math.ceil( self.indices_ub_number / self.indices_data_each_block) * self.indices_data_each_block # 当update和var分片在UB上放不下时,如果indices能放下,优先存储indices数据 elif indices_size_bytes < (self.ub_size_bytes * 0.9): self.indices_ub_number = math.ceil( self.indices_num / self.indices_data_each_block) * self.indices_data_each_block self.updates_ub_number = ( self.ub_size_bytes - indices_size_bytes) // self.var_dtype_bytes_size // 6 self.updates_ub_number = math.ceil( self.updates_ub_number / self.var_data_each_block) * self.var_data_each_block self.vconv_ub_number = math.ceil( self.updates_ub_number / vconv_data_each_block) * vconv_data_each_block # 都放不下时,UB内存对半分 else: self.updates_ub_number = (self.ub_size_bytes // 2 // (vconv_dtype_bytes_size + self.var_dtype_bytes_size) // 2 // self.var_data_each_block * self.var_data_each_block) self.indices_ub_number = (self.ub_size_bytes // self.indices_dtype_bytes_size // 2 // self.var_data_each_block * self.var_data_each_block) self.vconv_ub_number = self.updates_ub_number # update数据类型非int8或者uint8时的处理方法 else: # 当update和var分片能在UB上放下时优先存储这两个数据 if updates_size_bytes * 2 < self.ub_size_bytes * 0.9: self.updates_ub_number = math.ceil( self.update_data_num / self.var_data_each_block) * self.var_data_each_block self.indices_ub_number = ( self.ub_size_bytes - updates_size_bytes * 2) // self.indices_dtype_bytes_size self.indices_ub_number = math.ceil( self.indices_ub_number / self.indices_data_each_block) * self.indices_data_each_block if self.indices_num < self.indices_ub_number: self.indices_ub_number = math.ceil( self.indices_num / self.indices_data_each_block ) * self.indices_data_each_block # 当update和var分片在UB上放不下时,如果indices能放下,优先存储indices数据 elif indices_size_bytes < self.ub_size_bytes * 0.9: self.indices_ub_number = math.ceil( self.indices_num / self.indices_data_each_block) * self.indices_data_each_block self.updates_ub_number = ( self.ub_size_bytes - indices_size_bytes) // 2 // self.var_dtype_bytes_size self.updates_ub_number = math.ceil( self.updates_ub_number / self.var_data_each_block) * self.var_data_each_block # 都放不下时,UB内存对半分 else: self.indices_ub_number = (self.ub_size_bytes // self.indices_dtype_bytes_size // 2 // self.indices_data_each_block * self.indices_data_each_block) self.updates_ub_number = (self.indices_ub_number // 2 // self.var_data_each_block * self.var_data_each_block) last_num = self.update_data_num % self.updates_ub_number if (last_num < self.var_data_each_block and self.update_data_num > self.updates_ub_number): self.updates_ub_number -= self.var_data_each_block
- 计算过程实现。根据tiling的计算结果,我们判断要不要使用多核。如果要使用多核,就需要设置多核循环。并且定义UB tensor的操作必须定义在多核循环内,防止编译时出现冲突。对于多核场景,每次循环都会遍历输入张量indices,在计算出index后判断该index是否在当前核的处理范围内再进行计算。
def scatter_operator(self): # 根据tiling计算结果判断能否开多核,如果需要开多核,需要指定多核循环 if self.block_num > 1: with self.tik_instance.for_range( 0, self.block_num, block_num=self.block_num) as indices_loop_index: # 初始化UB中的tensor self.init_ub_tensor() self.indices_loop_index.set_as(indices_loop_index) # 遍历indices索引计算 self.traversing_indices() else: self.init_ub_tensor() self.traversing_indices() # 通过BuildCCE接口进行算子编译,最终生成算子目标文件.o与算子描述文件.json self.tik_instance.BuildCCE( kernel_name=self.kernel_name, inputs=(self.var_gm, self.indices_gm, self.updates_gm), outputs=(self.out_gm), enable_l2=False) return self.tik_instance
- traversing_indices函数定义。该函数主要操作是将indices分片搬入到UB中,然后遍历和计算出需要更新的var对应的index。搬运的时候需要考虑最后一个分片,搬运的burst_len需要单独计算。将一个indice分片搬入到UB后,在self.updates_the_var函数中遍历当前UB中的indices,做相应的计算和处理。
def traversing_indices(self): # 计算indices需要分多少次搬入UB进行遍历,根据给indices分配的UB大小来计算 max_ub_idx_num = (self.indices_ub_number // self.index_dims * self.index_dims) indices_loop_num = self.indices_num // max_ub_idx_num if indices_loop_num > 0: with self.tik_instance.for_range( 0, indices_loop_num) as indices_loop_index: # 封装计算var分片的函数,对每一个index做更新操作,输入的参数为var和updates读取的偏移量 self.updates_the_var(indices_loop_index * max_ub_idx_num, max_ub_idx_num) # 遍历的尾巴,或者只需要一次搬入遍历的场景 indices_last_num = self.indices_num % max_ub_idx_num if indices_last_num > 0: self.updates_the_var(indices_loop_num * max_ub_idx_num, indices_last_num)
- updates_the_var函数定义。该函数的入参为当前搬运到UB的indices的位置和个数。indices的位置主要用来计算当前的indices对应的updates分片的位置,indices的个数主要用来计算需要遍历多少个index。对于当前遍历计算出来的index,判断是否在当前核心的处理范围,如果不是,就跳过不进行处理。对于每个updates分片的处理,我们仍然需要考虑UB放不下后需要分片处理。对于每个分片的处理,我们可以封装相同的规则进行处理。
def updates_the_var(self, indices_in_index, indice_num): # 计算数据搬运的burst_len indices_burst_len = math.ceil(indice_num / self.indices_data_each_block) # 将indices搬运到UB if self.indices_num == 1: self.tik_instance.data_move(self.indices_ub, self.indices_gm, 0, 1, indices_burst_len, 0, 0) else: self.tik_instance.data_move(self.indices_ub, self.indices_gm[indices_in_index], 0, 1, indices_burst_len, 0, 0) if self.nd_flag: indice_loop_num = indice_num // self.indices_shape[-1] else: indice_loop_num = indice_num # 遍历搬运到UB的indices with self.tik_instance.for_range(0, indice_loop_num) as indices_ub_index: self.get_var_read_index(indices_ub_index) if self.block_num > 1: # 判断index是否在当前核的计算范围内,如果在,进行对应的计算 with self.tik_instance.if_scope( self.indices_loop_index * self.indice_step <= self.var_read_index): with self.tik_instance.if_scope( (self.indices_loop_index + 1) * self.indice_step > self.var_read_index): if self.nd_flag: indices_in_index = indices_in_index // \ self.indices_shape[ -1] self.get_updates_read_index(indices_ub_index + indices_in_index) self.var_read_index.set_as(self.var_read_index * self.update_data_num) # 计算update和var的函数 self.calc_updates() else: if self.nd_flag: indices_in_index = indices_in_index // self.indices_shape[ -1] self.get_updates_read_index(indices_ub_index + indices_in_index) self.var_read_index.set_as(self.var_read_index * self.update_data_num) self.calc_updates() # 对updates数据进行分段遍历 def calc_updates(self): updates_loop = self.update_data_num // self.updates_ub_number if updates_loop > 0: with self.tik_instance.for_range(0, updates_loop) as loop_index: self.calc_updates_small(loop_index * self.updates_ub_number, self.updates_ub_number) last_num = self.update_data_num % self.updates_ub_number if last_num > 0: self.calc_updates_small(updates_loop * self.updates_ub_number, last_num)
- calc_updates_small函数定义。该函数主要实现每个updates分片的处理。在实现的过程中主要需要考虑非32B对齐场景,多核时序问题导致的写覆盖规避。同时,由于vec_add计算指令单条指令最大只能计算255*128(32640)个float16数据(255次repeat,每个repeat计算128个数,每次repeat计算的最大个数和数据类型相关)。因此,我们需要进行三步处理。第一步,通过tik.for_range循环计算多次,每次计算255*128个数据。剩下的通过设置repeat次数,将N*128个元素通过一条指令计算完毕。最后小于128个元素,通过设置mask进行精确计算。相同的数据量,vec_add调用次数越少,性能越高。通过三次处理,可以做到shape的泛化和最优性能。
def calc_updates_small(self, read_index_offset, element_num): # 计算一次搬运到UB的burst_len参数 updates_burst_len = math.ceil(element_num / self.var_data_each_block) # 将需要更新的var分片搬运到UB buffer self.tik_instance.data_move( self.var_ub, self.var_gm[self.var_read_index + read_index_offset], 0, 1, updates_burst_len, 0, 0) # 将需要更新的updates分片搬运到UB buffer上 self.tik_instance.data_move( self.updates_ub, self.updates_gm[self.updates_read_index + read_index_offset], 0, 1, updates_burst_len, 0, 0) # 计算非32B对齐场景尾巴的数据有多少,需要两次计算和搬运防止写覆盖 tile_ele_num = element_num % self.var_data_each_block align_offset = 0 # 非32B对齐,且大于32B的场景进行计算。并将计算结果搬出 if (tile_ele_num != 0 and self.update_data_num > self.var_data_each_block): align_ele_num = ( element_num // self.var_data_each_block * self.var_data_each_block) align_offset = ( read_index_offset + align_ele_num - (self.var_data_each_block - tile_ele_num)) self.tik_instance.data_move( self.var_tile_ub, self.var_gm[self.var_read_index + align_offset], 0, 1, 1, 0, 0) self.tik_instance.data_move( self.updates_tile_ub, self.updates_gm[self.updates_read_index + align_offset], 0, 1, 1, 0, 0) compute_loop = element_num // self.max_num_one_repeat // 255 // 对于vec_add指令,根据updates数量大小判断需要调用多少次 if compute_loop > 0: with self.tik_instance.for_range(0, compute_loop) as index: index_offset = index * self.max_num_one_repeat * 255 self.calc_process(self.max_num_one_repeat, index_offset, index_offset, index_offset, 255, False) last_loop = element_num % (self.max_num_one_repeat * 255) // self.max_num_one_repeat if last_loop > 0: index_offset = compute_loop * self.max_num_one_repeat * 255 self.calc_process(self.max_num_one_repeat, index_offset, index_offset, index_offset, last_loop, False) compute_mask = element_num % self.max_num_one_repeat if compute_mask > 0: index_offset = ( element_num // self.max_num_one_repeat * self.max_num_one_repeat) # 32B对齐场景,只需要将数据一次搬出去 if (tile_ele_num == 0 or self.update_data_num < self.var_data_each_block): self.calc_process(compute_mask, index_offset, index_offset, index_offset, 1, False) self.tik_instance.data_move( self.out_gm[self.var_read_index + read_index_offset], self.var_ub, 0, 1, updates_burst_len, 0, 0) # 非32B对齐场景,需要把对齐部分和非对齐部分分两次计算,然后搬出 else: self.calc_process(self.var_data_each_block, 0, 0, 0, 1, True) self.tik_instance.data_move( self.out_gm[self.var_read_index + align_offset], self.var_tile_ub, 0, 1, 1, 0, 0) self.calc_process(compute_mask, index_offset, index_offset, index_offset, 1, False) self.tik_instance.data_move( self.out_gm[self.var_read_index + read_index_offset], self.var_ub, 0, 1, updates_burst_len - 1, 0, 0) else: self.tik_instance.data_move( self.out_gm[self.var_read_index + read_index_offset], self.var_ub, 0, 1, updates_burst_len, 0, 0)
- calc_process函数定义。对于核心的计算指令,我们封装成calc_process函数,主要来做数据类型和计算类型的泛化。首先,对于int8和uint8类型的数据,无法直接使用vec_add接口进行计算。此时需要使用vconv进行数据类型转换再进行计算,并在计算完成之后转换回之前的数据类型。其次,由于scatter_nd_add和scatter_nd_sub计算过程一样,只是最后调用的计算指令不一样。我们可以通过参数来进行控制进行那种计算,以实现一个模板适配多个算子类型。
def calc_process(self, mask, dest_addr, src_addr1, src_addr2, repeat_times, is_tile): need_vconv_dtype = ("int8", "uint8") # 对于int8和uint8数据类型,需要进行转换后再进行计算 if self.var_dtype in need_vconv_dtype: if is_tile: self.tik_instance.vec_conv(mask, "", self.var_tile_vconv_ub[dest_addr], self.var_tile_ub[src_addr1], repeat_times, 8, 4) self.tik_instance.vec_conv(mask, "", self.updates_tile_vconv_ub[dest_addr], self.updates_tile_ub[src_addr2], repeat_times, 8, 4) compute_repeat_strid = 8 src1_ub = self.var_tile_vconv_ub src2_ub = self.updates_tile_vconv_ub dst_ub = self.var_tile_vconv_ub mask = self.var_data_each_block else: self.tik_instance.vec_conv(mask, "", self.var_vconv_ub[dest_addr], self.var_ub[src_addr1], repeat_times, 8, 4) self.tik_instance.vec_conv(mask, "", self.updates_vconv_ub[dest_addr], self.updates_ub[src_addr2], repeat_times, 8, 4) compute_repeat_strid = 8 src1_ub = self.var_vconv_ub[src_addr1] src2_ub = self.updates_vconv_ub[src_addr2] dst_ub = self.var_vconv_ub[dest_addr] else: if is_tile: compute_repeat_strid = ( self.max_num_one_repeat // self.var_data_each_block) src1_ub = self.var_tile_ub src2_ub = self.updates_tile_ub dst_ub = self.var_tile_ub mask = self.var_data_each_block else: compute_repeat_strid = ( self.max_num_one_repeat // self.var_data_each_block) src1_ub = self.var_ub[src_addr1] src2_ub = self.updates_ub[src_addr2] dst_ub = self.var_ub[dest_addr] if self.compute_type == "vadd": self.tik_instance.vec_add(mask, dst_ub, src1_ub, src2_ub, repeat_times, compute_repeat_strid, compute_repeat_strid, compute_repeat_strid) elif self.compute_type == "vsub": self.tik_instance.vec_sub(mask, dst_ub, src1_ub, src2_ub, repeat_times, compute_repeat_strid, compute_repeat_strid, compute_repeat_strid) else: raise RuntimeError("the operator [%s] is not supported" % self.compute_type) if self.var_dtype in need_vconv_dtype: if is_tile: self.tik_instance.vec_conv(mask, "", self.var_tile_ub, self.var_tile_vconv_ub, repeat_times, 4, 8) else: self.tik_instance.vec_conv(mask, "", self.var_ub[src_addr1], self.var_vconv_ub[dest_addr], repeat_times, 4, 8)
- traversing_indices函数定义。
算子适配插件实现
将原始Tensorflow的ScatterNdAdd算子或者ResourceScatterNdAdd算子解析并映射为适配昇腾AI处理器的ScatterNdAdd算子,算子属性的映射可直接调用AutoMappingFn( )接口进行实现,完整代码可参考sample样例中的“framework/tf_plugin/scatter_nd_add_plugin.cpp”文件。
算子原型定义
ScatterNdAdd算子的原型定义详细代码可参见“op_proto/scatter_nd_add.h”与“op_proto/scatter_nd_add.cpp”文件。
scatter_nd_add.h对ScatterNdAdd算子进行原型定义。
#ifndef GE_OP_ARG_MAX_H #define GE_OP_ARG_MAX_H #include "graph/operator_reg.h" namespace ge { REG_OP(ScatterNdAdd) .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) .INPUT(indices, TensorType::IndexNumberType()) .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ScatterNdAdd) #endif // GE_OP_ARG_MAX_H
IndexNumberType()的数据类型定义请参见“inc/graph/types.h”文件,此文件中定义了所有GE使用的数据类型。
scatter_nd_add.cpp对算子基本类型进行校验并推理算子的输出shape。
由于输入tensor var与updates的数据类型要求相同,所以需要对其进行校验:
IMPLEMT_VERIFIER(ScatterNdAdd, ScatterNdAddVerify) { if (!CheckTwoInputDtypeSame(op, "var", "updates")) { return GRAPH_FAILED; } return GRAPH_SUCCESS; }
将输入tensor var的shape与数据类型更新到输出tensor。
IMPLEMT_COMMON_INFERFUNC(ScatterNdAddInferShape) { Shape var_shape = op.GetInputDesc("var").GetShape(); DataType input_dtype = op.GetInputDesc("var").GetDataType(); TensorDesc td = op.GetOutputDesc("var"); td.SetShape(ge::Shape(var_shape)); td.SetDataType(input_dtype); (void)op.UpdateOutputDesc("var", td); return GRAPH_SUCCESS; }
算子信息定义
ScatterNdAdd算子的信息定义文件请参见“tbe/op_info_cfg/ai_core/<soc_version>/scatter_nd_add.ini”,由于信息定义中未配置算子实现代码的Python文件的名字opFile.vaule以及算子定义函数的名字opInterface.vaule,所以FE默认按照将算子类型中的大写字符转换为下划线加小写字符的形式去匹配算子实现文件与算子定义函数名字,匹配规则可参见4。