Python module
max.pipelines.architectures.mistral
Mistral transformer architecture for text generation.
MistralConfigβ
class max.pipelines.architectures.mistral.MistralConfig(*, hidden_size, num_attention_heads, num_key_value_heads, num_hidden_layers, head_dim, vocab_size, rope_theta, max_seq_len, rms_norm_eps, feed_forward_length, dtype, kv_params, attention_multiplier, devices, return_logits=ReturnLogits.LAST_TOKEN)
Bases: ArchConfigWithKVCache
Configuration for Mistral models.
-
Parameters:
-
- hidden_size (int)
- num_attention_heads (int)
- num_key_value_heads (int)
- num_hidden_layers (int)
- head_dim (int)
- vocab_size (int)
- rope_theta (float)
- max_seq_len (int)
- rms_norm_eps (float)
- feed_forward_length (int)
- dtype (DType)
- kv_params (KVCacheParams)
- attention_multiplier (float)
- devices (list[DeviceRef])
- return_logits (ReturnLogits)
attention_multiplierβ
attention_multiplier: float
calculate_max_seq_len()β
static calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the maximum sequence length for the model.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
-
Return type:
construct_kv_params()β
static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
devicesβ
dtypeβ
dtype: DType
feed_forward_lengthβ
feed_forward_length: int
get_head_dim()β
static get_head_dim(huggingface_config)
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
get_kv_params()β
get_kv_params()
KV cache parameters to use when running the model.
-
Return type:
get_max_seq_len()β
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
get_num_layers()β
static get_num_layers(huggingface_config)
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
head_dimβ
head_dim: int
hidden_sizeβ
hidden_size: int
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initializes a MistralConfig instance from pipeline configuration.
This method creates a config instance with all fields that can be determined from the pipeline configuration.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The MAX Engine pipeline configuration.
- model_config (MAXModelConfig | None)
-
Returns:
-
An initialized MistralConfig instance.
-
Return type:
initialize_from_config()β
classmethod initialize_from_config(pipeline_config, huggingface_config)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
-
Return type:
kv_paramsβ
kv_params: KVCacheParams
max_seq_lenβ
max_seq_len: int
num_attention_headsβ
num_attention_heads: int
num_hidden_layersβ
num_hidden_layers: int
num_key_value_headsβ
num_key_value_heads: int
return_logitsβ
return_logits: ReturnLogits = 'last_token'
Whether to return the last token, all logits, or a variable number of logits.
rms_norm_epsβ
rms_norm_eps: float
rope_thetaβ
rope_theta: float
vocab_sizeβ
vocab_size: int
MistralInputsβ
class max.pipelines.architectures.mistral.MistralInputs(tokens, input_row_offsets, signal_buffers, return_n_logits, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)
Bases: ModelInputs
A class representing inputs for the Mistral model.
This class encapsulates the input tensors required for the Mistral model execution:
- tokens: A tensor containing the input token IDs
- input_row_offsets: A tensor containing the offsets for each row in the ragged input sequence
- return_n_logits: A tensor containing the number of expected token logits.
-
Parameters:
input_row_offsetsβ
input_row_offsets: Buffer
return_n_logitsβ
return_n_logits: Buffer
signal_buffersβ
Device buffers used for synchronization in communication collectives.
tokensβ
tokens: Buffer
MistralModelβ
class max.pipelines.architectures.mistral.MistralModel(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN)
Bases: PipelineModelWithKVCache[TextContext]
-
Parameters:
-
- pipeline_config (PipelineConfig)
- session (InferenceSession)
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
calculate_max_seq_len()β
classmethod calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the optimal max sequence length for the model.
Models are expected to implement this method. The following example shows how to implement it for a Mistral model:
class MistralModel(PipelineModel):
@classmethod
def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
try:
return upper_bounded_default(
upper_bound=huggingface_config.max_seq_len,
default=pipeline_config.model.max_length,
)
except ValueError as e:
raise ValueError(
"Unable to infer max_length for Mistral, the provided "
f"max_length ({pipeline_config.model.max_length}) exceeds the "
f"model's max_seq_len ({huggingface_config.max_seq_len})."
) from e-
Parameters:
-
- pipeline_config (PipelineConfig) β Configuration for the pipeline.
- huggingface_config (AutoConfig) β Hugging Face model configuration.
-
Returns:
-
The maximum sequence length to use.
-
Return type:
execute()β
execute(model_inputs)
Runs the graph.
-
Parameters:
-
model_inputs (ModelInputs)
-
Return type:
get_kv_params()β
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Returns the KV cache params for the pipeline model.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
graph_inputs()β
graph_inputs()
-
Return type:
-
tuple[TensorType | BufferType, β¦]
load_model()β
load_model(session)
-
Parameters:
-
session (InferenceSession)
-
Return type:
modelβ
model: Model
Compiled and initialized model ready for inference.
prepare_initial_token_inputs()β
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepares the initial inputs to be passed to execute().
The inputs and functionality can vary per model. For example, model
inputs could include encoded tensors, unique IDs per tensor when using
a KV cache manager, and kv_cache_inputs (or None if the model does
not use KV cache). This method typically batches encoded tensors,
claims a KV cache slot if needed, and returns the inputs and caches.
-
Parameters:
-
- replica_batches (Sequence[Sequence[TextContext]])
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- return_n_logits (int)
-
Return type:
prepare_next_token_inputs()β
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepares the secondary inputs to be passed to execute().
While prepare_initial_token_inputs is responsible for managing the initial inputs.
This function is responsible for updating the inputs, for each step in a multi-step execution pattern.
-
Parameters:
-
- next_tokens (Buffer)
- prev_model_inputs (ModelInputs)
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!