Skip to main content

Python class

DiffusionPipeline

DiffusionPipeline

class max.pipelines.lib.interfaces.DiffusionPipeline(pipeline_config, session, devices, weight_paths, cache_config=None, **kwargs)

source

Bases: ABC

Base class for diffusion pipelines.

Subclasses must define components mapping component names to ComponentModel types.

Parameters:

components

components: dict[str, type[ComponentModel]] | None = None

source

default_num_inference_steps

default_num_inference_steps: int = 50

source

Default number of denoising steps when the user does not specify one.

Subclasses may override this to provide a model-appropriate default.

execute()

abstract execute(model_inputs, **kwargs)

source

Execute the pipeline with the given model inputs.

Parameters:

  • model_inputs (PixelModelInputs) – Prepared model inputs from prepare_inputs.
  • **kwargs (Any) – Additional pipeline-specific execution parameters.

Returns:

Pipeline-specific output (e.g., generated images).

Return type:

Any

init_remaining_components()

abstract init_remaining_components()

source

Initialize non-ComponentModel components (e.g., image processors).

Return type:

None

prepare_inputs()

abstract prepare_inputs(context)

source

Prepare inputs for the pipeline.

Parameters:

context (PixelGenerationContext)

Return type:

PixelModelInputs

run_denoising_step()

run_denoising_step(step, cache_state, device, **kwargs)

source

Execute one denoising step with caching logic.

Delegates the actual transformer call to self.run_transformer(), which subclasses override with model-specific arguments.

Parameters:

  • step (int) – Current step index.
  • cache_state (DenoisingCacheState) – Per-request mutable cache state for this stream.
  • device (Device) – Target device.
  • **kwargs (Any) – Model-specific arguments forwarded to run_transformer.

Returns:

noise_pred tensor for this step.

Return type:

Tensor

run_transformer()

run_transformer(cache_state, **kwargs)

source

Run the transformer for one denoising step.

Subclasses must override this to call their transformer with the appropriate model-specific arguments. The method should return (noise_pred,) when first_block_caching is disabled, or (new_residual, noise_pred) when first_block_caching is enabled.

Parameters:

  • cache_state (DenoisingCacheState) – Per-request mutable cache state for this stream.
  • **kwargs (Any) – Model-specific arguments forwarded from run_denoising_step.

Return type:

tuple[Tensor, …]

unprefixed_weight_component

unprefixed_weight_component: str | None = None

source

When set, weight files without a <component>/ prefix are assigned to this component. This supports multi-repo layouts where quantized weights for one component (e.g. the transformer) are shipped as flat files in a separate repo while the remaining components use the base model repo.