Skip to main content

Mojo function

shard_and_stack

shard_and_stack[axis: Int](outputs: VariadicTensors[outputs.dtype, outputs.rank, outputs.static_specs.size, Output, static_specs=outputs.static_specs], inputs: VariadicTensors[outputs.dtype, (outputs.rank - 1), inputs.static_specs.size, Input, static_specs=inputs.static_specs], dev_ctxs_input: DeviceContextPtrList[dev_ctxs_input.size])

Shard weight tensors across multiple devices for tensor parallelism.

This operation takes multiple input tensors with identical shapes and shards them along a specified axis, distributing the shards to different devices (typically GPUs for tensor parallel inference).

Parameters:

  • ​axis (Int): The dimension along which to shard the weights.

Args:

  • ​outputs (VariadicTensors): Output tensors, one per device/shard.
  • ​inputs (VariadicTensors): Input tensors to be sharded, all with identical shapes.
  • ​dev_ctxs_input (DeviceContextPtrList): Device contexts for multi-device transfers.

Was this page helpful?