Mojo function
shard_and_stack
shard_and_stack[axis: Int](outputs: VariadicTensors[outputs.dtype, outputs.rank, outputs.static_specs.size, Output, static_specs=outputs.static_specs], inputs: VariadicTensors[outputs.dtype, (outputs.rank - 1), inputs.static_specs.size, Input, static_specs=inputs.static_specs], dev_ctxs_input: DeviceContextPtrList[dev_ctxs_input.size])
Shard weight tensors across multiple devices for tensor parallelism.
This operation takes multiple input tensors with identical shapes and shards them along a specified axis, distributing the shards to different devices (typically GPUs for tensor parallel inference).
Parameters:
- βaxis (
Int): The dimension along which to shard the weights.
Args:
- βoutputs (
VariadicTensors): Output tensors, one per device/shard. - βinputs (
VariadicTensors): Input tensors to be sharded, all with identical shapes. - βdev_ctxs_input (
DeviceContextPtrList): Device contexts for multi-device transfers.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!