Skip to main content

Python class

StackedLinear

StackedLinear​

class max.nn.StackedLinear(in_dim, out_dims, names, dtype, device, stacked=False, has_bias=False, linear_cls=<class 'max.nn.linear.Linear'>, quant_config=None, clip_weight=None, _is_sharding=False)

source

Bases: Module

A module that manages multiple linear projections as a stacked weight.

Supports two modes:

  • Stacked (stacked=True): Holds a single pre-stacked weight tensor. Use when the checkpoint already stores a fused weight (e.g. qkv_proj.weight).
  • Unfused (stacked=False): Holds N child Linear modules whose weights are concatenated at graph-build time. Use when the checkpoint stores separate projections (e.g. q_proj, k_proj, v_proj).

In unfused mode (stacked=False), the module sets _omit_module_attr_name: its own attribute name (typically qkv_proj) is omitted from the FQN of its child weights. The child names supplied via the names argument therefore double as the external (checkpoint) names. For QKV stacking that means using names=["q_proj", "k_proj", "v_proj"] so that self.qkv_proj = StackedLinear(...) exposes weights at self_attn.q_proj.weight rather than self_attn.qkv_proj.q_proj.weight. This removes the need for per-architecture q_proj -> qkv_proj.q mapping in weight adapters.

In stacked mode (stacked=True), the attribute name is not omitted: the single fused weight/bias would otherwise lose all namespace context and collide with sibling attributes. Stacked-mode weights remain at <attr>.weight / <attr>.bias (e.g. self_attn.qkv_proj.weight) and weight adapters must continue to map fused checkpoint names into that namespace.

Initializes the stacked linear layer.

Parameters:

  • in_dim (int) – The input dimension shared by all projections.
  • out_dims (Sequence[int]) – Output dimension for each projection.
  • names (Sequence[str]) – Attribute name for each child (e.g. ["q_proj", "k_proj", "v_proj"]). In unfused mode these names are also the FQNs the children’s weights are exposed under (see class docstring on _omit_module_attr_name), so they should match the corresponding checkpoint names.
  • dtype (DType) – Data type for all weights.
  • device (DeviceRef) – Device for weight placement.
  • stacked (bool) – When True, create a single pre-stacked weight instead of N child Linear modules.
  • has_bias (bool) – Whether each projection has a bias vector.
  • linear_cls (Callable[..., Linear]) – Linear class to use for each projection.
  • quant_config (QuantConfig | None) – Optional quantization config.
  • clip_weight (float | None) – Optional weight clipping threshold.
  • _is_sharding (bool)

shard()​

shard(devices)

source

Create sharded copies across devices.

For stacked mode, shards the single weight. For unfused mode, shards each child Linear and reassembles.

Parameters:

devices (Iterable[DeviceRef])

Return type:

list[StackedLinear]

sharding_strategy​

property sharding_strategy: ShardingStrategy | None

source

Get the sharding strategy.

stacked_bias​

property stacked_bias: TensorValue | None

source

Returns the concatenated bias vector, or None.

stacked_input_scale​

property stacked_input_scale: TensorValue | None

source

Returns the max of per-projection input scales, or None.

stacked_weight​

property stacked_weight: TensorValue

source

Returns the stacked weight tensor.

For stacked mode, returns the single weight directly. For unfused mode, delegates to _concat_child_weights().

stacked_weight_scale​

property stacked_weight_scale: TensorValue | None

source

Returns the combined weight scale for quantized matmul.

stacked_weight_scale_2​

property stacked_weight_scale_2: TensorValue | None

source

Returns the max of per-projection weight_scale_2 (NVFP4).