Skip to main content
Log in

Python module

distributed_transformer

DistributedTransformer

class max.pipelines.nn.transformer.distributed_transformer.DistributedTransformer(dim: int, n_heads: int, layers: list[max.pipelines.nn.transformer.distributed_transformer.DistributedTransformerBlock], norm: RMSNorm | LPLayerNorm, output: LinearV2, embedding: Embedding, kv_params: KVCacheParams, kv_collection_constructor: FetchContinuousBatchingKVCacheCollection, devices: List[DeviceRef], all_logits: bool = False)

Transformer model consisting for TransformerBlock layers.

all_logits

all_logits*: bool* = False

devices

devices*: List[DeviceRef]*

dim

dim*: int*

embedding

embedding*: Embedding*

kv_collection_constructor

kv_collection_constructor*: FetchContinuousBatchingKVCacheCollection*

kv_params

kv_params*: KVCacheParams*

layers

layers*: list[max.pipelines.nn.transformer.distributed_transformer.DistributedTransformerBlock]*

n_heads

n_heads*: int*

norm

norm*: RMSNorm | LPLayerNorm*

output

output*: LinearV2*

DistributedTransformerBlock

class max.pipelines.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention: DistributedAttentionImpl, mlp: DistributedMLP, attention_norm: DistributedRMSNorm, mlp_norm: DistributedRMSNorm, devices: List[DeviceRef])

Stack of Attention, FeedForward, and RMSNorm layers.

attention

attention*: DistributedAttentionImpl*

attention_norm

attention_norm*: DistributedRMSNorm*

devices

devices*: List[DeviceRef]*

mlp

mlp*: DistributedMLP*

mlp_norm

mlp_norm*: DistributedRMSNorm*

distribute_value()

max.pipelines.nn.transformer.distributed_transformer.distribute_value(v, devices: List[DeviceRef])