IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

DistributedBroadcast

struct DistributedBroadcast

Distributed broadcast: copy tensor from root GPU to all GPUs.

A single instance of this op handles all participating GPUs. It receives:

  • input: The source tensor from the root GPU (P2P accessible)
  • outputs: Destination tensors, one per GPU
  • signal_buffers: Synchronization buffers for all participating GPUs
  • dev_ctxs_input: Device contexts for all participating GPUs

Implemented traits​

AnyType, ImplicitlyDestructible

Methods​

execute​

static def execute[dtype: DType, rank: Int, root: Int, target: StringSlice[StaticConstantOrigin], _trace_name: StringSlice[StaticConstantOrigin]](outputs: VariadicTensors[Output, static_specs=outputs.static_specs], input: ManagedTensorSlice[Input, static_spec=input.static_spec], signal_buffers: VariadicTensors[MutableInput, static_specs=signal_buffers.static_specs], dev_ctxs_input: DeviceContextList)

Execute distributed broadcast operation.

Limitations: - Maximum of 8 GPUs supported (MAX_GPUS). - Requires P2P access between GPUs (NVLink or PCIe P2P).

Parameters:

  • ​dtype (DType): Data type of the tensor.
  • ​rank (Int): Tensor rank (number of dimensions).
  • ​root (Int): Index of the root GPU (source of data).
  • ​target (StringSlice[StaticConstantOrigin]): Target device string for tracing.
  • ​_trace_name (StringSlice[StaticConstantOrigin]): Trace name for profiling.

Args:

  • ​outputs (VariadicTensors[Output, static_specs=outputs.static_specs]): Output tensors (one per GPU) to store broadcast results.
  • ​input (ManagedTensorSlice[Input, static_spec=input.static_spec]): Input tensor from root GPU (P2P accessible from all GPUs).
  • ​signal_buffers (VariadicTensors[MutableInput, static_specs=signal_buffers.static_specs]): Synchronization buffers for cross-GPU coordination.
  • ​dev_ctxs_input (DeviceContextList): Device contexts for participating GPUs.