IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

scatter

def scatter[dtype: DType, //, ngpus: Int, dp_size: Int, in_layout: TensorLayout, in_origin: Origin[mut=in_origin.mut], pdl_level: PDLLevel = PDLLevel()](input_buffers: InlineArray[TileTensor[dtype, in_layout, in_origin], dp_size], output_buffer: TileTensor[dtype, Storage=output_buffer.Storage, address_space=output_buffer.address_space, linear_idx_type=output_buffer.linear_idx_type, element_size=output_buffer.element_size], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], Int(8)], ctx: DeviceContext)

Pull-based scatter+broadcast.

Each GPU reads its replica's chunk from the root GPU via P2P. All GPUs must call this function.

Parameters:

  • ​dtype (DType): Data type of the tensor elements.
  • ​ngpus (Int): Number of GPUs participating.
  • ​dp_size (Int): Number of data-parallel replicas.
  • ​in_layout (TensorLayout): Layout of the input TileTensors.
  • ​in_origin (Origin[mut=in_origin.mut]): Origin of the input TileTensors.
  • ​pdl_level (PDLLevel): Controls PDL behavior for P2P kernels.

Args: