IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

Sharded

Sharded​

class max.experimental.sharding.Sharded(axis, even=True)

source

Bases: Placement

Every device on this mesh axis holds a slice along axis.

Parameters:

  • axis (int) – The tensor axis along which data is split.
  • even (bool) – When True, per-shard cells of a dynamic axis parent stay connected via uniform parent // n. When False, scatter mints fresh per-shard cells so each rank can bind to a different extent.

axis​

axis: int

source

even​

even: bool = True

source

global_dim()​

global_dim(cells)

source

Sums per-shard cells along this mesh axis.

Parameters:

cells (Dim)

Return type:

Dim

local_dim()​

local_dim(parent, mesh, mesh_axis, *, allow_symbolic_mint=True)

source

Splits parent along mesh_axis into per-shard cells.

StaticDim parents use uneven divmod; SymbolicDim parents mint fresh per-shard cells; AlgebraicDim parents raise. Wrapper parents pass through verbatim. When allow_symbolic_mint is False, a bare SymbolicDim raises instead of minting.

Parameters:

Return type:

Dim

localized_axis()​

localized_axis()

source

Returns the tensor axis this Sharded localizes.

Return type:

int | None

transition_to()​

transition_to(other)

source

Sharded-to-Replicated is allgather; Sharded-to-Sharded is all-to-all.

Parameters:

other (Placement)

Return type:

Collective