Skip to main content

Python function

sharded_symbolic_dim

sharded_symbolic_dim()

max.experimental.sharding.sharded_symbolic_dim(parent, mesh, mesh_axis, device_idx)

source

Returns the per-device symbolic dim name for a sharded parent.

The naming convention "{parent}_{axis_name}_{coord}" keeps each device’s shard symbolically distinct. Without a per-device suffix every shard would share one SymbolicDim and the graph’s same name = same size rule would force the runtime sizes to be equal — wrong for uneven data parallelism.

Parameters:

  • parent (SymbolicDim) – The global symbolic dim being sharded.
  • mesh (DeviceMesh) – The device mesh.
  • mesh_axis (int) – The mesh axis index along which parent is sharded.
  • device_idx (int) – The flat device index in row-major order.

Returns:

SymbolicDim("{parent.name}_{axis_name}_{coord}"), where coord is this device’s coordinate along mesh_axis.

Return type:

SymbolicDim