Skip to main content

Python function

local_shard_shape_from_global

local_shard_shape_from_global()

max.experimental.sharding.local_shard_shape_from_global(global_shape, mesh, placements)

source

Maps a global tensor shape to each device’s local shard shape.

For each Sharded placement, the corresponding tensor dimension is split across mesh.mesh_shape[mesh_axis] ranks along that mesh axis. Dispatch on the parent dim subtype (mirrors sharding.types.DistributedTensorType._local_shard_shape()):

  • StaticDim: standard strided decomposition. When the parent size is not divisible by the mesh axis size, shard extents differ by at most one element.
  • SymbolicDim: emit a fresh named symbolic dim "{name}_{axis_name}", where every shard gets the same symbolic size.
  • AlgebraicDim (and any other non-static case): divide symbolically via parent // mesh_axis_size. The resulting AlgebraicDim folds eagerly when operands are static.

Device flat indices follow the same row-major order as DeviceMesh.devices.

Parameters:

Returns:

One Shape per device, in row-major mesh order.

Raises:

Return type:

list[Shape]