Mojo function
scatter_pull_kernel
scatter_pull_kernel[dtype: DType, BLOCK_SIZE: Int, ngpus: Int, tp_size: Int, dp_size: Int, simd_width: Int = simd_width_of[dtype, get_gpu_target()](), pdl_level: PDLLevel = PDLLevel()](output_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin], input_ptrs: InlineArray[UnsafePointer[Scalar[dtype], ImmutAnyOrigin], dp_size], chunk_num_elems: InlineArray[Int, dp_size], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], my_rank: Int)
Pull-based scatter+broadcast: each GPU reads its chunk from root.
Each GPU determines its replica index (my_rank // tp_size), then copies from input_ptrs[replica] on the root GPU to its own output buffer.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!