Python module
max.pipelines.architectures.mamba
Mamba state-space architecture for text generation.
MambaConfigβ
class max.pipelines.architectures.mamba.MambaConfig(*, hidden_size, intermediate_size, num_hidden_layers, vocab_size, max_seq_len, dtype, devices, d_state, dt_rank=None, conv_kernel=4, x_proj_dim=None, rms_norm_eps=None, use_bias=False, use_conv_bias=True, residual_in_fp32=True, tie_word_embeddings=True, return_logits=ReturnLogits.LAST_TOKEN, quant_config=None, use_subgraphs=True, data_parallel_degree=1, expand=2)
Bases: ArchConfigWithKVCache
Model configuration for Mamba graph construction/execution.
-
Parameters:
-
- hidden_size (int)
- intermediate_size (int)
- num_hidden_layers (int)
- vocab_size (int)
- max_seq_len (int)
- dtype (DType)
- devices (list[DeviceRef])
- d_state (int)
- dt_rank (int | str | None)
- conv_kernel (int)
- x_proj_dim (int | None)
- rms_norm_eps (float | None)
- use_bias (bool)
- use_conv_bias (bool)
- residual_in_fp32 (bool)
- tie_word_embeddings (bool)
- return_logits (ReturnLogits)
- quant_config (QuantConfig | None)
- use_subgraphs (bool)
- data_parallel_degree (int)
- expand (int)
calculate_max_seq_len()β
static calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
- model_config (MAXModelConfig | None)
-
Return type:
conv_kernelβ
conv_kernel: int = 4
d_stateβ
d_state: int
data_parallel_degreeβ
data_parallel_degree: int = 1
devicesβ
dt_rankβ
dtypeβ
dtype: DType
expandβ
expand: int = 2
finalize()β
finalize(huggingface_config, state_dict, return_logits)
Set parameters that require introspecting the state dict.
-
Parameters:
-
- huggingface_config (AutoConfig)
- state_dict (dict[str, WeightData])
- return_logits (ReturnLogits)
-
Return type:
-
None
get_kv_params()β
get_kv_params()
Return minimal dummy KV cache params.
Mamba uses SSM state caching, not attention KV cache, but the pipeline interfaces require KVCacheParams for memory estimation and PagedKVCacheManager initialization. The tiny dimensions (1 head, 1 dim, 1 layer) make the allocation negligible.
-
Return type:
get_max_seq_len()β
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
get_num_layers()β
static get_num_layers(huggingface_config)
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
get_ssm_cache_params()β
get_ssm_cache_params()
-
Return type:
help()β
static help()
hidden_sizeβ
hidden_size: int
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The pipeline configuration.
- model_config (MAXModelConfig | None) β The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
initialize_from_config()β
classmethod initialize_from_config(pipeline_config, huggingface_config, model_config=None)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
- model_config (MAXModelConfig | None)
-
Return type:
intermediate_sizeβ
intermediate_size: int
max_seq_lenβ
max_seq_len: int
num_hidden_layersβ
num_hidden_layers: int
quant_configβ
quant_config: QuantConfig | None = None
residual_in_fp32β
residual_in_fp32: bool = True
return_logitsβ
return_logits: ReturnLogits = 'last_token'
rms_norm_epsβ
tie_word_embeddingsβ
tie_word_embeddings: bool = True
use_biasβ
use_bias: bool = False
use_conv_biasβ
use_conv_bias: bool = True
use_subgraphsβ
use_subgraphs: bool = True
vocab_sizeβ
vocab_size: int
x_proj_dimβ
MambaModelβ
class max.pipelines.architectures.mamba.MambaModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)
Bases: PipelineModelWithKVCache[TextContext]
Mamba pipeline model with incremental SSM state caching.
Uses separate compiled prefill and step models. The prefill model processes the full prompt and extracts per-layer conv/ssm states. The step model processes single new tokens using cached states.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- session (InferenceSession)
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
calculate_max_seq_len()β
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the optimal max sequence length for the model.
Models are expected to implement this method. The following example shows how to implement it for a Mistral model:
class MistralModel(PipelineModel):
@classmethod
def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
try:
return upper_bounded_default(
upper_bound=huggingface_config.max_seq_len,
default=pipeline_config.model.max_length,
)
except ValueError as e:
raise ValueError(
"Unable to infer max_length for Mistral, the provided "
f"max_length ({pipeline_config.model.max_length}) exceeds the "
f"model's max_seq_len ({huggingface_config.max_seq_len})."
) from e-
Parameters:
-
- pipeline_config (PipelineConfig) β Configuration for the pipeline.
- huggingface_config (AutoConfig) β Hugging Face model configuration.
-
Returns:
-
The maximum sequence length to use.
-
Return type:
compute_log_probabilities()β
compute_log_probabilities(session, model_inputs, model_outputs, next_tokens, batch_top_n, batch_echo)
Optional method that can be overridden to compute log probabilities.
-
Parameters:
-
- session (InferenceSession) β Inference session to compute log probabilities within.
- model_inputs (ModelInputs) β Inputs to the model returned by prepare_*_token_inputs().
- model_outputs (ModelOutputs) β Outputs returned by execute().
- next_tokens (Buffer) β Sampled tokens. Should have shape=[batch size]
- batch_top_n (list[int]) β Number of top log probabilities to return per input in the batch. For any element where top_n == 0, the LogProbabilities is skipped.
- batch_echo (list[bool]) β Whether to include input tokens in the returned log probabilities.
-
Returns:
-
List of log probabilities.
-
Return type:
-
list[LogProbabilities | None]
execute()β
execute(model_inputs)
Executes the graph with the given inputs.
-
Parameters:
-
model_inputs (ModelInputs) β The model inputs to execute, containing tensors and any other required data for model execution.
-
Returns:
-
ModelOutputs containing the pipelineβs output tensors.
-
Return type:
This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.
get_kv_params()β
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Return minimal dummy KV cache params.
Mamba uses SSM state caching internally, not attention KV cache. These dummy params satisfy the PipelineModelWithKVCache interface with negligible memory overhead.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
get_num_layers()β
classmethod get_num_layers(huggingface_config)
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
prepare_initial_token_inputs()β
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepares the initial inputs to be passed to execute().
The inputs and functionality can vary per model. For example, model
inputs could include encoded tensors, unique IDs per tensor when using
a KV cache manager, and kv_cache_inputs (or None if the model does
not use KV cache). This method typically batches encoded tensors,
claims a KV cache slot if needed, and returns the inputs and caches.
-
Parameters:
-
- replica_batches (Sequence[Sequence[TextContext]])
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- return_n_logits (int)
-
Return type:
prepare_next_token_inputs()β
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepares the secondary inputs to be passed to execute().
While prepare_initial_token_inputs is responsible for managing the initial inputs.
This function is responsible for updating the inputs, for each step in a multi-step execution pattern.
-
Parameters:
-
- next_tokens (Buffer)
- prev_model_inputs (ModelInputs)
-
Return type:
release()β
release(request_id)
Release SSM cache slot when a request completes.
-
Parameters:
-
request_id (RequestID)
-
Return type:
-
None
MambaModelInputsβ
class max.pipelines.architectures.mamba.MambaModelInputs(tokens, input_row_offsets, return_n_logits, is_prefill=True, layer_states=None, request_ids=None)
Bases: ModelInputs
Inputs for the Mamba pipeline model.
-
Parameters:
input_row_offsetsβ
input_row_offsets: Buffer
is_prefillβ
is_prefill: bool
layer_statesβ
request_idsβ
return_n_logitsβ
return_n_logits: Buffer
tokensβ
tokens: Buffer
SSMStateCacheParamsβ
class max.pipelines.architectures.mamba.SSMStateCacheParams(num_layers, intermediate_size, d_state, conv_kernel, dtype)
Bases: object
Parameters for SSM state cache memory estimation.
- SSM cache is fixed-size per batch element (no sequence-length scaling):
- conv_state: num_layers * intermediate_size * conv_kernel * dtype_bytes ssm_state: num_layers * intermediate_size * d_state * dtype_bytes
compute_max_seq_len_fitting_in_cache()β
compute_max_seq_len_fitting_in_cache(cache_memory)
SSM cache doesnβt scale with seq length β no constraint.
conv_kernelβ
conv_kernel: int
d_stateβ
d_state: int
dtypeβ
dtype: DType
estimated_memory_size()β
estimated_memory_size(available_cache_memory, max_batch_size, max_seq_len)
Total SSM cache bytes. Independent of sequence length.
intermediate_sizeβ
intermediate_size: int
num_layersβ
num_layers: int
per_element_bytesβ
property per_element_bytes: int
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!