Python module
max.pipelines.architectures.llama3
Llama 3 transformer architecture for text generation.
Llama3Config
class max.pipelines.architectures.llama3.Llama3Config(*, hidden_size, num_attention_heads, num_key_value_heads, num_hidden_layers, rope_theta, rope_scaling_params, max_seq_len, intermediate_size, interleaved_rope_weights, vocab_size, dtype, model_quantization_encoding, quantization_config, kv_params, return_logits=ReturnLogits.LAST_TOKEN, norm_method='rms_norm', norm_dtype=None, attention_bias=False, rms_norm_eps=None, tie_word_embeddings=False, stacked_mlp=False, stacked_qkv=False, attention_multiplier, embedding_multiplier, residual_multiplier, devices, clip_qkv, quant_config=None, lora_config=None, longrope_scaling_params=None, logits_scaling=1.0, return_hidden_states=ReturnHiddenStates.NONE, use_subgraphs=True, data_parallel_degree=1)
Bases: ArchConfigWithKVCache
Model configuration for Llama3 graph construction/execution.
-
Parameters:
-
- hidden_size (int)
- num_attention_heads (int)
- num_key_value_heads (int)
- num_hidden_layers (int)
- rope_theta (float)
- rope_scaling_params (Llama3RopeScalingParams | None)
- max_seq_len (int)
- intermediate_size (int)
- interleaved_rope_weights (bool)
- vocab_size (int)
- dtype (DType)
- model_quantization_encoding (QuantizationEncoding | None)
- quantization_config (QuantizationConfig | None)
- kv_params (KVCacheParams)
- return_logits (ReturnLogits)
- norm_method (Literal['rms_norm', 'layer_norm'])
- norm_dtype (DType | None)
- attention_bias (bool)
- rms_norm_eps (float | None)
- tie_word_embeddings (bool)
- stacked_mlp (bool)
- stacked_qkv (bool)
- attention_multiplier (float)
- embedding_multiplier (float)
- residual_multiplier (float)
- devices (list[DeviceRef])
- clip_qkv (float | None)
- quant_config (QuantConfig | None)
- lora_config (LoRAConfig | None)
- longrope_scaling_params (LongRoPEScalingParams | None)
- logits_scaling (float)
- return_hidden_states (ReturnHiddenStates)
- use_subgraphs (bool)
- data_parallel_degree (int)
attention_bias
attention_bias: bool = False
attention_multiplier
attention_multiplier: float
calculate_attention_multiplier()
static calculate_attention_multiplier(huggingface_config)
The attention multiplier is a scalar that scales the attention scores. It is used to control the variance of the attention scores.
This function is used to get the attention multiplier from the huggingface config. If the attention multiplier is not set, it will be calculated as the square root of 1.0 divided by the head dimension.
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
calculate_max_seq_len()
static calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
- model_config (MAXModelConfig | None)
-
Return type:
clip_qkv
construct_kv_params()
static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
data_parallel_degree
data_parallel_degree: int = 1
devices
dtype
dtype: DType
embedding_multiplier
embedding_multiplier: float
finalize()
finalize(huggingface_config, state_dict, return_logits, return_hidden_states=ReturnHiddenStates.NONE, norm_method='rms_norm', attention_bias=False)
Define parameters that can’t be determined just from the pipeline config.
-
Parameters:
-
- huggingface_config (AutoConfig)
- state_dict (dict[str, WeightData])
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
- norm_method (Literal['rms_norm', 'layer_norm'])
- attention_bias (bool)
-
Return type:
-
None
get_head_dim()
static get_head_dim(huggingface_config)
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
get_kv_params()
get_kv_params()
KV cache parameters to use when running the model.
-
Return type:
get_max_seq_len()
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
get_num_layers()
static get_num_layers(huggingface_config)
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
hidden_size
hidden_size: int
initialize()
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The pipeline configuration.
- model_config (MAXModelConfig | None) – The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
initialize_from_config()
classmethod initialize_from_config(pipeline_config, huggingface_config, model_config=None)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
- model_config (MAXModelConfig | None)
-
Return type:
interleaved_rope_weights
interleaved_rope_weights: bool
intermediate_size
intermediate_size: int
kv_params
kv_params: KVCacheParams
logits_scaling
logits_scaling: float = 1.0
longrope_scaling_params
longrope_scaling_params: LongRoPEScalingParams | None = None
lora_config
lora_config: LoRAConfig | None = None
max_seq_len
max_seq_len: int
model_quantization_encoding
model_quantization_encoding: QuantizationEncoding | None
norm_dtype
norm_method
norm_method: Literal['rms_norm', 'layer_norm'] = 'rms_norm'
num_attention_heads
num_attention_heads: int
num_hidden_layers
num_hidden_layers: int
num_key_value_heads
num_key_value_heads: int
quant_config
quant_config: QuantConfig | None = None
quantization_config
quantization_config: QuantizationConfig | None
residual_multiplier
residual_multiplier: float
return_hidden_states
return_hidden_states: ReturnHiddenStates = 'none'
return_logits
return_logits: ReturnLogits = 'last_token'
rms_norm_eps
rope_scaling_params
rope_scaling_params: Llama3RopeScalingParams | None
rope_theta
rope_theta: float
stacked_mlp
stacked_mlp: bool = False
stacked_qkv
stacked_qkv: bool = False
tie_word_embeddings
tie_word_embeddings: bool = False
use_subgraphs
use_subgraphs: bool = True
vocab_size
vocab_size: int
Llama3Inputs
class max.pipelines.architectures.llama3.Llama3Inputs(tokens, input_row_offsets, signal_buffers, return_n_logits, lora_grouped_offsets=None, num_active_loras=None, lora_end_idx=None, batch_seq_len=None, lora_ids_kv=None, lora_grouped_offsets_kv=None, data_parallel_splits=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)
Bases: ModelInputs
A class representing inputs for the Llama3 model.
This class encapsulates the input tensors required for the Llama3 model execution.
-
Parameters:
-
- tokens (Buffer)
- input_row_offsets (Buffer)
- signal_buffers (list[Buffer])
- return_n_logits (Buffer)
- lora_grouped_offsets (Buffer | None)
- num_active_loras (Buffer | None)
- lora_end_idx (Buffer | None)
- batch_seq_len (Buffer | None)
- lora_ids_kv (Buffer | None)
- lora_grouped_offsets_kv (Buffer | None)
- data_parallel_splits (Buffer | Sequence[Sequence[int]] | None)
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- lora_ids (Buffer | None)
- lora_ranks (Buffer | None)
- hidden_states (Buffer | list[Buffer] | None)
batch_seq_len
buffers
Returns positional Buffer inputs for model ABI calls.
data_parallel_splits
data_parallel_splits: Buffer | Sequence[Sequence[int]] | None = None
Tensor containing the data parallel splits.
input_row_offsets
input_row_offsets: Buffer
Tensor containing the offsets for each row in the ragged input sequence.
lora_end_idx
lora_grouped_offsets
lora_grouped_offsets_kv
lora_ids_kv
num_active_loras
return_n_logits
return_n_logits: Buffer
signal_buffers
Device buffers used for synchronization in communication collectives.
tokens
tokens: Buffer
Tensor containing the input token IDs.
Llama3Model
class max.pipelines.architectures.llama3.Llama3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)
Bases: LlamaModelBase
Llama 3 pipeline model implementation.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The configuration for this pipeline.
- session (InferenceSession) – The container for the runtime for this model.
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
config_class
config_class
alias of Llama3Config
norm_method
norm_method: Literal['rms_norm'] | Literal['layer_norm'] = 'rms_norm'
Normalization layer.
LlamaModelBase
class max.pipelines.architectures.llama3.LlamaModelBase(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)
Bases: LogProbabilitiesMixin, PipelineModelWithKVCache[TextContext]
Base Llama pipeline model implementation.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The configuration for this pipeline.
- session (InferenceSession) – The container for the runtime for this model.
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
attention_bias
attention_bias: bool = False
Whether to use attention bias.
calculate_max_seq_len()
classmethod calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the optimal max sequence length for the model.
Models are expected to implement this method. The following example shows how to implement it for a Mistral model:
class MistralModel(PipelineModel):
@classmethod
def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
try:
return upper_bounded_default(
upper_bound=huggingface_config.max_seq_len,
default=pipeline_config.model.max_length,
)
except ValueError as e:
raise ValueError(
"Unable to infer max_length for Mistral, the provided "
f"max_length ({pipeline_config.model.max_length}) exceeds the "
f"model's max_seq_len ({huggingface_config.max_seq_len})."
) from e-
Parameters:
-
- pipeline_config (PipelineConfig) – Configuration for the pipeline.
- huggingface_config (AutoConfig) – Hugging Face model configuration.
-
Returns:
-
The maximum sequence length to use.
-
Return type:
execute()
execute(model_inputs)
Executes the graph with the given inputs.
-
Parameters:
-
model_inputs (ModelInputs) – The model inputs to execute, containing tensors and any other required data for model execution.
-
Returns:
-
ModelOutputs containing the pipeline’s output tensors.
-
Return type:
This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.
get_kv_params()
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Returns the KV cache params for the pipeline model.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
load_model()
load_model(session)
-
Parameters:
-
session (InferenceSession)
-
Return type:
model
model: Model
Compiled and initialized model ready for inference.
norm_method
norm_method: Literal['rms_norm'] | Literal['layer_norm']
Normalization layer.
prepare_initial_token_inputs()
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepare the inputs for the first pass in multistep execution.
-
Parameters:
-
- replica_batches (Sequence[Sequence[TextContext]])
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- return_n_logits (int)
-
Return type:
prepare_next_token_inputs()
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepare the inputs for the next token in multistep execution. This should avoid any device synchronization or copy operations.
-
Parameters:
-
- next_tokens (Buffer)
- prev_model_inputs (ModelInputs)
-
Return type:
state_dict
Weights to load into the model.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!