Python function
sharded_symbolic_dim
sharded_symbolic_dim()
max.experimental.sharding.sharded_symbolic_dim(parent, mesh, mesh_axis, device_idx)
Returns the per-device symbolic dim name for a sharded parent.
The naming convention "{parent}_{axis_name}_{coord}" keeps each
device’s shard symbolically distinct. Without a per-device suffix
every shard would share one SymbolicDim and the graph’s
same name = same size rule would force the runtime sizes to be
equal — wrong for uneven data parallelism.
-
Parameters:
-
- parent (SymbolicDim) – The global symbolic dim being sharded.
- mesh (DeviceMesh) – The device mesh.
- mesh_axis (int) – The mesh axis index along which
parentis sharded. - device_idx (int) – The flat device index in row-major order.
-
Returns:
-
SymbolicDim("{parent.name}_{axis_name}_{coord}"), wherecoordis this device’s coordinate alongmesh_axis. -
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!