Skip to main content

Python module

max.pipelines.architectures.llama3

Llama 3 transformer architecture for text generation.

Llama3Config

class max.pipelines.architectures.llama3.Llama3Config(*, hidden_size, num_attention_heads, num_key_value_heads, num_hidden_layers, rope_theta, rope_scaling_params, max_seq_len, intermediate_size, interleaved_rope_weights, vocab_size, dtype, model_quantization_encoding, quantization_config, kv_params, return_logits=ReturnLogits.LAST_TOKEN, norm_method='rms_norm', norm_dtype=None, attention_bias=False, rms_norm_eps=None, tie_word_embeddings=False, stacked_mlp=False, stacked_qkv=False, attention_multiplier, embedding_multiplier, residual_multiplier, devices, clip_qkv, quant_config=None, lora_config=None, longrope_scaling_params=None, logits_scaling=1.0, return_hidden_states=ReturnHiddenStates.NONE, use_subgraphs=True, data_parallel_degree=1)

source

Bases: ArchConfigWithKVCache

Model configuration for Llama3 graph construction/execution.

Parameters:

attention_bias

attention_bias: bool = False

source

attention_multiplier

attention_multiplier: float

source

calculate_attention_multiplier()

static calculate_attention_multiplier(huggingface_config)

source

The attention multiplier is a scalar that scales the attention scores. It is used to control the variance of the attention scores.

This function is used to get the attention multiplier from the huggingface config. If the attention multiplier is not set, it will be calculated as the square root of 1.0 divided by the head dimension.

Parameters:

huggingface_config (AutoConfig)

Return type:

float

calculate_max_seq_len()

static calculate_max_seq_len(pipeline_config, huggingface_config, model_config=None)

source

Parameters:

Return type:

int

clip_qkv

clip_qkv: float | None

source

construct_kv_params()

static construct_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Parameters:

Return type:

KVCacheParams

data_parallel_degree

data_parallel_degree: int = 1

source

devices

devices: list[DeviceRef]

source

dtype

dtype: DType

source

embedding_multiplier

embedding_multiplier: float

source

finalize()

finalize(huggingface_config, state_dict, return_logits, return_hidden_states=ReturnHiddenStates.NONE, norm_method='rms_norm', attention_bias=False)

source

Define parameters that can’t be determined just from the pipeline config.

Parameters:

Return type:

None

get_head_dim()

static get_head_dim(huggingface_config)

source

Parameters:

huggingface_config (AutoConfig)

Return type:

int

get_kv_params()

get_kv_params()

source

KV cache parameters to use when running the model.

Return type:

KVCacheParams

get_max_seq_len()

get_max_seq_len()

source

Returns the default maximum sequence length for the model.

Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

Return type:

int

get_num_layers()

static get_num_layers(huggingface_config)

source

Parameters:

huggingface_config (AutoConfig)

Return type:

int

hidden_size

hidden_size: int

source

initialize()

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

initialize_from_config()

classmethod initialize_from_config(pipeline_config, huggingface_config, model_config=None)

source

Parameters:

Return type:

Self

interleaved_rope_weights

interleaved_rope_weights: bool

source

intermediate_size

intermediate_size: int

source

kv_params

kv_params: KVCacheParams

source

logits_scaling

logits_scaling: float = 1.0

source

longrope_scaling_params

longrope_scaling_params: LongRoPEScalingParams | None = None

source

lora_config

lora_config: LoRAConfig | None = None

source

max_seq_len

max_seq_len: int

source

model_quantization_encoding

model_quantization_encoding: QuantizationEncoding | None

source

norm_dtype

norm_dtype: DType | None = None

source

norm_method

norm_method: Literal['rms_norm', 'layer_norm'] = 'rms_norm'

source

num_attention_heads

num_attention_heads: int

source

num_hidden_layers

num_hidden_layers: int

source

num_key_value_heads

num_key_value_heads: int

source

quant_config

quant_config: QuantConfig | None = None

source

quantization_config

quantization_config: QuantizationConfig | None

source

residual_multiplier

residual_multiplier: float

source

return_hidden_states

return_hidden_states: ReturnHiddenStates = 'none'

source

return_logits

return_logits: ReturnLogits = 'last_token'

source

rms_norm_eps

rms_norm_eps: float | None = None

source

rope_scaling_params

rope_scaling_params: Llama3RopeScalingParams | None

source

rope_theta

rope_theta: float

source

stacked_mlp

stacked_mlp: bool = False

source

stacked_qkv

stacked_qkv: bool = False

source

tie_word_embeddings

tie_word_embeddings: bool = False

source

use_subgraphs

use_subgraphs: bool = True

source

vocab_size

vocab_size: int

source

Llama3Inputs

class max.pipelines.architectures.llama3.Llama3Inputs(tokens, input_row_offsets, signal_buffers, return_n_logits, lora_grouped_offsets=None, num_active_loras=None, lora_end_idx=None, batch_seq_len=None, lora_ids_kv=None, lora_grouped_offsets_kv=None, data_parallel_splits=None, *, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: ModelInputs

A class representing inputs for the Llama3 model.

This class encapsulates the input tensors required for the Llama3 model execution.

Parameters:

batch_seq_len

batch_seq_len: Buffer | None = None

source

buffers

property buffers: tuple[Buffer, ...]

source

Returns positional Buffer inputs for model ABI calls.

data_parallel_splits

data_parallel_splits: Buffer | Sequence[Sequence[int]] | None = None

source

Tensor containing the data parallel splits.

input_row_offsets

input_row_offsets: Buffer

source

Tensor containing the offsets for each row in the ragged input sequence.

lora_end_idx

lora_end_idx: Buffer | None = None

source

lora_grouped_offsets

lora_grouped_offsets: Buffer | None = None

source

lora_grouped_offsets_kv

lora_grouped_offsets_kv: Buffer | None = None

source

lora_ids_kv

lora_ids_kv: Buffer | None = None

source

num_active_loras

num_active_loras: Buffer | None = None

source

return_n_logits

return_n_logits: Buffer

source

signal_buffers

signal_buffers: list[Buffer]

source

Device buffers used for synchronization in communication collectives.

tokens

tokens: Buffer

source

Tensor containing the input token IDs.

Llama3Model

class max.pipelines.architectures.llama3.Llama3Model(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)

source

Bases: LlamaModelBase

Llama 3 pipeline model implementation.

Parameters:

config_class

config_class

source

alias of Llama3Config

norm_method

norm_method: Literal['rms_norm'] | Literal['layer_norm'] = 'rms_norm'

source

Normalization layer.

LlamaModelBase

class max.pipelines.architectures.llama3.LlamaModelBase(pipeline_config, session, devices, kv_cache_config, weights, adapter=None, return_logits=ReturnLogits.LAST_TOKEN, return_hidden_states=ReturnHiddenStates.NONE)

source

Bases: LogProbabilitiesMixin, PipelineModelWithKVCache[TextContext]

Base Llama pipeline model implementation.

Parameters:

attention_bias

attention_bias: bool = False

source

Whether to use attention bias.

calculate_max_seq_len()

classmethod calculate_max_seq_len(pipeline_config, huggingface_config)

source

Calculates the optimal max sequence length for the model.

Models are expected to implement this method. The following example shows how to implement it for a Mistral model:

class MistralModel(PipelineModel):
    @classmethod
    def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
        try:
            return upper_bounded_default(
                upper_bound=huggingface_config.max_seq_len,
                default=pipeline_config.model.max_length,
            )
        except ValueError as e:
            raise ValueError(
                "Unable to infer max_length for Mistral, the provided "
                f"max_length ({pipeline_config.model.max_length}) exceeds the "
                f"model's max_seq_len ({huggingface_config.max_seq_len})."
            ) from e

Parameters:

  • pipeline_config (PipelineConfig) – Configuration for the pipeline.
  • huggingface_config (AutoConfig) – Hugging Face model configuration.

Returns:

The maximum sequence length to use.

Return type:

int

execute()

execute(model_inputs)

source

Executes the graph with the given inputs.

Parameters:

model_inputs (ModelInputs) – The model inputs to execute, containing tensors and any other required data for model execution.

Returns:

ModelOutputs containing the pipeline’s output tensors.

Return type:

ModelOutputs

This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.

get_kv_params()

classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)

source

Returns the KV cache params for the pipeline model.

Parameters:

Return type:

KVCacheParams

load_model()

load_model(session)

source

Parameters:

session (InferenceSession)

Return type:

Model

model

model: Model

source

Compiled and initialized model ready for inference.

norm_method

norm_method: Literal['rms_norm'] | Literal['layer_norm']

source

Normalization layer.

prepare_initial_token_inputs()

prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)

source

Prepare the inputs for the first pass in multistep execution.

Parameters:

Return type:

Llama3Inputs

prepare_next_token_inputs()

prepare_next_token_inputs(next_tokens, prev_model_inputs)

source

Prepare the inputs for the next token in multistep execution. This should avoid any device synchronization or copy operations.

Parameters:

Return type:

Llama3Inputs

state_dict

state_dict: dict[str, Any]

source

Weights to load into the model.