Skip to main content

Python module

max.pipelines.architectures.mamba

Mamba state-space architecture for text generation.

MambaConfig

class max.pipelines.architectures.mamba.MambaConfig(*, hidden_size, intermediate_size, num_hidden_layers, vocab_size, max_seq_len, dtype, devices, d_state, dt_rank=None, conv_kernel=4, x_proj_dim=None, rms_norm_eps=None, use_bias=False, use_conv_bias=True, residual_in_fp32=True, tie_word_embeddings=True, return_logits=ReturnLogits.LAST_TOKEN, quant_config=None, use_subgraphs=True, data_parallel_degree=1, expand=2)

source

Bases: ArchConfigWithKVCache

Model configuration for Mamba graph construction/execution.

Parameters:

calculate_max_seq_len()

static calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)

source

Parameters:

Return type:

int

conv_kernel

conv_kernel: int = 4

source

d_state

d_state: int

source

data_parallel_degree

data_parallel_degree: int = 1

source

devices

devices: list[DeviceRef]

source

dt_rank

dt_rank: int | str | None = None

source

dtype

dtype: DType

source

expand

expand: int = 2

source

finalize()

finalize(huggingface_config, state_dict, return_logits)

source

Set parameters that require introspecting the state dict.

Parameters:

Return type:

None

get_kv_params()

get_kv_params()

source

Return minimal dummy KV cache params.

Mamba uses SSM state caching, not attention KV cache, but the pipeline interfaces require KVCacheParams for memory estimation and PagedKVCacheManager initialization. The tiny dimensions (1 head, 1 dim, 1 layer) make the allocation negligible.

Return type:

KVCacheParams

get_max_seq_len()

get_max_seq_len()

source

Returns the default maximum sequence length for the model.

Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

Return type:

int

get_num_layers()

static get_num_layers(huggingface_config)

source

Parameters:

huggingface_config (AutoConfig)

Return type:

int

get_ssm_cache_params()

get_ssm_cache_params()

source

Return type:

SSMStateCacheParams

help()

static help()

source

Return type:

dict[str, str]

hidden_size

hidden_size: int

source

initialize()

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

initialize_from_config()

classmethod initialize_from_config(pipeline_config, huggingface_config, model_config=None)

source

Parameters:

Return type:

Self

intermediate_size

intermediate_size: int

source

max_seq_len

max_seq_len: int

source

num_hidden_layers

num_hidden_layers: int

source

quant_config

quant_config: QuantConfig | None = None

source

residual_in_fp32

residual_in_fp32: bool = True

source

return_logits

return_logits: ReturnLogits = 'last_token'

source

rms_norm_eps

rms_norm_eps: float | None = None

source

tie_word_embeddings

tie_word_embeddings: bool = True

source

use_bias

use_bias: bool = False

source

use_conv_bias

use_conv_bias: bool = True

source

use_subgraphs

use_subgraphs: bool = True

source

vocab_size

vocab_size: int

source

x_proj_dim

x_proj_dim: int | None = None

source

MambaModel

class max.pipelines.architectures.mamba.MambaModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)

source

Bases: PipelineModelWithKVCache[TextContext]

Mamba pipeline model with incremental SSM state caching.

Uses separate compiled prefill and step models. The prefill model processes the full prompt and extracts per-layer conv/ssm states. The step model processes single new tokens using cached states.

Parameters:

calculate_max_seq_len()

static calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the optimal max sequence length for the model.

Models are expected to implement this method. The following example shows how to implement it for a Mistral model:

class MistralModel(PipelineModel):
    @classmethod
    def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
        try:
            return upper_bounded_default(
                upper_bound=huggingface_config.max_seq_len,
                default=pipeline_config.model.max_length,
            )
        except ValueError as e:
            raise ValueError(
                "Unable to infer max_length for Mistral, the provided "
                f"max_length ({pipeline_config.model.max_length}) exceeds the "
                f"model's max_seq_len ({huggingface_config.max_seq_len})."
            ) from e

Parameters:

  • pipeline_config (PipelineConfig) – Configuration for the pipeline.
  • huggingface_config (AutoConfig) – Hugging Face model configuration.

Returns:

The maximum sequence length to use.

Return type:

int

compute_log_probabilities()

compute_log_probabilities(session, model_inputs, model_outputs, next_tokens, batch_top_n, batch_echo)

source

Optional method that can be overridden to compute log probabilities.

Parameters:

  • session (InferenceSession) – Inference session to compute log probabilities within.
  • model_inputs (ModelInputs) – Inputs to the model returned by prepare_*_token_inputs().
  • model_outputs (ModelOutputs) – Outputs returned by execute().
  • next_tokens (Buffer) – Sampled tokens. Should have shape=[batch size]
  • batch_top_n (list[int]) – Number of top log probabilities to return per input in the batch. For any element where top_n == 0, the LogProbabilities is skipped.
  • batch_echo (list[bool]) – Whether to include input tokens in the returned log probabilities.

Returns:

List of log probabilities.

Return type:

list[LogProbabilities | None]

execute()

execute(model_inputs)

source

Executes the graph with the given inputs.

Parameters:

model_inputs (ModelInputs) – The model inputs to execute, containing tensors and any other required data for model execution.

Returns:

ModelOutputs containing the pipeline’s output tensors.

Return type:

ModelOutputs

This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.

get_kv_params()

classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Return minimal dummy KV cache params.

Mamba uses SSM state caching internally, not attention KV cache. These dummy params satisfy the PipelineModelWithKVCache interface with negligible memory overhead.

Parameters:

Return type:

KVCacheParamInterface

get_num_layers()

classmethod get_num_layers(huggingface_config)

source

Parameters:

huggingface_config (AutoConfig)

Return type:

int

prepare_initial_token_inputs()

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepares the initial inputs to be passed to execute().

The inputs and functionality can vary per model. For example, model inputs could include encoded tensors, unique IDs per tensor when using a KV cache manager, and kv_cache_inputs (or None if the model does not use KV cache). This method typically batches encoded tensors, claims a KV cache slot if needed, and returns the inputs and caches.

Parameters:

Return type:

MambaModelInputs

prepare_next_token_inputs()

prepare_next_token_inputs(next_tokens, prev_model_inputs)

source

Prepares the secondary inputs to be passed to execute().

While prepare_initial_token_inputs is responsible for managing the initial inputs. This function is responsible for updating the inputs, for each step in a multi-step execution pattern.

Parameters:

Return type:

MambaModelInputs

release()

release(request_id)

source

Release SSM cache slot when a request completes.

Parameters:

request_id (RequestID)

Return type:

None