Skip to main content

Python module

transformer

ReturnHiddenStates

class max.nn.legacy.transformer.transformer.ReturnHiddenStates(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

ALL

ALL = 'all'

ALL_NORMALIZED

ALL_NORMALIZED = 'all_normalized'

LAST

LAST = 'last'

LAST_NORMALIZED

LAST_NORMALIZED = 'last_normalized'

NONE

NONE = 'none'

ReturnLogits

class max.nn.legacy.transformer.transformer.ReturnLogits(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

ALL

ALL = 'all'

LAST_TOKEN

LAST_TOKEN = 'last_token'

VARIABLE

VARIABLE = 'variable'

Transformer

class max.nn.legacy.transformer.transformer.Transformer(dim, n_heads, layers, norm, output, embedding, kv_params, rope, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE, embedding_multiplier=1.0, logits_scaling=1.0)

Transformer model consisting for TransformerBlock layers.

Parameters:

TransformerBlock

class max.nn.legacy.transformer.transformer.TransformerBlock(attention, mlp, attention_norm, mlp_norm, residual_multiplier=1.0)

Stack of Attention, FeedForward, and RMSNorm layers.

Parameters:

Was this page helpful?