Skip to main content

Mojo function

scatter

scatter[dtype: DType, //, ngpus: Int, dp_size: Int, in_layout: TensorLayout, in_origin: Origin[mut=in_origin.mut], pdl_level: PDLLevel = PDLLevel()](input_buffers: InlineArray[TileTensor[dtype, in_layout, in_origin], dp_size], output_buffer: TileTensor[dtype, output_buffer.LayoutType, output_buffer.origin, address_space=output_buffer.address_space, linear_idx_type=output_buffer.linear_idx_type, element_size=output_buffer.element_size], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext)

Pull-based scatter+broadcast.

Each GPU reads its replica's chunk from the root GPU via P2P. All GPUs must call this function.

Parameters:

  • dtype (DType): Data type of the tensor elements.
  • ngpus (Int): Number of GPUs participating.
  • dp_size (Int): Number of data-parallel replicas.
  • in_layout (TensorLayout): Layout of the input TileTensors.
  • in_origin (Origin): Origin of the input TileTensors.
  • pdl_level (PDLLevel): Controls PDL behavior for P2P kernels.

Args:

  • input_buffers (InlineArray): Input buffers (one per DP replica) as TileTensors.
  • output_buffer (TileTensor): Output buffer for THIS GPU as a TileTensor.
  • rank_sigs (InlineArray): Per-GPU Signal pointers.
  • ctx (DeviceContext): Device context for THIS GPU.

Was this page helpful?