Skip to main content

Python class

TensorLayout

TensorLayout​

class max.experimental.sharding.TensorLayout(dtype, shape, mapping)

source

Bases: DeviceMapping

Metadata snapshot of a distributed tensor for rule evaluation.

Bundles the tensor’s dtype, shape, and distribution mapping. The mapping stays abstract (DeviceMapping) so rules work with any concrete mapping type, such as PlacementMapping or NamedMapping.

The shape is a Shape (list[Dim]), supporting both static and symbolic dimensions for graph compilation compatibility.

This class implements DeviceMapping, so sharding rules can return input TensorLayouts directly.

Parameters:

dtype​

dtype: DType

source

The element data type of the tensor.

is_fully_replicated​

property is_fully_replicated: bool

source

Whether every device holds a complete copy of the tensor.

Returns True if no dimension is sharded and there are no pending reductions.

is_fully_resolved​

property is_fully_resolved: bool

source

Whether this spec can be used in eager dispatch.

Returns False if the spec contains compiler-only annotations (e.g. priorities) that cannot be resolved without a compiler.

mapping​

mapping: DeviceMapping

source

The distribution mapping over the device mesh.

mesh​

property mesh: DeviceMesh

source

The device mesh derived from the mapping.

rank​

property rank: int

source

The number of dimensions.

shape​

shape: Shape

source

The global shape of the tensor.

to_named_sharding()​

to_named_sharding(tensor_rank)

source

Converts to tensor-dim-indexed spec for compiler lowering.

Parameters:

tensor_rank (int)

Return type:

NamedMapping

to_placements()​

to_placements()

source

Converts to mesh-axis-indexed placements for eager dispatch.

Return type:

tuple[Placement, …]