Skip to main content
Log in

Python module

transformer

ReturnLogits

class max.nn.transformer.transformer.ReturnLogits(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

ALL

ALL = 'all'

LAST_TOKEN

LAST_TOKEN = 'last_token'

VARIABLE

VARIABLE = 'variable'

Transformer

class max.nn.transformer.transformer.Transformer(dim, n_heads, layers, norm, output, embedding, kv_params, kv_collection_constructor, return_logits=ReturnLogits.LAST_TOKEN, embedding_multiplier=1.0, logits_postprocessor=None)

Transformer model consisting for TransformerBlock layers.

Parameters:

TransformerBlock

class max.nn.transformer.transformer.TransformerBlock(attention, mlp, attention_norm, mlp_norm, residual_multiplier=1.0)

Stack of Attention, FeedForward, and RMSNorm layers.

Parameters: