Python class
DistributedTensorType
DistributedTensorTypeβ
class max.experimental.sharding.DistributedTensorType(dtype, shape, mesh, placements)
Bases: DistributedType[TensorType]
A symbolic type for a tensor distributed across a device mesh.
Analogous to TensorType for single-device tensors.
Derives per-device TensorType objects via
local_types.
When a SymbolicDim is sharded along a mesh axis, the
local shard dimension becomes a new SymbolicDim named
"{original}_{axis_name}". This keeps symbolic names short and
debuggable while ensuring that sharding the same global dim on different
axes produces distinct names.
-
Parameters:
local_typesβ
property local_types: list[TensorType]
The per-device TensorType objects in mesh order.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!