IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.nemotron_h

NemotronHConfig​

class max.pipelines.architectures.nemotron_h.NemotronHConfig(*, hidden_size, vocab_size, num_hidden_layers, layer_norm_epsilon, max_seq_len, dtype, devices, tie_word_embeddings=False, layer_kinds=<factory>, num_attention_heads, num_key_value_heads, attention_head_dim, attention_bias=False, intermediate_size, mlp_hidden_act='relu2', mlp_bias=False, mamba_num_heads, mamba_head_dim, n_groups, ssm_state_size, conv_kernel, chunk_size, use_conv_bias=True, mamba_proj_bias=False, time_step_limit=(0.0, inf), kv_params, fp8_mamba_layers=<factory>, fp8_mlp_layers=<factory>, is_fp8=False)

source

Bases: ArchConfigWithStoredKVParams, ArchConfigWithKVCache

Configuration for a Nemotron-H (nemotron_h) hybrid decoder.

Nemotron-H interleaves Mamba-2 mixers, NoPE GQA attention, and relu2 (non-gated) MLP blocks per hybrid_override_pattern. There is no rotary embedding (attention is NoPE; position information flows through the SSM). FP8 (modelopt per-tensor static) is applied per-module to the Mamba in/out projections and the MLP up/down projections, honoring the checkpoint’s exclude_modules list; attention, conv1d, norms, and lm_head stay bf16.

Parameters:

attention_bias​

attention_bias: bool = False

source

attention_head_dim​

attention_head_dim: int

source

attention_kv_dim​

property attention_kv_dim: int

source

attention_layer_indices​

property attention_layer_indices: list[int]

source

attention_q_dim​

property attention_q_dim: int

source

calculate_max_seq_len()​

classmethod calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)

source

Bounds max_length by max_position_embeddings.

Parameters:

Return type:

int

chunk_size​

chunk_size: int

source

construct_kv_params()​

static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Allocate KV cache only for the (4) full-attention layers.

The forward pass maps each attention layer to a sequential KV cache index (0, 1, 2, …), independent of the absolute layer index.

Parameters:

Return type:

KVCacheParams

conv_dim​

property conv_dim: int

source

conv_kernel​

conv_kernel: int

source

devices​

devices: list[DeviceRef]

source

dtype​

dtype: DType

source

fp8_mamba_layers​

fp8_mamba_layers: set[int]

source

fp8_mlp_layers​

fp8_mlp_layers: set[int]

source

from_hf()​

classmethod from_hf(pipeline_config, huggingface_config, dtype, kv_params, devices)

source

Parameters:

Return type:

NemotronHConfig

get_num_layers()​

static get_num_layers(huggingface_config)

source

Layer count for the decoder stack (override when HF uses a different field).

Parameters:

huggingface_config (AutoConfig)

Return type:

int

hidden_size​

hidden_size: int

source

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

ArchConfig protocol entry point.

Derives dtype / devices / KV params from model_config and delegates to from_hf(). model.py calls from_hf() directly during graph build (it already has the resolved KV params); this method exists so the pipeline config-resolution / memory-estimation path can build the arch config from a PipelineConfig alone.

Parameters:

Return type:

NemotronHConfig

intermediate_size​

intermediate_size: int

source

is_fp8​

is_fp8: bool = False

source

kv_params​

kv_params: KVCacheParams

source

layer_kinds​

layer_kinds: list[str]

source

layer_norm_epsilon​

layer_norm_epsilon: float

source

mamba_head_dim​

mamba_head_dim: int

source

mamba_in_proj_out​

property mamba_in_proj_out: int

source

mamba_intermediate_size​

property mamba_intermediate_size: int

source

mamba_layer_indices​

property mamba_layer_indices: list[int]

source

mamba_num_heads​

mamba_num_heads: int

source

mamba_proj_bias​

mamba_proj_bias: bool = False

source

max_seq_len​

max_seq_len: int

source

mlp_bias​

mlp_bias: bool = False

source

mlp_hidden_act​

mlp_hidden_act: str = 'relu2'

source

n_groups​

n_groups: int

source

num_attention_heads​

num_attention_heads: int

source

num_hidden_layers​

num_hidden_layers: int

source

num_key_value_heads​

num_key_value_heads: int

source

populate_fp8_layers()​

populate_fp8_layers(state_dict)

source

Record which mamba/MLP layers are FP8 from the checkpoint.

A Linear is FP8 iff its weight_scale is present in the checkpoint (the exact inverse of the modelopt exclude_modules list). Names are the post-adapter MAX names: blocks.{i}.mixer.{in_proj,out_proj, up_proj,down_proj}.weight_scale.

Parameters:

state_dict (Mapping[str, WeightData])

Return type:

None

ssm_state_size​

ssm_state_size: int

source

tie_word_embeddings​

tie_word_embeddings: bool = False

source

time_step_limit​

time_step_limit: tuple[float, float] = (0.0, inf)

source

use_conv_bias​

use_conv_bias: bool = True

source

vocab_size​

vocab_size: int

source

NemotronHInputs​

class max.pipelines.architectures.nemotron_h.NemotronHInputs(tokens, input_row_offsets, signal_buffers, return_n_logits, data_parallel_splits=None, slot_idx=None, conv_pools=None, ssm_pools=None, has_initial_state=None, request_ids=None, *, kv_cache_inputs=None, lora=None, hidden_states=None)

source

Bases: Llama3Inputs

Inputs for Nemotron-H: ragged tokens + hybrid SSM/conv state.

Beyond the standard Llama3 ragged inputs, carries a uint32 slot_idx into the per-request pools, the per-mamba-layer conv pools (bf16, mutated in place by causal_conv1d_varlen_fwd), the per-mamba-layer SSM pools (fp32, mutated in-place by mamba2_ssd_chunk_scan_varlen_fwd_inplace), and a [batch] has_initial_state (the slots are zeroed on claim, so a fresh request’s zeroed initial state matches a from-scratch prefill).

Parameters:

buffers​

property buffers: tuple[Buffer, ...]

source

Returns positional Buffer inputs for model ABI calls.

conv_pools​

conv_pools: list[Buffer] | None = None

source

has_initial_state​

has_initial_state: Buffer | None = None

source

request_ids​

request_ids: list[RequestID] | None = None

source

slot_idx​

slot_idx: Buffer | None = None

source

ssm_pools​

ssm_pools: list[Buffer] | None = None

source

NemotronHModel​

class max.pipelines.architectures.nemotron_h.NemotronHModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)

source

Bases: LlamaModelBase

Nemotron-H pipeline model (hybrid Mamba-2 + attention).

Parameters:

batch_processor_cls​

batch_processor_cls: ClassVar[Any] = None

source

Optional batch processor class for input/output handling.

execute()​

execute(model_inputs)

source

Executes the graph with the given inputs.

Parameters:

model_inputs (ModelInputs) – The model inputs to execute, containing tensors and any other required data for model execution.

Returns:

ModelOutputs containing the pipeline’s output tensors.

Return type:

ModelOutputs

This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.

load_model()​

load_model(session)

source

Parameters:

session (InferenceSession)

Return type:

Model

model_config_cls​

model_config_cls

source

alias of NemotronHConfig

norm_method​

norm_method: Literal['rms_norm', 'layer_norm'] = 'rms_norm'

source

Normalization layer.

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Delegates to the batch processor and narrows to Llama3Inputs.

Parameters:

Return type:

NemotronHInputs

release()​

release(request_id)

source

Parameters:

request_id (RequestID)

Return type:

None