Computing Tensor Nodes for Collective Communication
You can call the operator prototype APIs to compute the tensors involved in collective communication.
AllReduce
#---------------------AllReduce test (two devices)--------------------------------- from npu_bridge.hccl import hccl_ops tensor = tf.random_uniform((1, 3), minval=1, maxval=10, dtype=tf.float32) allreduce_test = hccl_ops.allreduce(tensor , "sum")
AllGather
#---------------------AllGather test (two devices)--------------------------------- from npu_bridge.hccl import hccl_ops cCon = tf.constant([1.0,2.0,3.0]) allgather_test = hccl_ops.allgather(cCon, 2) #---------- rank 0/1 allgather _test = [1.0, 2.0, 3.0, 1.0, 2.0, 3.0] ----------
Broadcast
#---------------------Broadcast test (two devices)--------------------------------- from npu_bridge.hccl import hccl_ops cCon = tf.Variable([1.0,2.0,3.0]) input = [cCon] broadcast_test = hccl_ops.broadcast(input, 0) #---------------- rank 0/1 broadcast_test = [1.0, 2.0, 3.0] --------------------
ReduceScatter
#---------------------ReduceScatter test (two devices)----------------------------- from npu_bridge.hccl import hccl_ops cCon = tf.constant([1.0,2.0,3.0,4.0]) reducescatter_test = hccl_ops.reduce_scatter(cCon, "sum", 2) #-----------------rank 0 reducescatter _test = [2.0, 4.0] ---------------------- #-----------------rank 1 reducescatter _test = [6.0, 8.0] ----------------------
Send
#---------------------------------Send test------------------------------------- from npu_bridge.hccl import hccl_ops sr_tag = 0 dest_rank = 1 hccl_ops.send(tensor, sr_tag, dest_rank)
Receive
#---------------------Receive test (two devices)----------------------------------- from npu_bridge.hccl import hccl_ops sr_tag = 0 src_rank = 0 tensor = hccl_ops.receive(tensor.shape, tensor.dtype, sr_tag, src_rank)