Python class
NamedMapping
NamedMappingβ
class max.experimental.sharding.NamedMapping(_mesh, _spec=(), _unreduced=frozenset({}), _priorities=(), _memory_kind=None, _original_spec=None)
Bases: DeviceMapping
Tensor-dimension-indexed sharding (JAX PartitionSpec style).
Each entry in _spec corresponds to a tensor dimension and names the
mesh axis (or axes) that shard it:
"dp": shard this tensor dim across mesh axis"dp".("dp", "tp"): shard across both axes (multi-axis).None: this tensor dim is replicated.
Additionally:
_unreducednames mesh axes with pending reductions (analogous toPartialin the placement world). Contracting a sharded dimension produces an unreduced result that needs a collective reduction._prioritiesassigns per-dimension propagation priority for the compiler (for example, batch parallelism at priority 0, tensor parallelism at priority 1). Compiler-only, and cannot be used in eager mode._memory_kindspecifies the memory tier for this tensorβs shards (for example,"device"or"pinned_host"). Mirrors JAXβsNamedMapping.memory_kind.
-
Parameters:
-
- _mesh (DeviceMesh) β The device mesh.
- _spec (tuple[str | tuple[str, ...] | None, ...]) β One entry per tensor dimension.
- _unreduced (frozenset[str]) β The mesh axes with pending reductions.
- _priorities (tuple[int | None, ...]) β The per-dimension propagation priority (compiler-only).
- _memory_kind (str | None) β The memory tier for shard placement (for example,
"device"). - _original_spec (tuple[str | tuple[str, ...] | None, ...] | None)
from_spec()β
classmethod from_spec(spec=(), mesh=None, *, _unreduced=frozenset({}), _priorities=(), _memory_kind=None)
Creates a NamedMapping, resolving mesh from context if needed.
This is the preferred constructor when the mesh may come from an ambient context rather than being passed explicitly.
-
Parameters:
-
- spec (tuple[str | tuple[str, ...] | None, ...]) β One entry per tensor dimension.
- mesh (DeviceMesh | None) β The device mesh. If
None, falls back toDeviceMesh.default(). - _unreduced (frozenset[str]) β The mesh axes with pending reductions.
- _priorities (tuple[int | None, ...]) β The per-dimension propagation priority (compiler-only).
- _memory_kind (str | None) β The memory tier for shard placement.
-
Return type:
is_fully_replicatedβ
property is_fully_replicated: bool
Returns True if no dimension is sharded and there are no pending reductions.
is_fully_resolvedβ
property is_fully_resolved: bool
Returns True if every dimension has a concrete sharding decision.
memory_kindβ
The memory tier for shard placement (for example, "device").
meshβ
property mesh: DeviceMesh
The device mesh this sharding is defined over.
original_specβ
property original_spec: tuple[str | tuple[str, ...] | None, ...]
The caller-supplied spec before mesh resolution.
prioritiesβ
Per-dimension propagation priorities (compiler-only).
specβ
The raw spec tuple (one entry per tensor dim).
to_named_sharding()β
to_named_sharding(tensor_rank)
Returns self since this is already a NamedMapping.
-
Parameters:
-
tensor_rank (int)
-
Return type:
to_placements()β
to_placements()
Converts to mesh-axis-indexed placements for eager dispatch.
unreducedβ
Mesh axes with pending reductions.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!