Python module
distributed_transformer
DistributedTransformer
class max.nn.transformer.distributed_transformer.DistributedTransformer(dim, n_heads, layers, norm, output, embedding, kv_params, kv_collection_constructor, devices, return_logits=ReturnLogits.LAST_TOKEN, use_subgraphs=False)
Transformer model consisting for TransformerBlock layers.
-
Parameters:
-
- dim (int)
- n_heads (int)
- layers (list[DistributedTransformerBlock])
- norm (DistributedRMSNorm)
- output (ColumnParallelLinear)
- embedding (VocabParallelEmbedding)
- kv_params (KVCacheParams)
- kv_collection_constructor (FetchContinuousBatchingKVCacheCollection | FetchPagedKVCacheCollection)
- devices (list[DeviceRef])
- return_logits (ReturnLogits)
- use_subgraphs (bool)
DistributedTransformerBlock
class max.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention, mlp, attention_norm, mlp_norm, devices)
Stack of Attention, FeedForward, and RMSNorm layers.
-
Parameters:
-
- attention (Module)
- mlp (Module)
- attention_norm (DistributedRMSNorm)
- mlp_norm (DistributedRMSNorm)
- devices (list[DeviceRef])
distribute_value()
max.nn.transformer.distributed_transformer.distribute_value(v, devices)
take()
max.nn.transformer.distributed_transformer.take(it, n)
Return the next n items from it as a list.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!