Python function
local_shard_shape_from_global
local_shard_shape_from_global()
max.experimental.sharding.local_shard_shape_from_global(global_shape, mesh, placements)
Maps a global tensor shape to each device’s local shard shape.
For each Sharded placement, the corresponding tensor dimension is
split across mesh.mesh_shape[mesh_axis] ranks along that mesh axis.
Dispatch on the parent dim subtype (mirrors
sharding.types.DistributedTensorType._local_shard_shape()):
StaticDim: standard strided decomposition. When the parent size is not divisible by the mesh axis size, shard extents differ by at most one element.SymbolicDim: emit a fresh named symbolic dim"{name}_{axis_name}", where every shard gets the same symbolic size.AlgebraicDim(and any other non-static case): divide symbolically viaparent // mesh_axis_size. The resultingAlgebraicDimfolds eagerly when operands are static.
Device flat indices follow the same row-major order as
DeviceMesh.devices.
-
Parameters:
-
- global_shape (Shape) – The global tensor shape.
- mesh (DeviceMesh) – The device mesh.
- placements (Sequence[Placement]) – One
Placementper mesh axis.
-
Returns:
-
One
Shapeper device, in row-major mesh order. -
Raises:
-
- ValueError – If
placementslength does not matchmesh.ndimor a sharded axis is out of range. - NotImplementedError – If
placementscontains a placement type other thanSharded,Replicated, orPartial.
- ValueError – If
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!