Skip to main content

Python class

Transformer

Transformer

class max.nn.Transformer(dim, n_heads, layers, norm, output, embedding, kv_params, rope, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE, embedding_multiplier=1.0, logits_scaling=1.0)

source

Bases: LogitsPostprocessMixin, Module

A transformer model consisting of TransformerBlock layers.

Parameters:

  • dim (int) – The model dimension.
  • n_heads (int) – The number of attention heads.
  • layers (list[Block]) – The list of transformer blocks.
  • norm (Callable[[TensorValue], TensorValue]) – The normalization layer applied before the language model head.
  • output (Linear) – The language model head projection.
  • embedding (Embedding) – The token embedding layer.
  • kv_params (KVCacheParams) – The key-value cache parameters.
  • rope (RotaryEmbedding) – The rotary position embedding.
  • return_logits (ReturnLogits) – Which logits to return. Defaults to ReturnLogits.LAST_TOKEN.
  • return_hidden_states (ReturnHiddenStates) – Which hidden states to return. Defaults to ReturnHiddenStates.NONE.
  • embedding_multiplier (float) – A scalar applied to token embeddings after lookup. Defaults to 1.0 (no scaling).
  • logits_scaling (float) – A divisor applied to logits after projection. Logits are divided by this value before returning. Defaults to 1.0 (no scaling).