Python module
transformer
ReturnHiddenStates
class max.nn.legacy.transformer.transformer.ReturnHiddenStates(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
ALL
ALL = 'all'
ALL_NORMALIZED
ALL_NORMALIZED = 'all_normalized'
LAST
LAST = 'last'
LAST_NORMALIZED
LAST_NORMALIZED = 'last_normalized'
NONE
NONE = 'none'
ReturnLogits
class max.nn.legacy.transformer.transformer.ReturnLogits(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
ALL
ALL = 'all'
LAST_TOKEN
LAST_TOKEN = 'last_token'
VARIABLE
VARIABLE = 'variable'
Transformer
class max.nn.legacy.transformer.transformer.Transformer(dim, n_heads, layers, norm, output, embedding, kv_params, rope, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE, embedding_multiplier=1.0, logits_scaling=1.0)
Transformer model consisting for TransformerBlock layers.
-
Parameters:
-
- dim (int)
- n_heads (int)
- layers (list[Block])
- norm (Layer)
- output (Linear)
- embedding (Embedding)
- kv_params (KVCacheParams)
- rope (RotaryEmbedding)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
- embedding_multiplier (float)
- logits_scaling (float)
TransformerBlock
class max.nn.legacy.transformer.transformer.TransformerBlock(attention, mlp, attention_norm, mlp_norm, residual_multiplier=1.0)
Stack of Attention, FeedForward, and RMSNorm layers.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!