Skip to main content
Log in

Python module

naive_transformer

NaiveTransformer

class max.nn.transformer.naive_transformer.NaiveTransformer(dim: int, n_heads: int, layers: list[max.nn.transformer.naive_transformer.NaiveTransformerBlock], norm: Layer, output: Linear | LinearV2, theta: float, embedding: Embedding | EmbeddingV2, output_type: DType | None = None, embedding_multiplier: float = 1.0, logits_postprocessor: Callable[[TensorValue], TensorValue] | None = None)

Max-Graph only model consisting of NaiveTransformerBlock layers.

NaiveTransformerBlock

class max.nn.transformer.naive_transformer.NaiveTransformerBlock(attention: NaiveAttentionWithRope, mlp: Layer, attention_norm: Layer, mlp_norm: Layer, residual_multiplier: float = 1.0)

Max-Graph Only Stack of Attention, FeedForward, and RMSNorm layers.