Skip to main content

Python function

shard_shape

shard_shape()​

max.experimental.sharding.shard_shape(global_shape, placements, mesh_shape)

source

Computes the per-shard shape from a global shape and placements.

For each Sharded(axis=k) placement on mesh axis i, dimension k is divided by mesh_shape[i]. Replicated and Partial placements leave the shape unchanged. Works with both static and symbolic dimensions.

Parameters:

Returns:

The per-shard dimensions after applying every sharded placement.

Return type:

list[Dim]