Python class
DeviceMesh
DeviceMeshβ
class max.experimental.sharding.DeviceMesh(devices, mesh_shape, axis_names)
Bases: object
An N-dimensional logical grid of devices.
-
Parameters:
axis_namesβ
The human-readable name for each mesh axis.
axis_size()β
axis_size(axis)
Returns the size of a mesh axis by name or index.
-
Parameters:
-
axis (str | int) β The mesh axis to look up, either by name or by integer index.
-
Returns:
-
The number of devices along the requested axis.
-
Raises:
-
- ValueError β If
axisis a name that doesnβt exist on the mesh. - IndexError β If
axisis an integer outside[0, ndim).
- ValueError β If
-
Return type:
default()β
static default()
Returns a single-device mesh on the default device (CPU).
-
Return type:
devicesβ
The devices in the mesh, in row-major order.
is_simulatedβ
property is_simulated: bool
Returns True if all mesh slots reference the same device.
A simulated mesh uses graph-level ops (add, concat, split) to emulate multi-device collectives on a single CPU or GPU.
is_singleβ
property is_single: bool
Returns True if this mesh contains exactly one device.
mesh_shapeβ
The shape of the logical grid.
ndimβ
property ndim: int
The number of mesh dimensions.
num_devicesβ
property num_devices: int
The total number of devices.
single()β
static single(device)
Creates a trivial single-device mesh.
-
Parameters:
-
device (Device) β The single device the mesh wraps.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!