Skip to main content

Python class

PlacementMapping

PlacementMapping​

class max.experimental.sharding.PlacementMapping(_mesh, _placements)

source

Bases: DeviceMapping

Mesh-axis-indexed sharding (PyTorch DTensor style).

Stores one Placement per mesh axis. Each placement describes what that mesh axis does to the tensor: Shard(dim), Replicate(), or Partial(op).

This is always fully resolved and can be used directly in eager dispatch.

Parameters:

is_fully_replicated​

property is_fully_replicated: bool

source

Returns True if every mesh axis placement is Replicated.

is_fully_resolved​

property is_fully_resolved: bool

source

Returns True; placement mappings are always fully concrete.

mesh​

property mesh: DeviceMesh

source

The device mesh this sharding is defined over.

placements​

property placements: tuple[Placement, ...]

source

The raw placement tuple (one per mesh axis).

to_named_sharding()​

to_named_sharding(tensor_rank)

source

Converts to a tensor-dim-indexed NamedMapping.

Raises:

ConversionError – If any placement is not one of the standard types (Replicated, Sharded, Partial).

Parameters:

tensor_rank (int)

Return type:

NamedMapping

to_placements()​

to_placements()

source

Returns the stored placement tuple directly.

Return type:

tuple[Placement, …]