Python class
Transformer
Transformer
class max.nn.Transformer(dim, n_heads, layers, norm, output, embedding, kv_params, rope, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE, embedding_multiplier=1.0, logits_scaling=1.0)
Bases: LogitsPostprocessMixin, Module
A transformer model consisting of TransformerBlock layers.
-
Parameters:
-
- dim (int) – The model dimension.
- n_heads (int) – The number of attention heads.
- layers (list[Block]) – The list of transformer blocks.
- norm (Callable[[TensorValue], TensorValue]) – The normalization layer applied before the language model head.
- output (Linear) – The language model head projection.
- embedding (Embedding) – The token embedding layer.
- kv_params (KVCacheParams) – The key-value cache parameters.
- rope (RotaryEmbedding) – The rotary position embedding.
- return_logits (ReturnLogits) – Which logits to return. Defaults to
ReturnLogits.LAST_TOKEN. - return_hidden_states (ReturnHiddenStates) – Which hidden states to return. Defaults to
ReturnHiddenStates.NONE. - embedding_multiplier (float) – A scalar applied to token embeddings after
lookup. Defaults to
1.0(no scaling). - logits_scaling (float) – A divisor applied to logits after projection. Logits
are divided by this value before returning. Defaults to
1.0(no scaling).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!