Python module
transformer
ReturnLogits
class max.nn.transformer.transformer.ReturnLogits(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
ALL
ALL = 'all'
LAST_TOKEN
LAST_TOKEN = 'last_token'
VARIABLE
VARIABLE = 'variable'
Transformer
class max.nn.transformer.transformer.Transformer(dim, n_heads, layers, norm, output, embedding, kv_params, kv_collection_constructor, return_logits=ReturnLogits.LAST_TOKEN, embedding_multiplier=1.0, logits_postprocessor=None)
Transformer model consisting for TransformerBlock layers.
-
Parameters:
-
- dim (
int
) - n_heads (
int
) - layers (
list
[
Block
]
) - norm (
Layer
) - output (
LinearV1
|
Linear
) - embedding (
EmbeddingV1
|
Embedding
) - kv_params (
KVCacheParams
) - kv_collection_constructor (
FetchContinuousBatchingKVCacheCollection
|
FetchPagedKVCacheCollection
) - return_logits (
ReturnLogits
) - embedding_multiplier (
float
) - logits_postprocessor (
Callable
[
[
TensorValue
]
,
TensorValue
]
|
None
)
- dim (
TransformerBlock
class max.nn.transformer.transformer.TransformerBlock(attention, mlp, attention_norm, mlp_norm, residual_multiplier=1.0)
Stack of Attention, FeedForward, and RMSNorm layers.
-
Parameters:
-
- attention (
AttentionImpl
|
AttentionImplQKV
|
Module
) - mlp (
Layer
) - attention_norm (
Layer
) - mlp_norm (
Layer
) - residual_multiplier (
float
)
- attention (
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!