Skip to main content

Mojo function

broadcast_pull_2stage_kernel

broadcast_pull_2stage_kernel[dtype: DType, rank: Int, ngpus: Int, *, BLOCK_SIZE: Int, pdl_level: PDLLevel = PDLLevel()](result: NDBuffer[dtype, rank, MutAnyOrigin], root_input_ptr: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], num_elements: Int, my_rank: Int, root: Int)

Two-stage broadcast: scatter from root, then allgather among all GPUs.

Stage 1 (Scatter): Root's data is split into ngpus chunks. Each GPU reads its assigned chunk directly from root's input buffer and writes it to its signal payload. Non-root GPUs also write to their result buffer. Root copies all N elements from source to dest (local operation).

Stage 2 (Allgather): Non-root GPUs gather the remaining chunks from all other GPUs' signal payloads (including root's). Root skips this stage since it already has all data.

Parameters:

  • dtype (DType): Data dtype of tensor elements.
  • rank (Int): Number of dimensions in tensors.
  • ngpus (Int): Number of GPUs participating.
  • BLOCK_SIZE (Int): Number of threads per block.
  • pdl_level (PDLLevel): Control PDL behavior for the kernel.

Args:

  • result (NDBuffer): Output buffer for broadcast result.
  • root_input_ptr (UnsafePointer): Pointer to root's input data (all GPUs read from this).
  • rank_sigs (InlineArray): Signal pointers for synchronization. IMPORTANT: Signal pointers have trailing buffers for communication.
  • num_elements (Int): Number of elements to broadcast.
  • my_rank (Int): Current GPU rank.
  • root (Int): Root GPU rank (source of broadcast).

Was this page helpful?