Python module
max.experimental.sharding
Placement types and sharding specifications for distributed tensors.
This package is the single source of truth for describing how tensor data is
distributed across a DeviceMesh. It contains:
Placement types (mesh-axis-indexed primitives):
Replicated: every device holds a full copy.Sharded: tensor is split along a dimension.Partial: each device holds a partial result needing reduction.
Sharding specifications (high-level wrappers):
PlacementMapping, mesh-axis-indexed (PyTorch DTensor style). OnePlacementper mesh axis. Suitable for eager dispatch.NamedMapping, tensor-dimension-indexed (JAX PartitionSpec style). One entry per tensor dimension names the mesh axis that shards it. Suitable for compiler-driven sharding propagation.
Both spec types share the same DeviceMesh and can be converted to
each other for the standard placement vocabulary. Conversions that would
lose information raise ConversionError.
Device meshβ
DeviceMesh | An N-dimensional logical grid of devices. |
|---|
Placementsβ
Partial | Every device holds a partial result that must be reduced. |
|---|---|
Placement | Abstract base for all placement types. |
ReduceOp | Reduction operations for partial placements. |
Replicated | Every device on this mesh axis holds the same copy of the data. |
Sharded | Every device on this mesh axis holds a slice along axis. |
Sharding specificationsβ
DeviceMapping | Abstract base for all sharding specifications. |
|---|---|
NamedMapping | Tensor-dimension-indexed sharding (JAX PartitionSpec style). |
PlacementMapping | Mesh-axis-indexed sharding (PyTorch DTensor style). |
SpecEntry | Represent a PEP 604 union type |
Distributed typesβ
DistributedBufferType | A symbolic type for a mutable buffer distributed across a device mesh. |
|---|---|
DistributedTensorType | A symbolic type for a tensor distributed across a device mesh. |
DistributedType | Shared state and shard-shape logic for distributed type descriptors. |
TensorLayout | Metadata snapshot of a distributed tensor for rule evaluation. |
Exceptionsβ
ConversionError | Raised when a sharding spec conversion would lose information. |
|---|
Functionsβ
global_shape_from_local | Derives the global shape from one local shard's shape and placements. |
|---|---|
local_shard_shape_from_global | Maps a global tensor shape to each device's local shard shape. |
shard_shape | Computes the per-shard shape from a global shape and placements. |
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!