Python module
interfaces
Interfaces for MAX pipelines.
AlwaysSignalBuffersMixin
class max.pipelines.lib.interfaces.AlwaysSignalBuffersMixin
Bases: object
Mixin for models that always require signal buffers.
Use this for models that use VocabParallelEmbedding or other distributed components that always perform allreduce, even on single-device setups.
Models using this mixin build graphs that always include signal buffer inputs, regardless of device count. This is typically because they use distributed embedding layers or other components that call allreduce operations unconditionally.
devices
Device list that must be provided by the model class.
signal_buffers
Override to always create signal buffers.
Models using this mixin have distributed components that always perform allreduce, even for single-device setups. Therefore, signal buffers are always required to match the graph inputs.
In compile-only mode (virtual device mode), returns an empty list to avoid GPU memory allocation which is not supported.
-
Returns:
-
List of signal buffer tensors, one per device, or empty list in compile-only mode.
ArchConfig
class max.pipelines.lib.interfaces.ArchConfig(*args, **kwargs)
Bases: Protocol
Config for a model architecture.
get_max_seq_len()
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.max_length) flag.
-
Return type:
initialize()
classmethod initialize(pipeline_config)
Initialize the config from a PipelineConfig.
-
Parameters:
-
pipeline_config (PipelineConfig)
-
Return type:
-
Self
ArchConfigWithAttentionKVCache
class max.pipelines.lib.interfaces.ArchConfigWithAttentionKVCache(dtype, devices=<factory>, cache_dtype=None, kv_cache=<factory>, data_parallel_degree=1, user_provided_max_length=None, huggingface_config=None, _kv_params=None)
Bases: ArchConfigWithKVCache, ABC
Predefined configuration for architectures that use attention KV cache blocks.
Subclasses must define the following attributes:
- num_key_value_heads: int
- head_dim: int
- num_layers: int
- model_max_seq_len: int
-
Parameters:
cache_dtype
The data type to use for the KV cache.
data_parallel_degree
data_parallel_degree: int = 1
The data parallel degree to use when running the model.
devices
The physical devices to use when running the model.
dtype
dtype: DType
The data type to use for the model.
get_kv_params()
get_kv_params()
Returns the KV cache parameters for this architecture.
-
Return type:
get_max_seq_len()
get_max_seq_len()
Returns the maximum sequence length the model can process.
Returns max_length if set, otherwise model_max_seq_len.
Raises ValueError if max_length exceeds model_max_seq_len.
-
Return type:
head_dim
abstract property head_dim: int
Dimensionality of each attention head.
huggingface_config
huggingface_config: AutoConfig | None = None
initialize()
classmethod initialize(pipeline_config)
Initialize the config from a PipelineConfig.
-
Parameters:
-
pipeline_config (PipelineConfig)
-
Return type:
-
Self
kv_cache
kv_cache: KVCacheConfig
The KV cache configuration to use when running the model.
model_max_seq_len
abstract property model_max_seq_len: int
The maximum sequence length that can be processed by the model.
num_key_value_heads
abstract property num_key_value_heads: int
Number of key-value heads to use for the KV cache.
num_layers
abstract property num_layers: int
Number of hidden layers in the model.
user_provided_max_length
Override for the maximum sequence length.
ArchConfigWithKVCache
class max.pipelines.lib.interfaces.ArchConfigWithKVCache(*args, **kwargs)
Bases: ArchConfig, Protocol
Config for a model architecture that uses a KV cache.
get_kv_params()
get_kv_params()
KV cache parameters to use when running the model.
-
Return type:
ComponentModel
class max.pipelines.lib.interfaces.ComponentModel(config, encoding, devices, weights)
Bases: ABC
Base interface for component models with weight-backed execution.
load_model()
abstract load_model()
Load and return a runtime model instance.
DiffusionPipeline
class max.pipelines.lib.interfaces.DiffusionPipeline(pipeline_config, session, devices, weight_paths, **kwargs)
Bases: ABC
Base class for diffusion pipelines.
Subclasses must define components mapping component names to ComponentModel types.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- session (InferenceSession)
- devices (list[Device])
- weight_paths (list[Path])
- kwargs (Any)
components
components: dict[str, type[ComponentModel]] | None = None
execute()
abstract execute(model_inputs, **kwargs)
Execute the pipeline with the given model inputs.
-
Parameters:
-
- model_inputs (PixelModelInputs) – Prepared model inputs from prepare_inputs.
- **kwargs (Any) – Additional pipeline-specific execution parameters.
-
Returns:
-
Pipeline-specific output (e.g., generated images).
-
Return type:
finalize_pipeline_config()
classmethod finalize_pipeline_config(pipeline_config)
Hook for finalizing pipeline configuration. Override if needed.
-
Parameters:
-
pipeline_config (PipelineConfig)
-
Return type:
-
None
init_remaining_components()
abstract init_remaining_components()
Initialize non-ComponentModel components (e.g., image processors).
-
Return type:
-
None
prepare_inputs()
abstract prepare_inputs(context)
Prepare inputs for the pipeline.
-
Parameters:
-
context (PixelGenerationContext)
-
Return type:
GenerateMixin
class max.pipelines.lib.interfaces.GenerateMixin(*args, **kwargs)
Bases: Protocol[TextGenerationContextType, RequestType]
Protocol for pipelines that support text generation.
execute()
execute(inputs)
Executes the pipeline for the given inputs.
-
Parameters:
-
inputs (TextGenerationInputs[TextGenerationContextType])
-
Return type:
generate()
generate(prompts)
Generates outputs for the given prompts.
-
Parameters:
-
prompts (RequestType | list[RequestType])
-
Return type:
generate_async()
async generate_async(prompts)
Generates outputs asynchronously for the given prompts.
kv_managers
property kv_managers: list[PagedKVCacheManager]
Returns the KV cache managers for this pipeline.
pipeline_config
property pipeline_config: PipelineConfig
Returns the pipeline configuration.
release()
release(request_id)
Releases resources for the given request.
-
Parameters:
-
request_id (RequestID)
-
Return type:
-
None
tokenizer
property tokenizer: PipelineTokenizer[TextGenerationContextType, ndarray[tuple[Any, ...], dtype[integer[Any]]], RequestType]
Returns the tokenizer for this pipeline.
KVCacheMixin
class max.pipelines.lib.interfaces.KVCacheMixin(*args, **kwargs)
Bases: Protocol
get_kv_params()
abstract classmethod get_kv_params(huggingface_config, pipeline_config, devices, kv_cache_config, cache_dtype)
Returns the KV cache params for the pipeline model.
-
Parameters:
-
- huggingface_config (AutoConfig)
- pipeline_config (PipelineConfig)
- devices (list[DeviceRef])
- kv_cache_config (KVCacheConfig)
- cache_dtype (DType)
-
Return type:
load_kv_managers()
load_kv_managers(kv_params, max_batch_size, max_seq_len, session, available_cache_memory)
Provided a PipelineConfig and InferenceSession, loads the KV manager.
-
Parameters:
-
- kv_params (KVCacheParamInterface) – KV cache parameters.
- max_batch_size (int) – Maximum batch size of the model.
- max_seq_len (int) – Maximum sequence length of the model.
- session (InferenceSession) – Inference session to compile and init the KV cache.
- available_cache_memory (int) – Amount of memory available to the KV cache, in bytes.
-
Returns:
-
A single KV cache manager.
-
Return type:
ModelInputs
class max.pipelines.lib.interfaces.ModelInputs(*, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)
Bases: object
Base class for model inputs.
Use this class to encapsulate inputs for your model; you may store any number of dataclass fields.
The following example demonstrates how to create a custom inputs class:
@dataclass
class ReplitInputs(ModelInputs):
tokens: Buffer
input_row_offsets: Buffer
tokens = Buffer.zeros((1, 2, 3), DType.int64)
input_row_offsets = Buffer.zeros((1, 1, 1), DType.int64)
# Initialize inputs
inputs = ReplitInputs(tokens=tokens, input_row_offsets=input_row_offsets)
# Access tensors
list(inputs) == [tokens, input_row_offsets] # Output: True-
Parameters:
buffers
Returns positional Buffer inputs for model ABI calls.
hidden_states
Hidden states for a variable number of tokens per sequence.
For data parallel models, this can be a list of Buffers where each Buffer contains hidden states for the sequences assigned to that device.
kv_cache_inputs
kv_cache_inputs: KVCacheInputs | None = None
lora_ids
Buffer containing the LoRA ids.
lora_ranks
Buffer containing the LoRA ranks
update()
update(**kwargs)
Updates attributes from keyword arguments (only existing, non-None).
-
Return type:
-
None
ModelOutputs
class max.pipelines.lib.interfaces.ModelOutputs(logits: 'Buffer', next_token_logits: 'Buffer | None' = None, logit_offsets: 'Buffer | None' = None, hidden_states: 'Buffer | list[Buffer] | None' = None)
Bases: object
-
Parameters:
hidden_states
Hidden states for a variable number of tokens per sequence.
For data parallel models, this can be a list of Buffers where each Buffer contains hidden states for the sequences assigned to that device.
logit_offsets
Offsets to access variable length logits for each sequence.
logits
logits: Buffer
Logits for a variable number of tokens per sequence.
next_token_logits
Logits for just the next token.
PipelineModel
class max.pipelines.lib.interfaces.PipelineModel(pipeline_config, session, huggingface_config, encoding, devices, kv_cache_config, weights, adapter, return_logits, return_hidden_states=ReturnHiddenStates.NONE)
Bases: ABC, Generic[BaseContextType]
A pipeline model with setup, input preparation and execution methods.
-
Parameters:
-
- pipeline_config (PipelineConfig)
- session (InferenceSession)
- huggingface_config (AutoConfig)
- encoding (SupportedEncoding)
- devices (list[Device])
- kv_cache_config (KVCacheConfig)
- weights (Weights)
- adapter (WeightsAdapter | None)
- return_logits (ReturnLogits)
- return_hidden_states (ReturnHiddenStates)
calculate_max_seq_len()
abstract classmethod calculate_max_seq_len(pipeline_config, huggingface_config)
Calculates the optimal max sequence length for the model.
Models are expected to implement this method. The following example shows how to implement it for a Mistral model:
class MistralModel(PipelineModel):
@classmethod
def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int:
try:
return upper_bounded_default(
upper_bound=huggingface_config.max_seq_len,
default=pipeline_config.max_length,
)
except ValueError as e:
raise ValueError(
"Unable to infer max_length for Mistral, the provided "
f"max_length ({pipeline_config.max_length}) exceeds the "
f"model's max_seq_len ({huggingface_config.max_seq_len})."
) from e-
Parameters:
-
- pipeline_config (PipelineConfig) – Configuration for the pipeline.
- huggingface_config (AutoConfig) – Hugging Face model configuration.
-
Returns:
-
The maximum sequence length to use.
-
Return type:
compute_log_probabilities()
compute_log_probabilities(session, model_inputs, model_outputs, next_tokens, batch_top_n, batch_echo)
Optional method that can be overridden to compute log probabilities.
-
Parameters:
-
- session (InferenceSession) – Inference session to compute log probabilities within.
- model_inputs (ModelInputs) – Inputs to the model returned by prepare_*_token_inputs().
- model_outputs (ModelOutputs) – Outputs returned by execute().
- next_tokens (Buffer) – Sampled tokens. Should have shape=[batch size]
- batch_top_n (list[int]) – Number of top log probabilities to return per input in the batch. For any element where top_n == 0, the LogProbabilities is skipped.
- batch_echo (list[bool]) – Whether to include input tokens in the returned log probabilities.
-
Returns:
-
List of log probabilities.
-
Return type:
-
list[LogProbabilities | None]
dtype
property dtype: DType
Returns the model data type (from encoding or pipeline config).
estimate_activation_memory()
classmethod estimate_activation_memory(pipeline_config, huggingface_config)
Estimates the activation memory required for model execution.
This accounts for temporary memory buffers used during model execution, such as intermediate activations and working buffers.
The default implementation returns 0 for backward compatibility. Models with significant activation memory requirements should override this method to provide accurate estimates.
-
Parameters:
-
- pipeline_config (PipelineConfig) – Pipeline configuration
- huggingface_config (AutoConfig) – Hugging Face model configuration
-
Returns:
-
Estimated activation memory in bytes
-
Return type:
estimate_weights_size()
classmethod estimate_weights_size(pipeline_config)
Calculates the estimated memory consumption of our model.
-
Parameters:
-
pipeline_config (PipelineConfig)
-
Return type:
execute()
abstract execute(model_inputs)
Executes the graph with the given inputs.
-
Parameters:
-
model_inputs (ModelInputs) – The model inputs to execute, containing tensors and any other required data for model execution.
-
Returns:
-
ModelOutputs containing the pipeline’s output tensors.
-
Return type:
This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.
execute_with_capture()
execute_with_capture(model_inputs, batch_size)
Executes the model with optional capture handling.
Subclasses can override this to integrate device graph capture/replay.
-
Parameters:
-
- model_inputs (ModelInputs)
- batch_size (int)
-
Return type:
finalize_pipeline_config()
classmethod finalize_pipeline_config(pipeline_config)
Finalizes the pipeline configuration.
This method is called after the pipeline configuration is resolved. It can be overridden to perform any finalization steps that are needed.
-
Parameters:
-
pipeline_config (PipelineConfig)
-
Return type:
-
None
lora_manager
property lora_manager: LoRAManager | None
Returns the LoRA manager if LoRA is enabled, otherwise None.
pre_capture_execution_trace()
pre_capture_execution_trace(model_inputs, batch_size)
Pre-captures device graphs for the given model inputs.
-
Parameters:
-
- model_inputs (list[ModelInputs]) – List of model inputs to capture graphs for.
- batch_size (int) – The batch size for execution.
-
Return type:
-
None
prepare_initial_token_inputs()
abstract prepare_initial_token_inputs(replica_batches, kv_cache_inputs=None, return_n_logits=1)
Prepares the initial inputs to be passed to .execute().
The inputs and functionality can vary per model. For example, model
inputs could include encoded tensors, unique IDs per tensor when using
a KV cache manager, and kv_cache_inputs (or None if the model does
not use KV cache). This method typically batches encoded tensors,
claims a KV cache slot if needed, and returns the inputs and caches.
-
Parameters:
-
Return type:
prepare_next_token_inputs()
abstract prepare_next_token_inputs(next_tokens, prev_model_inputs)
Prepares the secondary inputs to be passed to .execute().
While prepare_initial_token_inputs is responsible for managing the initial inputs. This function is responsible for updating the inputs, for each step in a multi-step execution pattern.
-
Parameters:
-
- next_tokens (Buffer)
- prev_model_inputs (ModelInputs)
-
Return type:
signal_buffers
Lazily initialize signal buffers for multi-GPU communication collectives.
Signal buffers are only needed during model execution, not during compilation. By deferring their allocation, we avoid memory allocation in compile-only mode.
-
Returns:
-
List of signal buffer tensors, one per device for multi-device setups, or an empty list for single-device setups or compile-only mode.
PixelModelInputs
class max.pipelines.lib.interfaces.PixelModelInputs(*, tokens, tokens_2=None, negative_tokens=None, negative_tokens_2=None, extra_params=<factory>, timesteps=<factory>, sigmas=<factory>, latents=<factory>, latent_image_ids=<factory>, height=1024, width=1024, num_inference_steps=50, guidance_scale=3.5, guidance=None, true_cfg_scale=1.0, num_warmup_steps=0, num_images_per_prompt=1, input_image=None)
Bases: object
Common input container for pixel-generation models.
Provides a consistent set of fields used across multiple pixel pipelines and models.
-
Parameters:
-
- tokens (TokenBuffer)
- tokens_2 (TokenBuffer | None)
- negative_tokens (TokenBuffer | None)
- negative_tokens_2 (TokenBuffer | None)
- extra_params (dict[str, ndarray[tuple[Any, ...], dtype[Any]]])
- timesteps (ndarray[tuple[Any, ...], dtype[float32]])
- sigmas (ndarray[tuple[Any, ...], dtype[float32]])
- latents (ndarray[tuple[Any, ...], dtype[float32]])
- latent_image_ids (ndarray[tuple[Any, ...], dtype[float32]])
- height (int)
- width (int)
- num_inference_steps (int)
- guidance_scale (float)
- guidance (ndarray[tuple[Any, ...], dtype[float32]] | None)
- true_cfg_scale (float)
- num_warmup_steps (int)
- num_images_per_prompt (int)
- input_image (Image | None)
extra_params
extra_params: dict[str, ndarray[tuple[Any, ...], dtype[Any]]]
A bag of model-specific numeric parameters not represented as explicit fields.
Typical uses:
- Architecture-specific knobs (e.g., cfg_normalization arrays, scaling vectors)
- Precomputed per-step values not worth standardizing across all models
- Small numeric tensors that are easier to carry as named extras than formal fields
Values are expected to be numpy arrays (ndarray) to keep the contract consistent, but you can relax this if your codebase needs non-array values.
from_context()
classmethod from_context(context)
Build an instance from a context-like dict.
Policy:
- If a key is missing: the dataclass default applies automatically.
- If a key is present with value None: treat as missing and substitute the class default (including subclass overrides).
-
Parameters:
-
context (PixelGenerationContext)
-
Return type:
guidance
guidance: ndarray[tuple[Any, ...], dtype[float32]] | None = None
Optional guidance tensor.
- Some pipelines precompute guidance weights/tensors (e.g., per-token weights, per-step weights).
- None is meaningful here: it means “no explicit guidance tensor supplied”.
- Unlike scalar fields, None is preserved (not replaced).
guidance_scale
guidance_scale: float = 3.5
Guidance scale for classifier-free guidance (CFG).
- A higher value typically increases adherence to the prompt but can reduce diversity.
- This is expected to be a real float (not None).
- If a context provides guidance_scale=None, from_context() substitutes the default.
height
height: int = 1024
Output height in pixels.
- This is a required scalar (not None).
- If a context provides height=None, from_context() treats that as “not provided” and substitutes this default value (or a subclass override).
input_image
input_image: Image | None = None
Optional input image for image-to-image generation (PIL.Image.Image).
latent_image_ids
Optional latent image IDs / positional identifiers for latents.
- Some pipelines attach per-latent identifiers for caching, routing, or conditioning.
- Often used to avoid recomputation of image-id embeddings across steps.
- If unused, it may remain empty.
latents
Initial latent noise tensor (or initial latent state).
- For diffusion/flow models, this is typically random noise seeded per request.
- Shape depends on model: commonly [B, C, H/8, W/8] for image latents, or [B, T, C, H/8, W/8] for video latents.
- If your pipeline generates latents internally, you may leave it empty. (Model-specific subclasses can enforce non-empty via __post_init__.)
negative_tokens
negative_tokens: TokenBuffer | None = None
Negative prompt tokens for the primary encoder. Used for classifier-free guidance (CFG) or similar conditioning schemes. If your pipeline does not use negative prompts, leave as None.
negative_tokens_2
negative_tokens_2: TokenBuffer | None = None
Negative prompt tokens for the secondary encoder (for dual-encoder models). If the model is single-encoder or you do not use negative prompts, leave as None.
num_images_per_prompt
num_images_per_prompt: int = 1
Number of images/videos to generate per prompt.
- Commonly used for “same prompt, multiple samples” behavior.
- Must be > 0.
- For video generation, the naming may still be used for historical compatibility.
num_inference_steps
num_inference_steps: int = 50
Number of denoising/inference steps.
- This is a required scalar (not None).
- If a context provides num_inference_steps=None, from_context() treats that as “not provided” and substitutes this default value (or a subclass override).
num_warmup_steps
num_warmup_steps: int = 0
Number of warmup steps.
- Used in some schedulers/pipelines to handle initial steps differently (e.g., scheduler stabilization, cache warmup, etc.).
- Must be >= 0.
sigmas
Precomputed sigma schedule for denoising.
- Usually a 1D float32 numpy array of length num_inference_steps corresponding to the noise level per step.
- Some schedulers are sigma-based; others are timestep-based; some use both.
- If unused, it may remain empty unless your model subclass requires it.
timesteps
Precomputed denoising timestep schedule.
- Usually a 1D float32 numpy array of length num_inference_steps (exact semantics depend on your scheduler).
- If your pipeline precomputes the scheduler trajectory, you pass it here.
- Some models may not require explicit timesteps; in that case it may remain empty. (Model-specific subclasses can enforce non-empty via __post_init__.)
tokens
tokens: TokenBuffer
Primary encoder token buffer. This is the main prompt representation consumed by the model’s text encoder. Required for all models.
tokens_2
tokens_2: TokenBuffer | None = None
Secondary encoder token buffer (for dual-encoder models). Examples: architectures that have a second text encoder stream or pooled embeddings. If the model is single-encoder, leave as None.
true_cfg_scale
true_cfg_scale: float = 1.0
“True CFG” scale used by certain pipelines/models.
- Some architectures distinguish between the user-facing guidance_scale and an internal scale applied to a different normalization or conditioning pathway.
- Defaults to 1.0 for pipelines that do not use this feature.
width
width: int = 1024
Output width in pixels.
- This is a required scalar (not None).
- If a context provides width=None, from_context() treats that as “not provided” and substitutes this default value (or a subclass override).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!