Skip to main content

Python class

NamedMapping

NamedMapping​

class max.experimental.sharding.NamedMapping(_mesh, _spec=(), _unreduced=frozenset({}), _priorities=(), _memory_kind=None, _original_spec=None)

source

Bases: DeviceMapping

Tensor-dimension-indexed sharding (JAX PartitionSpec style).

Each entry in _spec corresponds to a tensor dimension and names the mesh axis (or axes) that shard it:

  • "dp": shard this tensor dim across mesh axis "dp".
  • ("dp", "tp"): shard across both axes (multi-axis).
  • None: this tensor dim is replicated.

Additionally:

  • _unreduced names mesh axes with pending reductions (analogous to Partial in the placement world). Contracting a sharded dimension produces an unreduced result that needs a collective reduction.
  • _priorities assigns per-dimension propagation priority for the compiler (for example, batch parallelism at priority 0, tensor parallelism at priority 1). Compiler-only, and cannot be used in eager mode.
  • _memory_kind specifies the memory tier for this tensor’s shards (for example, "device" or "pinned_host"). Mirrors JAX’s NamedMapping.memory_kind.

Parameters:

  • _mesh (DeviceMesh) – The device mesh.
  • _spec (tuple[str | tuple[str, ...] | None, ...]) – One entry per tensor dimension.
  • _unreduced (frozenset[str]) – The mesh axes with pending reductions.
  • _priorities (tuple[int | None, ...]) – The per-dimension propagation priority (compiler-only).
  • _memory_kind (str | None) – The memory tier for shard placement (for example, "device").
  • _original_spec (tuple[str | tuple[str, ...] | None, ...] | None)

from_spec()​

classmethod from_spec(spec=(), mesh=None, *, _unreduced=frozenset({}), _priorities=(), _memory_kind=None)

source

Creates a NamedMapping, resolving mesh from context if needed.

This is the preferred constructor when the mesh may come from an ambient context rather than being passed explicitly.

Parameters:

  • spec (tuple[str | tuple[str, ...] | None, ...]) – One entry per tensor dimension.
  • mesh (DeviceMesh | None) – The device mesh. If None, falls back to DeviceMesh.default().
  • _unreduced (frozenset[str]) – The mesh axes with pending reductions.
  • _priorities (tuple[int | None, ...]) – The per-dimension propagation priority (compiler-only).
  • _memory_kind (str | None) – The memory tier for shard placement.

Return type:

NamedMapping

is_fully_replicated​

property is_fully_replicated: bool

source

Returns True if no dimension is sharded and there are no pending reductions.

is_fully_resolved​

property is_fully_resolved: bool

source

Returns True if every dimension has a concrete sharding decision.

memory_kind​

property memory_kind: str | None

source

The memory tier for shard placement (for example, "device").

mesh​

property mesh: DeviceMesh

source

The device mesh this sharding is defined over.

original_spec​

property original_spec: tuple[str | tuple[str, ...] | None, ...]

source

The caller-supplied spec before mesh resolution.

priorities​

property priorities: tuple[int | None, ...]

source

Per-dimension propagation priorities (compiler-only).

spec​

property spec: tuple[str | tuple[str, ...] | None, ...]

source

The raw spec tuple (one entry per tensor dim).

to_named_sharding()​

to_named_sharding(tensor_rank)

source

Returns self since this is already a NamedMapping.

Parameters:

tensor_rank (int)

Return type:

NamedMapping

to_placements()​

to_placements()

source

Converts to mesh-axis-indexed placements for eager dispatch.

Return type:

tuple[Placement, …]

unreduced​

property unreduced: frozenset[str]

source

Mesh axes with pending reductions.