Python module
naive_transformer
NaiveTransformer
class max.pipelines.nn.transformer.naive_transformer.NaiveTransformer(dim: int, n_heads: int, layers: list[max.pipelines.nn.transformer.naive_transformer.NaiveTransformerBlock], norm: RMSNorm | RMSNormV2, output: Linear | LinearV2, theta: float, embedding: Embedding | EmbeddingV2, output_type: DType | None = None)
Max-Graph only model consisting of NaiveTransformerBlock layers.
NaiveTransformerBlock
class max.pipelines.nn.transformer.naive_transformer.NaiveTransformerBlock(attention: NaiveAttentionWithRope, mlp: MLP | MLPV2, attention_norm: RMSNorm | RMSNormV2, mlp_norm: RMSNorm | RMSNormV2)
Max-Graph Only Stack of Attention, FeedForward, and RMSNorm layers.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!