Python class
DistributedTransformer
DistributedTransformer
class max.nn.DistributedTransformer(dim, n_heads, layers, norm, output, embedding, kv_params, devices, rope, return_logits=ReturnLogits.LAST_TOKEN, use_subgraphs=False, subgraph_layer_groups=None, logits_scaling=1.0)
Bases: DistributedLogitsPostprocessMixin, Module
Transformer model consisting for TransformerBlock layers.
-
Parameters:
-
- dim (int)
- n_heads (int)
- layers (list[DistributedTransformerBlock])
- norm (ShardableCallable)
- output (ColumnParallelLinear)
- embedding (VocabParallelEmbedding)
- kv_params (KVCacheParams)
- devices (list[DeviceRef])
- rope (RotaryEmbedding)
- return_logits (ReturnLogits)
- use_subgraphs (bool)
- subgraph_layer_groups (list[list[int]] | None)
- logits_scaling (float)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!