Python module
distributed_transformer
DistributedTransformer
class max.nn.transformer.distributed_transformer.DistributedTransformer(dim, n_heads, layers, norm, output, embedding, kv_params, kv_collection_constructor, devices, return_logits=ReturnLogits.LAST_TOKEN)
Transformer model consisting for TransformerBlock layers.
-
Parameters:
-
- dim (
int
) - n_heads (
int
) - layers (
list
[
DistributedTransformerBlock
]
) - norm (
DistributedRMSNorm
) - output (
ColumnParallelLinear
) - embedding (
VocabParallelEmbedding
) - kv_params (
KVCacheParams
) - kv_collection_constructor (
FetchContinuousBatchingKVCacheCollection
|
FetchPagedKVCacheCollection
) - devices (
list
[
DeviceRef
]
) - return_logits (
ReturnLogits
)
- dim (
DistributedTransformerBlock
class max.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention, mlp, attention_norm, mlp_norm, devices, use_subgraph=False)
Stack of Attention, FeedForward, and RMSNorm layers.
-
Parameters:
-
- attention (
Module
) - mlp (
Module
) - attention_norm (
DistributedRMSNorm
) - mlp_norm (
DistributedRMSNorm
) - devices (
list
[
DeviceRef
]
) - use_subgraph (
bool
)
- attention (
build_subgraph()
build_subgraph(name)
distribute_value()
max.nn.transformer.distributed_transformer.distribute_value(v, devices)
-
Parameters:
-
devices (
list
[
DeviceRef
]
)
take()
max.nn.transformer.distributed_transformer.take(it, n)
Return the next n items from it as a list.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!