IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python class

ShardingStrategy

ShardingStrategy​

class max.graph.ShardingStrategy(num_devices, shard)

source

Bases: object

Specifies how a Weight should be sharded across multiple devices.

This class encapsulates a sharding function and the number of devices over which to shard. It provides static methods for common sharding patterns like row-wise, column-wise, and replication.

Parameters:

axiswise()​

static axiswise(axis, num_devices)

source

Creates an axis-wise sharding strategy.

This strategy shards the weight along a given axis.

Parameters:

  • axis (int) – The axis along which to shard the weight.
  • num_devices (int) – The number of devices to shard the weight across.

Return type:

ShardingStrategy

columnwise()​

static columnwise(num_devices)

source

Creates a column-wise sharding strategy.

This strategy shards the weight along its second axis (axis=1).

Parameters:

num_devices (int) – The number of devices to shard the weight across.

Returns:

A ShardingStrategy instance configured for column-wise sharding.

Return type:

ShardingStrategy

expert_parallel()​

static expert_parallel(num_devices)

source

Creates an expert parallel sharding strategy.

This strategy is designed for Module that has multiple weights, for which a single weight sharding strategy is not sufficient. This strategy is a placeholder and should not be called directly. Modules are expected to implement their own expert parallel sharding strategy.

Parameters:

num_devices (int) – The number of devices to shard the Module across.

Returns:

A ShardingStrategy instance configured for expert parallel sharding.

Return type:

ShardingStrategy

gate_up()​

static gate_up(num_devices, axis=2)

source

Creates a gate_up sharding strategy.

This strategy shards weights where gate and up projections are concatenated along a specific axis.

Parameters:

  • num_devices (int) – The number of devices to shard the weight across.
  • axis (int) – The axis along which the gate and up projections are concatenated. Defaults to 2 (common for MoE experts).

Returns:

A ShardingStrategy instance configured for gate_up sharding.

Return type:

ShardingStrategy

head_aware_columnwise()​

static head_aware_columnwise(num_devices, num_heads, head_dim)

source

Creates a head-aware column-wise sharding strategy for attention output projection.

This strategy shards weight columns according to attention head distribution, properly handling cases where num_heads is not evenly divisible by num_devices. It’s designed for output projection weights in attention layers where each device processes a different number of heads.

Parameters:

  • num_devices (int) – The number of devices to shard the weight across.
  • num_heads (int) – Total number of attention heads.
  • head_dim (int) – Dimension per attention head.

Returns:

A ShardingStrategy instance configured for head-aware column sharding.

Return type:

ShardingStrategy

is_colwise​

property is_colwise: bool

source

Whether the sharding strategy is column-wise.

is_expert_parallel​

property is_expert_parallel: bool

source

Whether the sharding strategy is expert parallel.

is_gate_up​

property is_gate_up: bool

source

Whether the sharding strategy is gate_up.

is_head_aware_colwise​

property is_head_aware_colwise: bool

source

Whether the sharding strategy is head-aware column-wise.

is_replicate​

property is_replicate: bool

source

Whether the sharding strategy is replicate.

is_rowwise​

property is_rowwise: bool

source

Whether the sharding strategy is row-wise.

is_stacked_qkv​

property is_stacked_qkv: bool

source

Whether the sharding strategy is stacked QKV.

is_tensor_parallel​

property is_tensor_parallel: bool

source

Whether the sharding strategy is tensor parallel.

num_devices​

num_devices: int

source

The number of devices to shard the weight across.

replicate()​

static replicate(num_devices)

source

Creates a replication strategy.

This strategy replicates the entire weight on each device.

Parameters:

num_devices (int) – The number of devices (primarily for consistency, as the weight is replicated).

Returns:

A ShardingStrategy instance configured for replication.

Return type:

ShardingStrategy

rowwise()​

static rowwise(num_devices)

source

Creates a row-wise sharding strategy.

This strategy shards the weight along its first axis (axis=0).

Parameters:

num_devices (int) – The number of devices to shard the weight across.

Returns:

A ShardingStrategy instance configured for row-wise sharding.

Return type:

ShardingStrategy

shard​

shard: Callable[[Weight, int, int], TensorValue]

source

A callable that takes a Weight, a device index, and the total number of devices, and returns the sharded TensorValue for that device.

stacked_qkv()​

static stacked_qkv(num_devices, num_heads, head_dim)

source

Creates a stacked QKV sharding strategy for tensor parallel attention.

This strategy is designed for weights with shape [3 * hidden_size, hidden_size] where Q, K, and V weights are stacked together. It shards each section separately by attention heads, properly handling cases where num_heads is not evenly divisible by num_devices.

Parameters:

  • num_devices (int) – The number of devices to shard the weight across.
  • num_heads (int) – Total number of attention heads.
  • head_dim (int) – Dimension per attention head.

Returns:

A ShardingStrategy instance configured for stacked QKV sharding.

Return type:

ShardingStrategy

tensor_parallel()​

static tensor_parallel(num_devices)

source

Creates a tensor parallel sharding strategy.

This strategy is designed for Module that has multiple weights, for which a single weight sharding strategy is not sufficient. This strategy is a placeholder and should not be called directly. Modules are expected to implement their own tensor parallel sharding strategy.

Parameters:

num_devices (int) – The number of devices to shard the Module across.

Returns:

A ShardingStrategy instance configured for tensor parallel sharding.

Return type:

ShardingStrategy