Skip to main content

Python class

DeviceMesh

DeviceMesh​

class max.experimental.sharding.DeviceMesh(devices, mesh_shape, axis_names)

source

Bases: object

An N-dimensional logical grid of devices.

Parameters:

  • devices (tuple[Device, ...]) – A flat tuple of devices in row-major order.
  • mesh_shape (tuple[int, ...]) – The shape of the logical grid (for example, (2, 4) for DP=2, TP=4).
  • axis_names (tuple[str, ...]) – The human-readable names for each axis (for example, ("dp", "tp")).

axis_names​

axis_names: tuple[str, ...]

source

The human-readable name for each mesh axis.

axis_size()​

axis_size(axis)

source

Returns the size of a mesh axis by name or index.

Parameters:

axis (str | int) – The mesh axis to look up, either by name or by integer index.

Returns:

The number of devices along the requested axis.

Raises:

  • ValueError – If axis is a name that doesn’t exist on the mesh.
  • IndexError – If axis is an integer outside [0, ndim).

Return type:

int

default()​

static default()

source

Returns a single-device mesh on the default device (CPU).

Return type:

DeviceMesh

devices​

devices: tuple[Device, ...]

source

The devices in the mesh, in row-major order.

is_simulated​

property is_simulated: bool

source

Returns True if all mesh slots reference the same device.

A simulated mesh uses graph-level ops (add, concat, split) to emulate multi-device collectives on a single CPU or GPU.

is_single​

property is_single: bool

source

Returns True if this mesh contains exactly one device.

mesh_shape​

mesh_shape: tuple[int, ...]

source

The shape of the logical grid.

ndim​

property ndim: int

source

The number of mesh dimensions.

num_devices​

property num_devices: int

source

The total number of devices.

single()​

static single(device)

source

Creates a trivial single-device mesh.

Parameters:

device (Device) – The single device the mesh wraps.

Return type:

DeviceMesh