Python class
StackedLinear
StackedLinearβ
class max.nn.StackedLinear(in_dim, out_dims, names, dtype, device, stacked=False, has_bias=False, linear_cls=<class 'max.nn.linear.Linear'>, quant_config=None, clip_weight=None, _is_sharding=False)
Bases: Module
A module that manages multiple linear projections as a stacked weight.
Supports two modes:
- Stacked (
stacked=True): Holds a single pre-stacked weight tensor. Use when the checkpoint already stores a fused weight (e.g.qkv_proj.weight). - Unfused (
stacked=False): Holds N childLinearmodules whose weights are concatenated at graph-build time. Use when the checkpoint stores separate projections (e.g.q_proj,k_proj,v_proj).
In unfused mode (stacked=False), the module sets
_omit_module_attr_name: its own attribute name
(typically qkv_proj) is omitted from the FQN of its child weights.
The child names supplied via the names argument therefore double
as the external (checkpoint) names. For QKV stacking that means using
names=["q_proj", "k_proj", "v_proj"] so that
self.qkv_proj = StackedLinear(...) exposes weights at
self_attn.q_proj.weight rather than
self_attn.qkv_proj.q_proj.weight. This removes the need for
per-architecture q_proj -> qkv_proj.q mapping in weight adapters.
In stacked mode (stacked=True), the attribute name is not
omitted: the single fused weight/bias would otherwise lose
all namespace context and collide with sibling attributes.
Stacked-mode weights remain at <attr>.weight / <attr>.bias
(e.g. self_attn.qkv_proj.weight) and weight adapters must
continue to map fused checkpoint names into that namespace.
Initializes the stacked linear layer.
-
Parameters:
-
- in_dim (int) β The input dimension shared by all projections.
- out_dims (Sequence[int]) β Output dimension for each projection.
- names (Sequence[str]) β Attribute name for each child (e.g.
["q_proj", "k_proj", "v_proj"]). In unfused mode these names are also the FQNs the childrenβs weights are exposed under (see class docstring on_omit_module_attr_name), so they should match the corresponding checkpoint names. - dtype (DType) β Data type for all weights.
- device (DeviceRef) β Device for weight placement.
- stacked (bool) β When
True, create a single pre-stacked weight instead of N childLinearmodules. - has_bias (bool) β Whether each projection has a bias vector.
- linear_cls (Callable[..., Linear]) β Linear class to use for each projection.
- quant_config (QuantConfig | None) β Optional quantization config.
- clip_weight (float | None) β Optional weight clipping threshold.
- _is_sharding (bool)
shard()β
shard(devices)
Create sharded copies across devices.
For stacked mode, shards the single weight. For unfused mode, shards each child Linear and reassembles.
-
Parameters:
-
Return type:
sharding_strategyβ
property sharding_strategy: ShardingStrategy | None
Get the sharding strategy.
stacked_biasβ
property stacked_bias: TensorValue | None
Returns the concatenated bias vector, or None.
stacked_input_scaleβ
property stacked_input_scale: TensorValue | None
Returns the max of per-projection input scales, or None.
stacked_weightβ
property stacked_weight: TensorValue
Returns the stacked weight tensor.
For stacked mode, returns the single weight directly.
For unfused mode, delegates to _concat_child_weights().
stacked_weight_scaleβ
property stacked_weight_scale: TensorValue | None
Returns the combined weight scale for quantized matmul.
stacked_weight_scale_2β
property stacked_weight_scale_2: TensorValue | None
Returns the max of per-projection weight_scale_2 (NVFP4).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!