Python module
max.pipelines.architectures.llama3
Llama 3 transformer architecture for text generation.
Llama3Configβ
class max.pipelines.architectures.llama3.Llama3Config(*, hidden_size, num_attention_heads, num_key_value_heads, num_hidden_layers, rope_theta, rope_scaling_params, max_seq_len, intermediate_size, interleaved_rope_weights, vocab_size, dtype, model_quantization_encoding, quantization_config, kv_params, return_logits=ReturnLogits.LAST_TOKEN, norm_method='rms_norm', norm_dtype=None, attention_bias=False, rms_norm_eps=None, tie_word_embeddings=False, stacked_mlp=False, stacked_qkv=False, attention_multiplier, embedding_multiplier, residual_multiplier, devices, clip_qkv, quant_config=None, lora_config=None, longrope_scaling_params=None, logits_scaling=1.0, return_hidden_states=ReturnHiddenStates.NONE, use_subgraphs=True, data_parallel_degree=1)
Bases: ArchConfigWithKVCache
Model configuration for Llama3 graph construction/execution.
-
Parameters:
-
- hidden_size (int)
- num_attention_heads (int)
- num_key_value_heads (int)
- num_hidden_layers (int)
- rope_theta (float)
- rope_scaling_params (Llama3RopeScalingParams | None)
- max_seq_len (int)
- intermediate_size (int)
- interleaved_rope_weights (bool)
- vocab_size (int)
- dtype (DType)
- model_quantization_encoding (QuantizationEncoding | None)
- quantization_config (QuantizationConfig | None)
- kv_params (KVCacheParams)
- return_logits (ReturnLogits)
- norm_method (Literal['rms_norm', 'layer_norm'])
- norm_dtype (DType | None)
- attention_bias (bool)
- rms_norm_eps (float | None)
- tie_word_embeddings (bool)
- stacked_mlp (bool)
- stacked_qkv (bool)
- attention_multiplier (float)
- embedding_multiplier (float)
- residual_multiplier (float)
- devices (list[DeviceRef])
- clip_qkv (float | None)
- quant_config (QuantConfig | None)
- lora_config (LoRAConfig | None)
- longrope_scaling_params (LongRoPEScalingParams | None)
- logits_scaling (float)
- return_hidden_states (ReturnHiddenStates)
- use_subgraphs (bool)
- data_parallel_degree (int)
attention_biasβ
attention_bias: bool = False
attention_multiplierβ
attention_multiplier: float
calculate_attention_multiplier()β
static calculate_attention_multiplier(huggingface_config)
The attention multiplier is a scalar that scales the attention scores. It is used to control the variance of the attention scores.
This function is used to get the attention multiplier from the huggingface config. If the attention multiplier is not set, it will be calculated as the square root of 1.0 divided by the head dimension.
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
calculate_max_seq_len()β
static calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
- model_config (MAXModelConfig | None)
-
Return type:
clip_qkvβ
construct_kv_params()β
static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
data_parallel_degreeβ
data_parallel_degree: int = 1
devicesβ
dtypeβ
dtype: DType
embedding_multiplierβ
embedding_multiplier: float
finalize()β
finalize(huggingface_config, state_dict, return_logits, return_hidden_states=ReturnHiddenStates.NONE, norm_method='rms_norm', attention_bias=False)
Define parameters that canβt be determined just from the pipeline config.
-
Parameters:
-
- huggingface_config (AutoConfig)
- state_dict (dict[str, WeightData])
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
- norm_method (Literal['rms_norm', 'layer_norm'])
- attention_bias (bool)
-
Return type:
-
None
get_head_dim()β
static get_head_dim(huggingface_config)
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
get_kv_params()β
get_kv_params()
KV cache parameters to use when running the model.
-
Return type:
get_max_seq_len()β
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
get_num_layers()β
static get_num_layers(huggingface_config)
-
Parameters:
-
huggingface_config (AutoConfig)
-
Return type:
hidden_sizeβ
hidden_size: int
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The pipeline configuration.
- model_config (MAXModelConfig | None) β The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
initialize_from_config()β
classmethod initialize_from_config(pipeline_config, huggingface_config, model_config=None)
-
Parameters:
-
- pipeline_config (PipelineConfig)
- huggingface_config (AutoConfig)
- model_config (MAXModelConfig | None)
-
Return type:
interleaved_rope_weightsβ
interleaved_rope_weights: bool
intermediate_sizeβ
intermediate_size: int
kv_paramsβ
kv_params: KVCacheParams
logits_scalingβ
logits_scaling: float = 1.0
longrope_scaling_paramsβ
longrope_scaling_params: LongRoPEScalingParams | None = None
lora_configβ
lora_config: LoRAConfig | None = None
max_seq_lenβ
max_seq_len: int
model_quantization_encodingβ
model_quantization_encoding: QuantizationEncoding | None
norm_dtypeβ
norm_methodβ
norm_method: Literal['rms_norm', 'layer_norm'] = 'rms_norm'
num_attention_headsβ
num_attention_heads: int
num_hidden_layersβ
num_hidden_layers: int
num_key_value_headsβ
num_key_value_heads: int
quant_configβ
quant_config: QuantConfig | None = None
quantization_configβ
quantization_config: QuantizationConfig | None
residual_multiplierβ
residual_multiplier: float
return_hidden_statesβ
return_hidden_states: ReturnHiddenStates = 'none'
return_logitsβ
return_logits: ReturnLogits = 'last_token'
rms_norm_epsβ
rope_scaling_paramsβ
rope_scaling_params: Llama3RopeScalingParams | None
rope_thetaβ
rope_theta: float
stacked_mlpβ
stacked_mlp: bool = False
stacked_qkvβ
stacked_qkv: bool = False
tie_word_embeddingsβ
tie_word_embeddings: bool = False
use_subgraphsβ
use_subgraphs: bool = True
vocab_sizeβ
vocab_size: int
Llama3Inputsβ
class max.pipelines.architectures.llama3.Llama3Inputs(tokens, input_row_offsets, signal_buffers, return_n_logits, lora_grouped_offsets=None, num_active_loras=None, lora_end_idx=None, batch_seq_len=None, lora_ids_kv=None, lora_grouped_offsets_kv=None, data_parallel_splits=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)
Bases: ModelInputs
A class representing inputs for the Llama3 model.
This class encapsulates the input tensors required for the Llama3 model execution.
-
Parameters:
-
- tokens (Buffer)
- input_row_offsets (Buffer)
- signal_buffers (list[Buffer])
- return_n_logits (Buffer)
- lora_grouped_offsets (Buffer | None)
- num_active_loras (Buffer | None)
- lora_end_idx (Buffer | None)
- batch_seq_len (Buffer | None)
- lora_ids_kv (Buffer | None)
- lora_grouped_offsets_kv (Buffer | None)
- data_parallel_splits (Buffer | Sequence[Sequence[int]] | None)
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- lora_ids (Buffer | None)
- lora_ranks (Buffer | None)
- hidden_states (Buffer | list[Buffer] | None)
batch_seq_lenβ
buffersβ
Returns positional Buffer inputs for model ABI calls.
data_parallel_splitsβ
data_parallel_splits: Buffer | Sequence[Sequence[int]] | None = None
Tensor containing the data parallel splits.
input_row_offsetsβ
input_row_offsets: Buffer
Tensor containing the offsets for each row in the ragged input sequence.
lora_end_idxβ
lora_grouped_offsetsβ
lora_grouped_offsets_kvβ
lora_ids_kvβ
num_active_lorasβ
return_n_logitsβ
return_n_logits: Buffer
signal_buffersβ
Device buffers used for synchronization in communication collectives.
tokensβ
tokens: Buffer
Tensor containing the input token IDs.
Llama3Modelβ
class max.pipelines.architectures.llama3.Llama3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)
Bases: LlamaModelBase
Llama 3 pipeline model implementation.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The configuration for this pipeline.
- session (InferenceSession) β The container for the runtime for this model.
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
config_classβ
config_class
alias of Llama3Config
norm_methodβ
norm_method: Literal['rms_norm'] | Literal['layer_norm'] = 'rms_norm'
Normalization layer.
LlamaModelBaseβ
class max.pipelines.architectures.llama3.LlamaModelBase(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)
Bases: LogProbabilitiesMixin, PipelineModelWithKVCache[TextContext]
Base Llama pipeline model implementation.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The configuration for this pipeline.
- session (InferenceSession) β The container for the runtime for this model.
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
attention_biasβ
attention_bias: bool = False
Whether to use attention bias.
calculate_max_seq_len()β
classmethod calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the optimal max sequence length for the model.
Models are expected to implement this method. The following example shows how to implement it for a Mistral model:
class MistralModel(PipelineModel):
@classmethod
def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
try:
return upper_bounded_default(
upper_bound=huggingface_config.max_seq_len,
default=pipeline_config.model.max_length,
)
except ValueError as e:
raise ValueError(
"Unable to infer max_length for Mistral, the provided "
f"max_length ({pipeline_config.model.max_length}) exceeds the "
f"model's max_seq_len ({huggingface_config.max_seq_len})."
) from e-
Parameters:
-
- pipeline_config (PipelineConfig) β Configuration for the pipeline.
- huggingface_config (AutoConfig) β Hugging Face model configuration.
-
Returns:
-
The maximum sequence length to use.
-
Return type:
execute()β
execute(model_inputs)
Executes the graph with the given inputs.
-
Parameters:
-
model_inputs (ModelInputs) β The model inputs to execute, containing tensors and any other required data for model execution.
-
Returns:
-
ModelOutputs containing the pipelineβs output tensors.
-
Return type:
This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.
get_kv_params()β
classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Returns the KV cache params for the pipeline model.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
load_model()β
load_model(session)
-
Parameters:
-
session (InferenceSession)
-
Return type:
modelβ
model: Model
Compiled and initialized model ready for inference.
norm_methodβ
norm_method: Literal['rms_norm'] | Literal['layer_norm']
Normalization layer.
prepare_initial_token_inputs()β
prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepare the inputs for the first pass in multistep execution.
-
Parameters:
-
- replica_batches (Sequence[Sequence[TextContext]])
- kv_cache_inputs (KVCacheInputs[Buffer, Buffer] | None)
- return_n_logits (int)
-
Return type:
prepare_next_token_inputs()β
prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepare the inputs for the next token in multistep execution. This should avoid any device synchronization or copy operations.
-
Parameters:
-
- next_tokens (Buffer)
- prev_model_inputs (ModelInputs)
-
Return type:
state_dictβ
Weights to load into the model.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!