Skip to main content

Python class

ShardingStrategy

ShardingStrategy

class max.graph.ShardingStrategy(num_devices, shard)

source

Bases: object

Specifies how a Weight should be sharded across multiple devices.

This class encapsulates a sharding function and the number of devices over which to shard. It provides static methods for common sharding patterns like row-wise, column-wise, and replication.

Parameters:

axiswise()

static axiswise(axis, num_devices)

source

Creates an axis-wise sharding strategy.

This strategy shards the weight along a given axis.

Parameters:

  • axis (int) – The axis along which to shard the weight.
  • num_devices (int) – The number of devices to shard the weight across.

Return type:

ShardingStrategy

columnwise()

static columnwise(num_devices)

source

Creates a column-wise sharding strategy.

This strategy shards the weight along its second axis (axis=1).

Parameters:

num_devices (int) – The number of devices to shard the weight across.

Returns:

A ShardingStrategy instance configured for column-wise sharding.

Return type:

ShardingStrategy

expert_parallel()

static expert_parallel(num_devices)

source

Creates an expert parallel sharding strategy.

This strategy is designed for Module that has multiple weights, for which a single weight sharding strategy is not sufficient. This strategy is a placeholder and should not be called directly. Modules are expected to implement their own expert parallel sharding strategy.

Parameters:

num_devices (int) – The number of devices to shard the Module across.

Returns:

A ShardingStrategy instance configured for expert parallel sharding.

Return type:

ShardingStrategy

gate_up()

static gate_up(num_devices, axis=2)

source

Creates a gate_up sharding strategy.

This strategy shards weights where gate and up projections are concatenated along a specific axis.

Parameters:

  • num_devices (int) – The number of devices to shard the weight across.
  • axis (int) – The axis along which the gate and up projections are concatenated. Defaults to 2 (common for MoE experts).

Returns:

A ShardingStrategy instance configured for gate_up sharding.

Return type:

ShardingStrategy

head_aware_columnwise()

static head_aware_columnwise(num_devices, num_heads, head_dim)

source

Creates a head-aware column-wise sharding strategy for attention output projection.

This strategy shards weight columns according to attention head distribution, properly handling cases where num_heads is not evenly divisible by num_devices. It’s designed for output projection weights in attention layers where each device processes a different number of heads.

Parameters:

  • num_devices (int) – The number of devices to shard the weight across.
  • num_heads (int) – Total number of attention heads.
  • head_dim (int) – Dimension per attention head.

Returns:

A ShardingStrategy instance configured for head-aware column sharding.

Return type:

ShardingStrategy

is_colwise

property is_colwise: bool

source

Whether the sharding strategy is column-wise.

is_expert_parallel

property is_expert_parallel: bool

source

Whether the sharding strategy is expert parallel.

is_gate_up

property is_gate_up: bool

source

Whether the sharding strategy is gate_up.

is_head_aware_colwise

property is_head_aware_colwise: bool

source

Whether the sharding strategy is head-aware column-wise.

is_replicate

property is_replicate: bool

source

Whether the sharding strategy is replicate.

is_rowwise

property is_rowwise: bool

source

Whether the sharding strategy is row-wise.

is_stacked_qkv

property is_stacked_qkv: bool

source

Whether the sharding strategy is stacked QKV.

is_tensor_parallel

property is_tensor_parallel: bool

source

Whether the sharding strategy is tensor parallel.

num_devices

num_devices: int

source

The number of devices to shard the weight across.

replicate()

static replicate(num_devices)

source

Creates a replication strategy.

This strategy replicates the entire weight on each device.

Parameters:

num_devices (int) – The number of devices (primarily for consistency, as the weight is replicated).

Returns:

A ShardingStrategy instance configured for replication.

Return type:

ShardingStrategy

rowwise()

static rowwise(num_devices)

source

Creates a row-wise sharding strategy.

This strategy shards the weight along its first axis (axis=0).

Parameters:

num_devices (int) – The number of devices to shard the weight across.

Returns:

A ShardingStrategy instance configured for row-wise sharding.

Return type:

ShardingStrategy

shard

shard: Callable[[Weight, int, int], TensorValue]

source

A callable that takes a Weight, a device index, and the total number of devices, and returns the sharded TensorValue for that device.

stacked_qkv()

static stacked_qkv(num_devices, num_heads, head_dim)

source

Creates a stacked QKV sharding strategy for tensor parallel attention.

This strategy is designed for weights with shape [3 * hidden_size, hidden_size] where Q, K, and V weights are stacked together. It shards each section separately by attention heads, properly handling cases where num_heads is not evenly divisible by num_devices.

Parameters:

  • num_devices (int) – The number of devices to shard the weight across.
  • num_heads (int) – Total number of attention heads.
  • head_dim (int) – Dimension per attention head.

Returns:

A ShardingStrategy instance configured for stacked QKV sharding.

Return type:

ShardingStrategy

tensor_parallel()

static tensor_parallel(num_devices)

source

Creates a tensor parallel sharding strategy.

This strategy is designed for Module that has multiple weights, for which a single weight sharding strategy is not sufficient. This strategy is a placeholder and should not be called directly. Modules are expected to implement their own tensor parallel sharding strategy.

Parameters:

num_devices (int) – The number of devices to shard the Module across.

Returns:

A ShardingStrategy instance configured for tensor parallel sharding.

Return type:

ShardingStrategy