Mojo function
broadcast_pull_2stage_kernel
broadcast_pull_2stage_kernel[dtype: DType, rank: Int, ngpus: Int, *, BLOCK_SIZE: Int, pdl_level: PDLLevel = PDLLevel()](result: NDBuffer[dtype, rank, MutAnyOrigin], root_input_ptr: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], num_elements: Int, my_rank: Int, root: Int)
Two-stage broadcast: scatter from root, then allgather among all GPUs.
Stage 1 (Scatter): Root's data is split into ngpus chunks. Each GPU reads its assigned chunk directly from root's input buffer and writes it to its signal payload. Non-root GPUs also write to their result buffer. Root copies all N elements from source to dest (local operation).
Stage 2 (Allgather): Non-root GPUs gather the remaining chunks from all other GPUs' signal payloads (including root's). Root skips this stage since it already has all data.
Parameters:
- dtype (
DType): Data dtype of tensor elements. - rank (
Int): Number of dimensions in tensors. - ngpus (
Int): Number of GPUs participating. - BLOCK_SIZE (
Int): Number of threads per block. - pdl_level (
PDLLevel): Control PDL behavior for the kernel.
Args:
- result (
NDBuffer): Output buffer for broadcast result. - root_input_ptr (
UnsafePointer): Pointer to root's input data (all GPUs read from this). - rank_sigs (
InlineArray): Signal pointers for synchronization. IMPORTANT: Signal pointers have trailing buffers for communication. - num_elements (
Int): Number of elements to broadcast. - my_rank (
Int): Current GPU rank. - root (
Int): Root GPU rank (source of broadcast).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!