Skip to main content
Log in

Mojo function

allreduce

allreduce[type: DType, rank: Int, ngpus: Int, outputs_lambda: fn[Int, DType, Int, Int, Int](Index[$2], SIMD[$1, $3]) capturing -> None](ctxs: List[DeviceContext], input_buffers: InlineArray[NDBuffer[type, rank], ngpus], output_buffers: InlineArray[NDBuffer[type, rank], ngpus], rank_sigs: InlineArray[UnsafePointer[Signal], 8], _max_num_blocks: Optional[Int] = Optional(None))

Performs an allreduce operation across multiple GPUs.

This function serves as the main entry point for performing allreduce operations across multiple GPUs. It automatically selects between two implementations:

  • A peer-to-peer (P2P) based implementation when P2P access is possible between GPUs
  • A naive implementation as fallback when P2P access is not available

The allreduce operation combines values from all GPUs using element-wise addition and distributes the result back to all GPUs.

Note: - Input and output buffers must have identical shapes across all GPUs. - The number of elements must be identical across all input/output buffers. - Performance is typically better with P2P access enabled between GPUs.

Parameters:

  • type (DType): The data type of the tensor elements (e.g. DType.float32).
  • rank (Int): The number of dimensions in the input/output tensors.
  • ngpus (Int): The number of GPUs participating in the allreduce.
  • outputs_lambda (fn[Int, DType, Int, Int, Int](Index[$2], SIMD[$1, $3]) capturing -> None): An output elementwise lambda.

Args:

  • ctxs (List[DeviceContext]): List of device contexts for each participating GPU.
  • input_buffers (InlineArray[NDBuffer[type, rank], ngpus]): Array of input tensors from each GPU, one per GPU.
  • output_buffers (InlineArray[NDBuffer[type, rank], ngpus]): Array of output tensors for each GPU to store results.
  • rank_sigs (InlineArray[UnsafePointer[Signal], 8]): Array of Signal pointers used for cross-GPU synchronization.
  • _max_num_blocks (Optional[Int]): Optional maximum number of blocks used to compute grid configuration. If not passed a dispatch table sets the grid configuration.