Skip to main content

Python class

DistributedTransformer

DistributedTransformer

class max.nn.DistributedTransformer(dim, n_heads, layers, norm, output, embedding, kv_params, devices, rope, return_logits=ReturnLogits.LAST_TOKEN, use_subgraphs=False, subgraph_layer_groups=None, logits_scaling=1.0)

source

Bases: DistributedLogitsPostprocessMixin, Module

Transformer model consisting for TransformerBlock layers.

Parameters: