Skip to main content

Python class

DeviceMapping

DeviceMapping​

class max.experimental.sharding.DeviceMapping

source

Bases: ABC

Abstract base for all sharding specifications.

A DeviceMapping pairs a DeviceMesh with a description of how tensor data is distributed across that mesh. Two concrete implementations exist:

  • PlacementMapping: mesh-axis-indexed, for eager per-op dispatch.
  • NamedMapping: tensor-dim-indexed, for future full-graph sharding search (for example, a Python-level transform over an op trace).

is_fully_replicated​

abstract property is_fully_replicated: bool

source

Whether every device holds a complete copy of the tensor.

Returns True if no dimension is sharded and there are no pending reductions.

is_fully_resolved​

abstract property is_fully_resolved: bool

source

Whether this spec can be used in eager dispatch.

Returns False if the spec contains compiler-only annotations (for example, priorities) that cannot be resolved without a compiler.

mesh​

abstract property mesh: DeviceMesh

source

The device mesh this sharding is defined over.

to_named_sharding()​

abstract to_named_sharding(tensor_rank)

source

Converts to tensor-dim-indexed spec for compiler lowering.

Parameters:

tensor_rank (int) – The number of dimensions in the tensor. Required because the spec must have one entry per tensor dim.

Raises:

ConversionError – If the spec contains custom placements that have no NamedMapping equivalent.

Return type:

NamedMapping

to_placements()​

abstract to_placements()

source

Converts to mesh-axis-indexed placements for eager dispatch.

Returns one Placement per mesh axis.

Raises:

ConversionError – If the spec contains features that cannot be represented as placements (for example, priorities or custom placement types without a standard equivalent).

Return type:

tuple[Placement, …]