Skip to main content

Mojo function

broadcast_2stage

broadcast_2stage[dtype: DType, //, ngpus: Int, pdl_level: PDLLevel = PDLLevel()](input_buffer: TileTensor[dtype, input_buffer.LayoutType, input_buffer.origin, address_space=input_buffer.address_space, linear_idx_type=input_buffer.linear_idx_type, element_size=input_buffer.element_size], output_buffer: TileTensor[dtype, output_buffer.LayoutType, output_buffer.origin, address_space=output_buffer.address_space, linear_idx_type=output_buffer.linear_idx_type, element_size=output_buffer.element_size], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext, root: Int, _max_num_blocks: Optional[Int] = None)

Two-stage broadcast: scatter from root, then allgather among all GPUs.

Note: This path is only used with 3+ GPUs. With 2 GPUs, broadcast uses the simpler 1-stage path for better performance.

This algorithm achieves better bandwidth than simple pull broadcast by:

  1. Stage 1 (Scatter): Each GPU reads 1/ngpus of the data from root and writes to its payload buffer, utilizing root's outbound NVLink bandwidth.
  2. Stage 2 (Allgather): All GPUs gather from each other in parallel, with each GPU reading (ngpus-1) chunks from other GPUs' payloads.

All GPUs (including root) participate uniformly in both stages, which better utilizes root's NVLink bandwidth and simplifies partitioning.

IMPORTANT: Signal buffers must be sized to hold at least: size_of(Signal) + (num_elements / ngpus) * size_of(dtype) This is the payload space needed for each GPU's chunk.

Parameters:

  • dtype (DType): Data dtype of tensor elements.
  • ngpus (Int): Number of GPUs participating.
  • pdl_level (PDLLevel): Control PDL behavior for the kernel.

Args:

  • input_buffer (TileTensor): Input buffer (only root's is read, but all must be valid).
  • output_buffer (TileTensor): Output buffer for THIS GPU.
  • rank_sigs (InlineArray): Signal pointers with payload space for staging.
  • ctx (DeviceContext): Device context for THIS GPU.
  • root (Int): Root GPU rank (source of broadcast data).
  • _max_num_blocks (Optional): Optional maximum number of thread blocks.

Was this page helpful?