IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.laguna

LagunaConfig​

class max.pipelines.architectures.laguna.LagunaConfig(*, hidden_size, num_attention_heads, num_key_value_heads, num_hidden_layers, rope_theta, rope_scaling_params, max_seq_len, intermediate_size, interleaved_rope_weights, vocab_size, dtype, model_quantization_encoding, quantization_config, kv_params, return_logits=ReturnLogits.LAST_TOKEN, norm_method='rms_norm', norm_dtype=None, attention_bias=False, rms_norm_eps=None, tie_word_embeddings=False, stacked_mlp=False, stacked_qkv=False, attention_multiplier, embedding_multiplier, residual_multiplier, devices, clip_qkv, quant_config=None, lora_config=None, longrope_scaling_params=None, logits_scaling=1.0, return_hidden_states=ReturnHiddenStates.NONE, target_layer_ids=None, use_subgraphs=True, data_parallel_degree=1, sliding_window=None, num_local_experts=256, num_experts_per_tok=8, moe_intermediate_size=512, shared_expert_intermediate_size=512, moe_routed_scaling_factor=2.5, moe_router_logit_softcapping=0.0, norm_topk_prob=True, correction_bias_dtype=None, gate_dtype=None, attn_dtype=None, ep_config=None, mlp_layer_types=None, intermediate_size_dense=8192, partial_rotary_factor=0.5, gating=True)

source

Bases: Llama3Config

Configuration for Laguna decoder-only MoE models.

Extends Llama3Config with Laguna-specific fields:

  • Uniform GQA: 64 query / 8 KV heads, full causal attention, a single RoPE table (rope_theta + partial_rotary_factor; M.1 is full-rotary, 1.0).
  • Per-layer MLP type: mlp_layer_types is dense (the dense prefix) or sparse (the rest). The decoder block dispatches between the dense MLP and the sparse MoE block.
  • Sigmoid + correction-bias routing: not softmax. See LagunaTopKRouter for the per-token routing math.
  • Routed scaling factor + shared experts + softplus attention output gate: applied per token / per element as Laguna requires.

Parameters:

attn_dtype​

attn_dtype: DType | None = None

source

Data type for attention weights. Detected from state dict during finalize().

calculate_attention_multiplier()​

static calculate_attention_multiplier(huggingface_config)

source

Computes the attention multiplier from the config’s head_dim.

Parameters:

huggingface_config (AutoConfig) – The HuggingFace configuration object.

Returns:

The attention multiplier value.

Return type:

float

construct_kv_params()​

static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Constructs KV cache parameters using explicit head_dim from config.

Parameters:

  • huggingface_config (AutoConfig) – The HuggingFace configuration object.
  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • devices (list[DeviceRef]) – Devices to use for the KV cache.
  • kv_cache_config (KVCacheConfig) – Configuration for KV cache.
  • cache_dtype (DType) – Data type for the cache.

Returns:

KVCacheParams object with the correct head_dim from config.

Return type:

KVCacheParams

correction_bias_dtype​

correction_bias_dtype: DType | None = None

source

Data type of the e_score_correction_bias weight. Detected from state dict during finalize().

ep_config​

ep_config: EPConfig | None = None

source

Expert parallelism configuration. None means no EP (single-GPU).

gate_dtype​

gate_dtype: DType | None = None

source

Data type for the routed-expert gate Linear. Detected from state dict during finalize().

gating​

gating: bool = True

source

When True, attention applies softplus(g_proj(hidden)) * attn_out per-head before o_proj. g_proj.out_features = num_heads (one scalar per head); broadcast over head_dim. Always True for poolside/Laguna-M.1-NVFP4.

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initializes a LagunaConfig from pipeline configuration.

Parameters:

Returns:

An initialized LagunaConfig instance.

Return type:

Self

initialize_from_config()​

classmethod initialize_from_config(pipeline_config, huggingface_config, model_config=None)

source

Initializes a LagunaConfig from pipeline and HuggingFace configs.

Parameters:

  • pipeline_config (PipelineConfig) – The MAX Engine pipeline configuration.
  • huggingface_config (AutoConfig) – The HuggingFace model configuration.
  • model_config (MAXModelConfig | None) – The MAX Engine model configuration.

Returns:

An initialized LagunaConfig instance.

Return type:

Self

intermediate_size_dense​

intermediate_size_dense: int = 8192

source

Intermediate dim of the dense-layer MLPs (the one or more layers where mlp_layer_types[i] == "dense"). Distinct from intermediate_size (which Llama3Config uses for the active branch) and from moe_intermediate_size.

mlp_layer_types​

mlp_layer_types: list[str] | None = None

source

Per-layer MLP type. Each entry is "dense" or "sparse". Length equals num_hidden_layers. Typically layer 0 is dense and the rest are sparse MoE.

moe_intermediate_size​

moe_intermediate_size: int = 512

source

Per-expert intermediate dim (each routed expert is a SwiGLU MLP with this intermediate size).

moe_routed_scaling_factor​

moe_routed_scaling_factor: float = 2.5

source

Scalar applied to routed-expert output before adding the shared expert output. Laguna-specific (default 1.0 in donors that have it at all).

moe_router_logit_softcapping​

moe_router_logit_softcapping: float = 0.0

source

If > 0, router logits are passed through softcap * tanh(logits/softcap) before sigmoid. Disabled (0.0) for poolside/Laguna-M.1-NVFP4.

norm_topk_prob​

norm_topk_prob: bool = True

source

Whether to L1-normalise the selected top-k routing weights so they sum to 1. True for Laguna (matches HF reference).

num_experts_per_tok​

num_experts_per_tok: int = 8

source

Top-k experts selected per token.

num_local_experts​

num_local_experts: int = 256

source

Number of routed experts per sparse layer.

partial_rotary_factor​

partial_rotary_factor: float = 0.5

source

Fraction of head_dim used for rotary embeddings, read from the HF config. poolside/Laguna-M.1-NVFP4 is full-rotary (1.0).

shared_expert_intermediate_size​

shared_expert_intermediate_size: int = 512

source

Intermediate dim of the always-on shared expert MLP added alongside the routed-expert output (sum, not gated).

LagunaModel​

class max.pipelines.architectures.laguna.LagunaModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)

source

Bases: AlwaysSignalBuffersMixin, LlamaModelBase

Laguna pipeline model for text generation.

Uses AlwaysSignalBuffersMixin since VocabParallelEmbedding and ColumnParallelLinear always require signal buffers for allreduce.

Parameters:

attention_bias​

attention_bias: bool = False

source

Whether to use attention bias.

estimate_activation_memory()​

classmethod estimate_activation_memory(pipeline_config, huggingface_config)

source

Parameters:

Return type:

int

load_model()​

load_model(session)

source

Parameters:

session (InferenceSession)

Return type:

Model

model​

model: Model

source

Compiled and initialized model ready for inference.

model_config_cls​

model_config_cls

source

alias of LagunaConfig

norm_method​

norm_method: Literal['rms_norm'] | Literal['layer_norm'] = 'rms_norm'

source

Normalization layer.

state_dict​

state_dict: dict[str, Any]

source

Weights to load into the model.

LagunaReasoningParser​

class max.pipelines.architectures.laguna.LagunaReasoningParser(think_start_token_id, think_end_token_id, tool_call_start_token_id=None)

source

Bases: ReasoningParser

Laguna reasoning parser for spans framed by <think> and </think>.

Reasoning may begin implicitly, without an explicit <think> token (the chat template appends <think> to the assistant turn), and may end implicitly when a tool call begins.

Parameters:

  • think_start_token_id (int)
  • think_end_token_id (int)
  • tool_call_start_token_id (int | None)

from_tokenizer()​

async classmethod from_tokenizer(tokenizer)

source

Constructs a reasoning parser from a tokenizer.

Parameters:

tokenizer (PipelineTokenizer[Any, Any, Any])

Return type:

LagunaReasoningParser

reasoning_end_token_id()​

async classmethod reasoning_end_token_id(tokenizer)

source

Returns the </think> token id.

Parameters:

tokenizer (PipelineTokenizer[Any, Any, Any])

Return type:

int | None

stream()​

stream(delta_token_ids, is_currently_reasoning=True)

source

Identifies a reasoning span within a streaming delta chunk.

When is_currently_reasoning=False and the chunk contains no <think> opener, returns an empty span so non-reasoning chunks (turns where the chat template prefilled </think>, or any chunk after reasoning ended in a prior chunk) aren’t misclassified as reasoning.

Parameters:

Return type:

ParsedReasoningDelta

will_reason_after_prompt()​

will_reason_after_prompt(prompt_token_ids)

source

Predicts whether the model will emit reasoning after this prompt.

Only checks for </think> β€” not the tool-call opener β€” because the chat template embeds tool-call format tokens in the system prompt when tools are provided, which must not disable reasoning for the generation that follows.

Parameters:

prompt_token_ids (Sequence[int])

Return type:

bool

LagunaToolParser​

class max.pipelines.architectures.laguna.LagunaToolParser

source

Bases: object

Parser for Laguna <tool_call> blocks.

Streaming is emitted at per-call granularity: each complete <tool_call>...</tool_call> block produces one delta carrying the function name and the full JSON arguments. That is coarser than per-token but still valid for OpenAI-style clients, which concatenate argument fragments back into a JSON string.

parse_complete()​

parse_complete(response)

source

Parses a complete model response into tool calls.

Parameters:

response (str)

Return type:

ParsedToolResponse

parse_delta()​

parse_delta(delta)

source

Processes one decoded-token delta incrementally.

Buffers until a full <tool_call>...</tool_call> block is available, then emits it as a single tool-call delta. Content before the first block is forwarded; structural text between/after blocks is suppressed.

Parameters:

delta (str)

Return type:

list[ParsedToolCallDelta] | None

reset()​

reset()

source

Resets internal state for a new streaming session.

Return type:

None