Python class
ShardingStrategy
ShardingStrategy
class max.graph.ShardingStrategy(num_devices, shard)
Bases: object
Specifies how a Weight should be sharded across multiple devices.
This class encapsulates a sharding function and the number of devices over which to shard. It provides static methods for common sharding patterns like row-wise, column-wise, and replication.
axiswise()
static axiswise(axis, num_devices)
Creates an axis-wise sharding strategy.
This strategy shards the weight along a given axis.
-
Parameters:
-
Return type:
columnwise()
static columnwise(num_devices)
Creates a column-wise sharding strategy.
This strategy shards the weight along its second axis (axis=1).
-
Parameters:
-
num_devices (int) – The number of devices to shard the weight across.
-
Returns:
-
A
ShardingStrategyinstance configured for column-wise sharding. -
Return type:
expert_parallel()
static expert_parallel(num_devices)
Creates an expert parallel sharding strategy.
This strategy is designed for Module that has multiple weights,
for which a single weight sharding strategy is not sufficient. This strategy
is a placeholder and should not be called directly. Modules are expected
to implement their own expert parallel sharding strategy.
-
Parameters:
-
num_devices (int) – The number of devices to shard the
Moduleacross. -
Returns:
-
A
ShardingStrategyinstance configured for expert parallel sharding. -
Return type:
gate_up()
static gate_up(num_devices, axis=2)
Creates a gate_up sharding strategy.
This strategy shards weights where gate and up projections are concatenated along a specific axis.
-
Parameters:
-
Returns:
-
A
ShardingStrategyinstance configured for gate_up sharding. -
Return type:
head_aware_columnwise()
static head_aware_columnwise(num_devices, num_heads, head_dim)
Creates a head-aware column-wise sharding strategy for attention output projection.
This strategy shards weight columns according to attention head distribution, properly handling cases where num_heads is not evenly divisible by num_devices. It’s designed for output projection weights in attention layers where each device processes a different number of heads.
-
Parameters:
-
Returns:
-
A
ShardingStrategyinstance configured for head-aware column sharding. -
Return type:
is_colwise
property is_colwise: bool
Whether the sharding strategy is column-wise.
is_expert_parallel
property is_expert_parallel: bool
Whether the sharding strategy is expert parallel.
is_gate_up
property is_gate_up: bool
Whether the sharding strategy is gate_up.
is_head_aware_colwise
property is_head_aware_colwise: bool
Whether the sharding strategy is head-aware column-wise.
is_replicate
property is_replicate: bool
Whether the sharding strategy is replicate.
is_rowwise
property is_rowwise: bool
Whether the sharding strategy is row-wise.
is_stacked_qkv
property is_stacked_qkv: bool
Whether the sharding strategy is stacked QKV.
is_tensor_parallel
property is_tensor_parallel: bool
Whether the sharding strategy is tensor parallel.
num_devices
num_devices: int
The number of devices to shard the weight across.
replicate()
static replicate(num_devices)
Creates a replication strategy.
This strategy replicates the entire weight on each device.
-
Parameters:
-
num_devices (int) – The number of devices (primarily for consistency, as the weight is replicated).
-
Returns:
-
A
ShardingStrategyinstance configured for replication. -
Return type:
rowwise()
static rowwise(num_devices)
Creates a row-wise sharding strategy.
This strategy shards the weight along its first axis (axis=0).
-
Parameters:
-
num_devices (int) – The number of devices to shard the weight across.
-
Returns:
-
A
ShardingStrategyinstance configured for row-wise sharding. -
Return type:
shard
shard: Callable[[Weight, int, int], TensorValue]
A callable that takes a Weight, a device index, and the total
number of devices, and returns the sharded TensorValue for that
device.
stacked_qkv()
static stacked_qkv(num_devices, num_heads, head_dim)
Creates a stacked QKV sharding strategy for tensor parallel attention.
This strategy is designed for weights with shape [3 * hidden_size, hidden_size] where Q, K, and V weights are stacked together. It shards each section separately by attention heads, properly handling cases where num_heads is not evenly divisible by num_devices.
-
Parameters:
-
Returns:
-
A
ShardingStrategyinstance configured for stacked QKV sharding. -
Return type:
tensor_parallel()
static tensor_parallel(num_devices)
Creates a tensor parallel sharding strategy.
This strategy is designed for Module that has multiple weights, for
which a single weight sharding strategy is not sufficient. This strategy
is a placeholder and should not be called directly. Modules are expected
to implement their own tensor parallel sharding strategy.
-
Parameters:
-
num_devices (int) – The number of devices to shard the
Moduleacross. -
Returns:
-
A
ShardingStrategyinstance configured for tensor parallel sharding. -
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!