Python module
distributed_transformer
DistributedTransformer
class max.pipelines.nn.transformer.distributed_transformer.DistributedTransformer(dim: int, n_heads: int, layers: list[max.pipelines.nn.transformer.distributed_transformer.DistributedTransformerBlock], norm: RMSNorm | LPLayerNorm, output: LinearV2, embedding: Embedding, kv_params: KVCacheParams, kv_collection_constructor: FetchContinuousBatchingKVCacheCollection, devices: List[DeviceRef], all_logits: bool = False)
Transformer model consisting for TransformerBlock layers.
all_logits
all_logits*: bool* = False
devices
devices*: List[DeviceRef]*
dim
dim*: int*
embedding
embedding*: Embedding*
kv_collection_constructor
kv_collection_constructor*: FetchContinuousBatchingKVCacheCollection*
kv_params
kv_params*: KVCacheParams*
layers
layers*: list[max.pipelines.nn.transformer.distributed_transformer.DistributedTransformerBlock]*
n_heads
n_heads*: int*
norm
norm*: RMSNorm | LPLayerNorm*
output
output*: LinearV2*
DistributedTransformerBlock
class max.pipelines.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention: DistributedAttentionImpl, mlp: DistributedMLP, attention_norm: DistributedRMSNorm, mlp_norm: DistributedRMSNorm, devices: List[DeviceRef])
Stack of Attention, FeedForward, and RMSNorm layers.
attention
attention*: DistributedAttentionImpl*
attention_norm
attention_norm*: DistributedRMSNorm*
devices
devices*: List[DeviceRef]*
mlp
mlp*: DistributedMLP*
mlp_norm
mlp_norm*: DistributedRMSNorm*
distribute_value()
max.pipelines.nn.transformer.distributed_transformer.distribute_value(v, devices: List[DeviceRef])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!