For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Python module
max.pipelines.architectures.nemotron_h
NemotronHConfigβ
class max.pipelines.architectures.nemotron_h.NemotronHConfig(*, hidden_size, vocab_size, num_hidden_layers, layer_norm_epsilon, max_seq_len, dtype, devices, tie_word_embeddings=False, layer_kinds=<factory>, num_attention_heads, num_key_value_heads, attention_head_dim, attention_bias=False, intermediate_size, mlp_hidden_act='relu2', mlp_bias=False, mamba_num_heads, mamba_head_dim, n_groups, ssm_state_size, conv_kernel, chunk_size, use_conv_bias=True, mamba_proj_bias=False, time_step_limit=(0.0, inf), kv_params, fp8_mamba_layers=<factory>, fp8_mlp_layers=<factory>, is_fp8=False)
Bases: ArchConfigWithStoredKVParams, ArchConfigWithKVCache
Configuration for a Nemotron-H (nemotron_h) hybrid decoder.
Nemotron-H interleaves Mamba-2 mixers, NoPE GQA attention, and relu2
(non-gated) MLP blocks per hybrid_override_pattern. There is no
rotary embedding (attention is NoPE; position information flows through
the SSM). FP8 (modelopt per-tensor static) is applied per-module to the
Mamba in/out projections and the MLP up/down projections, honoring the
checkpointβs exclude_modules list; attention, conv1d, norms, and
lm_head stay bf16.
-
Parameters:
-
- hidden_size (int)
- vocab_size (int)
- num_hidden_layers (int)
- layer_norm_epsilon (float)
- max_seq_len (int)
- dtype (DType)
- devices (list[DeviceRef])
- tie_word_embeddings (bool)
- layer_kinds (list[str])
- num_attention_heads (int)
- num_key_value_heads (int)
- attention_head_dim (int)
- attention_bias (bool)
- intermediate_size (int)
- mlp_hidden_act (str)
- mlp_bias (bool)
- mamba_num_heads (int)
- mamba_head_dim (int)
- n_groups (int)
- ssm_state_size (int)
- conv_kernel (int)
- chunk_size (int)
- use_conv_bias (bool)
- mamba_proj_bias (bool)
- time_step_limit (tuple[float, float])
- kv_params (KVCacheParams)
- fp8_mamba_layers (set[int])
- fp8_mlp_layers (set[int])
- is_fp8 (bool)
attention_biasβ
attention_bias: bool = False
attention_head_dimβ
attention_head_dim: int
attention_kv_dimβ
property attention_kv_dim: int
attention_layer_indicesβ
attention_q_dimβ
property attention_q_dim: int
calculate_max_seq_len()β
classmethod calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)
Bounds max_length by max_position_embeddings.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
- model_config (MAXModelConfig | None)
-
Return type:
chunk_sizeβ
chunk_size: int
construct_kv_params()β
static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Allocate KV cache only for the (4) full-attention layers.
The forward pass maps each attention layer to a sequential KV cache index (0, 1, 2, β¦), independent of the absolute layer index.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
conv_dimβ
property conv_dim: int
conv_kernelβ
conv_kernel: int
devicesβ
dtypeβ
dtype: DType
fp8_mamba_layersβ
fp8_mlp_layersβ
from_hf()β
classmethod from_hf(pipeline_config, huggingface_config, dtype, kv_params, devices)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
- dtype (DType)
- kv_params (KVCacheParams)
- devices (list[DeviceRef])
-
Return type:
get_num_layers()β
static get_num_layers(huggingface_config)
Layer count for the decoder stack (override when HF uses a different field).
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
hidden_sizeβ
hidden_size: int
initialize()β
classmethod initialize(pipeline_config, model_config=None)
ArchConfig protocol entry point.
Derives dtype / devices / KV params from model_config and delegates
to from_hf(). model.py calls from_hf() directly during
graph build (it already has the resolved KV params); this method exists
so the pipeline config-resolution / memory-estimation path can build the
arch config from a PipelineConfig alone.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- model_config (MAXModelConfig | None)
-
Return type:
intermediate_sizeβ
intermediate_size: int
is_fp8β
is_fp8: bool = False
kv_paramsβ
kv_params: KVCacheParams
layer_kindsβ
layer_norm_epsilonβ
layer_norm_epsilon: float
mamba_head_dimβ
mamba_head_dim: int
mamba_in_proj_outβ
property mamba_in_proj_out: int
mamba_intermediate_sizeβ
property mamba_intermediate_size: int
mamba_layer_indicesβ
mamba_num_headsβ
mamba_num_heads: int
mamba_proj_biasβ
mamba_proj_bias: bool = False
max_seq_lenβ
max_seq_len: int
mlp_biasβ
mlp_bias: bool = False
mlp_hidden_actβ
mlp_hidden_act: str = 'relu2'
n_groupsβ
n_groups: int
num_attention_headsβ
num_attention_heads: int
num_hidden_layersβ
num_hidden_layers: int
num_key_value_headsβ
num_key_value_heads: int
populate_fp8_layers()β
populate_fp8_layers(state_dict)
Record which mamba/MLP layers are FP8 from the checkpoint.
A Linear is FP8 iff its weight_scale is present in the checkpoint
(the exact inverse of the modelopt exclude_modules list). Names are
the post-adapter MAX names: blocks.{i}.mixer.{in_proj,out_proj, up_proj,down_proj}.weight_scale.
-
Parameters:
-
state_dict (Mapping[str, WeightData])
-
Return type:
-
None
ssm_state_sizeβ
ssm_state_size: int
tie_word_embeddingsβ
tie_word_embeddings: bool = False
time_step_limitβ
use_conv_biasβ
use_conv_bias: bool = True
vocab_sizeβ
vocab_size: int
NemotronHInputsβ
class max.pipelines.architectures.nemotron_h.NemotronHInputs(tokens, input_row_offsets, signal_buffers, return_n_logits, data_parallel_splits=None, slot_idx=None, conv_pools=None, ssm_pools=None, has_initial_state=None, request_ids=None, *, kv_cache_inputs=None, lora=None, hidden_states=None)
Bases: Llama3Inputs
Inputs for Nemotron-H: ragged tokens + hybrid SSM/conv state.
Beyond the standard Llama3 ragged inputs, carries a uint32 slot_idx into
the per-request pools, the per-mamba-layer conv pools (bf16, mutated in
place by causal_conv1d_varlen_fwd), the per-mamba-layer SSM pools (fp32,
mutated in-place by mamba2_ssd_chunk_scan_varlen_fwd_inplace), and a
[batch] has_initial_state (the slots are zeroed on claim, so a
fresh requestβs zeroed initial state matches a from-scratch prefill).
-
Parameters:
-
- tokens (Buffer)
- input_row_offsets (Buffer)
- signal_buffers (list[Buffer])
- return_n_logits (Buffer)
- data_parallel_splits (Buffer | Sequence[Sequence[int]] | None)
- slot_idx (Buffer | None)
- conv_pools (list[Buffer] | None)
- ssm_pools (list[Buffer] | None)
- has_initial_state (Buffer | None)
- request_ids (list[RequestID] | None)
- kv_cache_inputs (KVCacheInputsInterface[Buffer, Buffer] | None)
- lora (LoRAInputs | None)
- hidden_states (Buffer | list[Buffer] | None)
buffersβ
Returns positional Buffer inputs for model ABI calls.
conv_poolsβ
has_initial_stateβ
request_idsβ
slot_idxβ
ssm_poolsβ
NemotronHModelβ
class max.pipelines.architectures.nemotron_h.NemotronHModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)
Bases: LlamaModelBase
Nemotron-H pipeline model (hybrid Mamba-2 + attention).
-
Parameters:
-
- pipeline_config (PipelineConfig) β The configuration for this pipeline.
- session (InferenceSession) β The container for the runtime for this model.
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
batch_processor_clsβ
batch_processor_cls: ClassVar[Any] = None
Optional batch processor class for input/output handling.
execute()β
execute(model_inputs)
Executes the graph with the given inputs.
-
Parameters:
-
model_inputs (ModelInputs) β The model inputs to execute, containing tensors and any other required data for model execution.
-
Returns:
-
ModelOutputs containing the pipelineβs output tensors.
-
Return type:
This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.
load_model()β
load_model(session)
-
Parameters:
-
session (InferenceSession)
-
Return type:
model_config_clsβ
model_config_cls
alias of NemotronHConfig
norm_methodβ
norm_method: Literal['rms_norm', 'layer_norm'] = 'rms_norm'
Normalization layer.
prepare_initial_token_inputs()β
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Delegates to the batch processor and narrows to Llama3Inputs.
release()β
release(request_id)
-
Parameters:
-
request_id (RequestID)
-
Return type:
-
None
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!