Skip to main content

Python module

max.experimental.sharding

Placement types and sharding specifications for distributed tensors.

This package is the single source of truth for describing how tensor data is distributed across a DeviceMesh. It contains:

Placement types (mesh-axis-indexed primitives):

  • Replicated: every device holds a full copy.
  • Sharded: tensor is split along a dimension.
  • Partial: each device holds a partial result needing reduction.

Sharding specifications (high-level wrappers):

  • PlacementMapping, mesh-axis-indexed (PyTorch DTensor style). One Placement per mesh axis. Suitable for eager dispatch.
  • NamedMapping, tensor-dimension-indexed (JAX PartitionSpec style). One entry per tensor dimension names the mesh axis that shards it. Suitable for compiler-driven sharding propagation.

Both spec types share the same DeviceMesh and can be converted to each other for the standard placement vocabulary. Conversions that would lose information raise ConversionError.

Device mesh​

DeviceMeshAn N-dimensional logical grid of devices.

Placements​

PartialEvery device holds a partial result that must be reduced.
PlacementAbstract base for all placement types.
ReduceOpReduction operations for partial placements.
ReplicatedEvery device on this mesh axis holds the same copy of the data.
ShardedEvery device on this mesh axis holds a slice along axis.

Sharding specifications​

DeviceMappingAbstract base for all sharding specifications.
NamedMappingTensor-dimension-indexed sharding (JAX PartitionSpec style).
PlacementMappingMesh-axis-indexed sharding (PyTorch DTensor style).
SpecEntryRepresent a PEP 604 union type

Distributed types​

DistributedBufferTypeA symbolic type for a mutable buffer distributed across a device mesh.
DistributedTensorTypeA symbolic type for a tensor distributed across a device mesh.
DistributedTypeShared state and shard-shape logic for distributed type descriptors.
TensorLayoutMetadata snapshot of a distributed tensor for rule evaluation.

Exceptions​

ConversionErrorRaised when a sharding spec conversion would lose information.

Functions​

global_shape_from_localDerives the global shape from one local shard's shape and placements.
local_shard_shape_from_globalMaps a global tensor shape to each device's local shard shape.
shard_shapeComputes the per-shard shape from a global shape and placements.