Python module
distributed_transformer
DistributedLogitsPostprocessMixin
class max.nn.transformer.distributed_transformer.DistributedLogitsPostprocessMixin
Mixin providing logits postprocessing for multi-device sharded models.
Requires: self.norm_shards, self.lm_head, self.return_logits, self.devices. Optional: self.return_hidden_states, self.logits_scaling.
devices
lm_head
lm_head: Callable[[list[TensorValue], Sequence[BufferValue]], Sequence[TensorValue]]
logits_scaling
logits_scaling: float = 1.0
norm_shards
norm_shards: Sequence[Callable[[TensorValue], TensorValue]]
return_hidden_states
return_hidden_states: ReturnHiddenStates = 'none'
return_logits
return_logits: ReturnLogits
DistributedTransformer
class max.nn.transformer.distributed_transformer.DistributedTransformer(dim, n_heads, layers, norm, output, embedding, kv_params, devices, rope, return_logits=ReturnLogits.LAST_TOKEN, use_subgraphs=False, subgraph_layer_groups=None, logits_scaling=1.0)
Transformer model consisting for TransformerBlock layers.
-
Parameters:
-
- dim (int)
- n_heads (int)
- layers (list[DistributedTransformerBlock])
- norm (ShardableCallable)
- output (ColumnParallelLinear)
- embedding (VocabParallelEmbedding)
- kv_params (KVCacheParams)
- devices (list[DeviceRef])
- rope (RotaryEmbedding)
- return_logits (ReturnLogits)
- use_subgraphs (bool)
- subgraph_layer_groups (list[list[int]] | None)
- logits_scaling (float)
DistributedTransformerBlock
class max.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention, mlp, attention_norm, mlp_norm, devices)
Stack of Attention, FeedForward, and RMSNorm layers.
-
Parameters:
-
- attention (Module)
- mlp (ShardableCallable)
- attention_norm (ShardableCallable)
- mlp_norm (ShardableCallable)
- devices (list[DeviceRef])
ShardableCallable
class max.nn.transformer.distributed_transformer.ShardableCallable(*args, **kwargs)
distributed_logits_postprocess()
max.nn.transformer.distributed_transformer.distributed_logits_postprocess(h, input_row_offsets, return_n_logits, norm_shards, lm_head, signal_buffers, return_logits, device, return_hidden_states=ReturnHiddenStates.NONE, logits_scaling=1.0)
Common logits postprocessing for multi-device sharded models.
Handles last-token gathering, logits computation (VARIABLE/ALL/LAST_TOKEN), logits scaling, and hidden states return for models that use per-device sharded hidden states.
-
Parameters:
-
- h (Sequence[TensorValue]) – Per-device hidden states from the final transformer layer.
- input_row_offsets (Sequence[TensorValue]) – Per-device row offsets for ragged batching.
- return_n_logits (TensorValue) – Number of logits to return per sequence.
- norm_shards (Sequence[Callable[[TensorValue], TensorValue]]) – Per-device normalization functions.
- lm_head (Callable[[list[TensorValue], Sequence[BufferValue]], Sequence[TensorValue]]) – Language model head (takes per-device inputs + signal buffers).
- signal_buffers (Sequence[BufferValue]) – Signal buffers for collective operations.
- return_logits (ReturnLogits) – Which logits to return.
- device (DeviceRef) – Primary device for scalar ops (e.g. ops.range).
- return_hidden_states (ReturnHiddenStates) – Which hidden states to return.
- logits_scaling (float) – Scaling factor for logits.
-
Returns:
-
Tuple of (last_logits, [logits, offsets], [hidden_states]).
-
Return type:
-
tuple[TensorValue, …]
forward_sharded_layers()
max.nn.transformer.distributed_transformer.forward_sharded_layers(layers, xs)
Forward pass through sharded layers.
-
Parameters:
-
- layers (Sequence[Callable[[TensorValue], TensorValue]]) – Sequence of callable layers that return TensorValue
- xs (Sequence[TensorValue]) – Input tensors, one per layer
-
Returns:
-
List of output tensors from each layer
-
Raises:
-
AssertionError – If the number of layers and input tensors don’t match
-
Return type:
take()
max.nn.transformer.distributed_transformer.take(it, n)
Return the next n items from it as a list.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!