Mojo function
scatter
scatter[dtype: DType, //, ngpus: Int, dp_size: Int, in_layout: TensorLayout, in_origin: Origin[mut=in_origin.mut], pdl_level: PDLLevel = PDLLevel()](input_buffers: InlineArray[TileTensor[dtype, in_layout, in_origin], dp_size], output_buffer: TileTensor[dtype, output_buffer.LayoutType, output_buffer.origin, address_space=output_buffer.address_space, linear_idx_type=output_buffer.linear_idx_type, element_size=output_buffer.element_size], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext)
Pull-based scatter+broadcast.
Each GPU reads its replica's chunk from the root GPU via P2P. All GPUs must call this function.
Parameters:
- dtype (
DType): Data type of the tensor elements. - ngpus (
Int): Number of GPUs participating. - dp_size (
Int): Number of data-parallel replicas. - in_layout (
TensorLayout): Layout of the input TileTensors. - in_origin (
Origin): Origin of the input TileTensors. - pdl_level (
PDLLevel): Controls PDL behavior for P2P kernels.
Args:
- input_buffers (
InlineArray): Input buffers (one per DP replica) as TileTensors. - output_buffer (
TileTensor): Output buffer for THIS GPU as a TileTensor. - rank_sigs (
InlineArray): Per-GPU Signal pointers. - ctx (
DeviceContext): Device context for THIS GPU.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!