Skip to main content

Python module

max.pipelines.architectures.wan

Wan diffusion architecture for video generation.

WanArchConfig

class max.pipelines.architectures.wan.WanArchConfig(*, pipeline_config)

source

Bases: ArchConfig

Pipeline-level config for Wan (implements ArchConfig; no KV cache).

Parameters:

pipeline_config (PipelineConfig)

get_max_seq_len()

get_max_seq_len()

source

Returns the default maximum sequence length for the model.

Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

Return type:

int

initialize()

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

pipeline_config

pipeline_config: PipelineConfig

source

WanContext

class max.pipelines.architectures.wan.WanContext(*, tokens, request_id=<factory>, model_name='', mask=None, tokens_2=None, negative_tokens=None, negative_mask=None, negative_tokens_2=None, explicit_negative_prompt=False, timesteps=<factory>, sigmas=<factory>, latents=<factory>, latent_image_ids=<factory>, text_ids=<factory>, negative_text_ids=<factory>, height=1024, width=1024, num_inference_steps=50, guidance_scale=3.5, true_cfg_scale=1.0, strength=0.6, cfg_normalization=False, cfg_truncation=1.0, num_warmup_steps=0, num_images_per_prompt=1, input_image=None, input_images=None, prompt_images=None, vae_condition_images=None, output_format='jpeg', residual_threshold=None, status=GenerationStatus.ACTIVE, num_frames=None, guidance_scale_2=None, step_coefficients=None, boundary_timestep=None)

source

Bases: PixelContext

Pixel generation context with Wan-specific video/MoE fields.

Parameters:

boundary_timestep

boundary_timestep: float | None = None

source

Timestep threshold for switching between high/low noise experts.

guidance_scale_2

guidance_scale_2: float | None = None

source

Secondary guidance scale for low-noise expert (MoE models).

num_frames

num_frames: int | None = None

source

Number of frames for video generation.

step_coefficients

step_coefficients: npt.NDArray[np.float32] | None = None

source

Pre-computed scheduler step coefficients.

WanExecutor

class max.pipelines.architectures.wan.WanExecutor(manifest, session, runtime_config)

source

Bases: PipelineExecutor[WanContext, WanExecutorInputs, WanExecutorOutputs]

Wan video diffusion pipeline executor.

Implements the PipelineExecutor interface for Wan video generation, wiring together the sub-components (text encoder, transformer, VAE) through the tensor-in/tensor-out executor contract.

Parameters:

default_num_inference_steps

default_num_inference_steps: int = 50

source

execute()

execute(inputs)

source

Runs the compiled model graph on the provided inputs.

The inputs should be the TensorStruct produced by prepare_inputs(), passed through without transformation. The caller may have transferred them to a different device between preparation and execution.

The returned struct may contain device-resident tensors. The caller is responsible for any host transfer needed for post-processing, using .to(device) on the returned TensorStruct.

Parameters:

inputs (WanExecutorInputs) – The prepared graph inputs, as returned by prepare_inputs().

Returns:

A TensorStruct containing the model outputs.

Return type:

WanExecutorOutputs

prepare_inputs()

prepare_inputs(contexts)

source

Converts a batch of contexts into a structured tensor container ready for graph execution.

Each context in the batch represents a single request or work item. The implementation is responsible for collating, tokenizing, or otherwise transforming the batch into the tensor format expected by the compiled graph.

The returned struct may contain tensors on any device. The caller is responsible for transferring them to the appropriate device before passing them to execute(), using .to(device) on the returned TensorStruct.

Parameters:

contexts (list[WanContext]) – A list of context objects representing the batch of requests to prepare inputs for.

Returns:

A TensorStruct containing the prepared graph inputs for the batch.

Return type:

WanExecutorInputs

WanTokenizer

class max.pipelines.architectures.wan.WanTokenizer(model_path, pipeline_config, subfolder, *, subfolder_2=None, revision=None, max_length=None, secondary_max_length=None, trust_remote_code=False, default_num_inference_steps=50, **unused_kwargs)

source

Bases: PixelGenerationTokenizer

Wan-specific tokenizer that produces WanContext with video/MoE fields.

Parameters:

  • model_path (str)
  • pipeline_config (PipelineConfig)
  • subfolder (str)
  • subfolder_2 (str | None)
  • revision (str | None)
  • max_length (int | None)
  • secondary_max_length (int | None)
  • trust_remote_code (bool)
  • default_num_inference_steps (int)

new_context()

async new_context(request, input_image=None)

source

Creates a new PixelContext object, leveraging necessary information from OpenResponsesRequest.

Parameters:

Return type:

WanContext

WanTransformerModel

class max.pipelines.architectures.wan.WanTransformerModel(config, encoding, devices, weights, session=None, eager_load=True)

source

Bases: ComponentModel

MAX-native Wan DiT interface with block-level compilation.

Each block is compiled independently so only one block’s workspace is live at any time, keeping peak VRAM low.

Parameters:

compute_rope()

compute_rope(num_frames, height, width)

source

Compute 3D RoPE cos/sin tensors and transfer to device.

Parameters:

Return type:

tuple[Buffer, Buffer]

load_model()

load_model(*, seq_text_len, seq_len, batch_size=1)

source

Compile the transformer as separate pre/block/post graphs.

Block graphs are compiled with symbolic seq_len and concrete batch_size / seq_text_len. Pre/post graphs use symbolic spatial dims.

Parameters:

  • seq_text_len (int)
  • seq_len (int)
  • batch_size (int)

Return type:

Callable[[…], Any]

prepare_state_dict()

prepare_state_dict()

source

Materialize the remapped state dict without compiling graphs.

Return type:

dict[str, Any]

reload_model_weights()

reload_model_weights(state_dict=None)

source

Reload weights into already-compiled models for MoE weight switching.

Parameters:

state_dict (dict[str, Any] | None)

Return type:

None