IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

PerShardDim

PerShardDim

class max.experimental.sharding.PerShardDim(per_shard=(), *, global_dim=None)

source

Bases: Dim

A Dim whose per_shard tuple lists one cell per mesh shard.

Used on Sharded axes when shards hold different per-device sizes (uneven static splits, dynamic axes minted per-shard). The wrapper must be projected per shard via local_dim_at() before reaching MLIR.

Allocates the wrapper, returning per_shard itself on a plain re-wrap.

Parameters:

Return type:

PerShardDim

is_static

property is_static: bool

source

True if this axis’s global extent is a static size.

Folds to the global dim first (see global_dim()), so a sharded axis whose global is static reports True even though isinstance(self, StaticDim) is False.

is_symbolic

property is_symbolic: bool

source

True if this axis’s global extent is a symbolic (named) dim.

parameters

property parameters: Iterable[SymbolicDim]

source

Distinct symbolic-dim parameters referenced across all cells.

per_shard

per_shard: tuple[Dim, ...]

source

to_mlir()

to_mlir()

source

Raises; wrappers must be projected per shard before reaching MLIR.

Return type:

NoReturn