Mojo function
broadcast_2stage
broadcast_2stage[dtype: DType, in_layout: TensorLayout, in_origin: Origin[mut=in_origin.mut], //, ngpus: Int, pdl_level: PDLLevel = PDLLevel()](input_tensor: TileTensor[dtype, in_layout, in_origin], output_tensor: TileTensor[dtype, in_layout, output_tensor.origin], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext, root: Int, _max_num_blocks: Optional[Int] = None)
Two-stage broadcast: scatter from root, then allgather among all GPUs.
Note: This path is only used with 3+ GPUs. With 2 GPUs, broadcast uses the simpler 1-stage path for better performance.
This algorithm achieves better bandwidth than simple pull broadcast by:
- Stage 1 (Scatter): Each GPU reads 1/ngpus of the data from root and writes to its payload buffer, utilizing root's outbound GPU link bandwidth.
- Stage 2 (Allgather): All GPUs gather from each other in parallel, with each GPU reading (ngpus-1) chunks from other GPUs' payloads.
All GPUs (including root) participate uniformly in both stages, which better utilizes root's GPU link bandwidth and simplifies partitioning.
IMPORTANT: Signal buffers must be sized to hold at least: size_of(Signal) + (num_elements / ngpus) * size_of(dtype) This is the payload space needed for each GPU's chunk.
Parameters:
- dtype (
DType): Data dtype of tensor elements. - in_layout (
TensorLayout): Layout of the input TileTensor. - in_origin (
Origin): Origin of the input TileTensor. - ngpus (
Int): Number of GPUs participating. - pdl_level (
PDLLevel): Control PDL behavior for the kernel.
Args:
- input_tensor (
TileTensor): Input tensor (only root's is read, but all must be valid). - output_tensor (
TileTensor): Output tensor for THIS GPU. - rank_sigs (
InlineArray): Signal pointers with payload space for staging. - ctx (
DeviceContext): Device context for THIS GPU. - root (
Int): Root GPU rank (source of broadcast data). - _max_num_blocks (
Optional): Optional maximum number of thread blocks.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!