IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python module

max.pipelines.architectures.mistral

Mistral transformer architecture for text generation.

MistralConfig​

class max.pipelines.architectures.mistral.MistralConfig(*, hidden_size, num_attention_heads, num_key_value_heads, num_hidden_layers, head_dim, vocab_size, rope_theta, max_seq_len, rms_norm_eps, feed_forward_length, dtype, kv_params, attention_multiplier, devices, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: ArchConfigWithPermissiveMaxSeqLen, ArchConfigWithStoredKVParams, ArchConfigWithKVCache

Configuration for Mistral models.

Parameters:

attention_multiplier​

attention_multiplier: float

source

devices​

devices: list[DeviceRef]

source

dtype​

dtype: DType

source

feed_forward_length​

feed_forward_length: int

source

get_max_seq_len()​

get_max_seq_len()

source

Returns the resolved maximum sequence length stored on the config.

Return type:

int

head_dim​

head_dim: int

source

hidden_size​

hidden_size: int

source

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initializes a MistralConfig instance from pipeline configuration.

This method creates a config instance with all fields that can be determined from the pipeline configuration.

Parameters:

Returns:

An initialized MistralConfig instance.

Return type:

Self

initialize_from_config()​

classmethod initialize_from_config(pipeline_config, huggingface_config)

source

Parameters:

Return type:

Self

kv_params​

kv_params: KVCacheParams

source

max_seq_len​

max_seq_len: int

source

num_attention_heads​

num_attention_heads: int

source

num_hidden_layers​

num_hidden_layers: int

source

num_key_value_heads​

num_key_value_heads: int

source

return_logits​

return_logits: ReturnLogits = 'last_token'

source

Whether to return the last token, all logits, or a variable number of logits.

rms_norm_eps​

rms_norm_eps: float

source

rope_theta​

rope_theta: float

source

vocab_size​

vocab_size: int

source

MistralInputs​

class max.pipelines.architectures.mistral.MistralInputs(tokens, input_row_offsets, signal_buffers, return_n_logits, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

A class representing inputs for the Mistral model.

This class encapsulates the input tensors required for the Mistral model execution:

  • tokens: A tensor containing the input token IDs
  • input_row_offsets: A tensor containing the offsets for each row in the ragged input sequence
  • return_n_logits: A tensor containing the number of expected token logits.

Parameters:

input_row_offsets​

input_row_offsets: Buffer

source

return_n_logits​

return_n_logits: Buffer

source

signal_buffers​

signal_buffers: list[Buffer]

source

Device buffers used for synchronization in communication collectives.

tokens​

tokens: Buffer

source

MistralModel​

class max.pipelines.architectures.mistral.MistralModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)

source

Bases: PipelineModelWithKVCache[TextContext]

Parameters:

calculate_max_seq_len()​

classmethod calculate_max_seq_len(pipeline_config, huggingface_config)

source

Bounds max_length by max_position_embeddings (config is permissive).

Parameters:

Return type:

int

execute()​

execute(model_inputs)

source

Runs the graph.

Parameters:

model_inputs (ModelInputs)

Return type:

ModelOutputs

graph_inputs()​

graph_inputs()

source

Return type:

tuple[TensorType | BufferType, …]

load_model()​

load_model(session)

source

Parameters:

session (InferenceSession)

Return type:

Model

model​

model: Model

source

Compiled and initialized model ready for inference.

model_config_cls​

model_config_cls

source

alias of MistralConfig

prepare_initial_token_inputs()​

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepares the initial inputs to be passed to execute().

The inputs and functionality can vary per model. For example, model inputs could include encoded tensors, unique IDs per tensor when using a KV cache manager, and kv_cache_inputs (or None if the model does not use KV cache). This method typically batches encoded tensors, claims a KV cache slot if needed, and returns the inputs and caches.

Parameters:

Return type:

MistralInputs