Python function
shard_shape
shard_shape()β
max.experimental.sharding.shard_shape(global_shape, placements, mesh_shape)
Computes the per-shard shape from a global shape and placements.
For each Sharded(axis=k) placement on mesh axis i, dimension
k is divided by mesh_shape[i]. Replicated and
Partial placements leave the shape unchanged. Works with both
static and symbolic dimensions.
-
Parameters:
-
Returns:
-
The per-shard dimensions after applying every sharded placement.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!