Python module
max.pipelines.architectures.wan
Wan diffusion architecture for video generation.
WanArchConfig
class max.pipelines.architectures.wan.WanArchConfig(*, pipeline_config)
Bases: ArchConfig
Pipeline-level config for Wan (implements ArchConfig; no KV cache).
-
Parameters:
-
pipeline_config (PipelineConfig)
get_max_seq_len()
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
initialize()
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) – The pipeline configuration.
- model_config (MAXModelConfig | None) – The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
pipeline_config
pipeline_config: PipelineConfig
WanContext
class max.pipelines.architectures.wan.WanContext(*, tokens, request_id=<factory>, model_name='', mask=None, tokens_2=None, negative_tokens=None, negative_mask=None, negative_tokens_2=None, explicit_negative_prompt=False, timesteps=<factory>, sigmas=<factory>, latents=<factory>, latent_image_ids=<factory>, text_ids=<factory>, negative_text_ids=<factory>, height=1024, width=1024, num_inference_steps=50, guidance_scale=3.5, true_cfg_scale=1.0, strength=0.6, cfg_normalization=False, cfg_truncation=1.0, num_warmup_steps=0, num_images_per_prompt=1, input_image=None, input_images=None, prompt_images=None, vae_condition_images=None, output_format='jpeg', residual_threshold=None, status=GenerationStatus.ACTIVE, num_frames=None, guidance_scale_2=None, step_coefficients=None, boundary_timestep=None)
Bases: PixelContext
Pixel generation context with Wan-specific video/MoE fields.
-
Parameters:
-
- tokens (TokenBuffer)
- request_id (RequestID)
- model_name (str)
- mask (ndarray[tuple[Any, ...], dtype[bool]] | None)
- tokens_2 (TokenBuffer | None)
- negative_tokens (TokenBuffer | None)
- negative_mask (ndarray[tuple[Any, ...], dtype[bool]] | None)
- negative_tokens_2 (TokenBuffer | None)
- explicit_negative_prompt (bool)
- timesteps (ndarray[tuple[Any, ...], dtype[float32]])
- sigmas (ndarray[tuple[Any, ...], dtype[float32]])
- latents (ndarray[tuple[Any, ...], dtype[float32]])
- latent_image_ids (ndarray[tuple[Any, ...], dtype[float32]])
- text_ids (ndarray[tuple[Any, ...], dtype[int64]])
- negative_text_ids (ndarray[tuple[Any, ...], dtype[int64]])
- height (int)
- width (int)
- num_inference_steps (int)
- guidance_scale (float)
- true_cfg_scale (float)
- strength (float)
- cfg_normalization (bool)
- cfg_truncation (float)
- num_warmup_steps (int)
- num_images_per_prompt (int)
- input_image (ndarray[tuple[Any, ...], dtype[uint8]] | None)
- input_images (list[ndarray[tuple[Any, ...], dtype[uint8]]] | None)
- prompt_images (list[ndarray[tuple[Any, ...], dtype[uint8]]] | None)
- vae_condition_images (list[ndarray[tuple[Any, ...], dtype[uint8]]] | None)
- output_format (str)
- residual_threshold (float | None)
- status (GenerationStatus)
- num_frames (int | None)
- guidance_scale_2 (float | None)
- step_coefficients (ndarray[tuple[Any, ...], dtype[float32]] | None)
- boundary_timestep (float | None)
boundary_timestep
Timestep threshold for switching between high/low noise experts.
guidance_scale_2
Secondary guidance scale for low-noise expert (MoE models).
num_frames
Number of frames for video generation.
step_coefficients
step_coefficients: npt.NDArray[np.float32] | None = None
Pre-computed scheduler step coefficients.
WanExecutor
class max.pipelines.architectures.wan.WanExecutor(manifest, session, runtime_config)
Bases: PipelineExecutor[WanContext, WanExecutorInputs, WanExecutorOutputs]
Wan video diffusion pipeline executor.
Implements the PipelineExecutor interface for Wan video
generation, wiring together the sub-components (text encoder,
transformer, VAE) through the tensor-in/tensor-out executor contract.
-
Parameters:
-
- manifest (ModelManifest)
- session (InferenceSession)
- runtime_config (PipelineRuntimeConfig)
default_num_inference_steps
default_num_inference_steps: int = 50
execute()
execute(inputs)
Runs the compiled model graph on the provided inputs.
The inputs should be the TensorStruct produced by
prepare_inputs(), passed through without transformation.
The caller may have transferred them to a different device between
preparation and execution.
The returned struct may contain device-resident tensors. The
caller is responsible for any host transfer needed for
post-processing, using .to(device) on the returned
TensorStruct.
-
Parameters:
-
inputs (WanExecutorInputs) – The prepared graph inputs, as returned by
prepare_inputs(). -
Returns:
-
A
TensorStructcontaining the model outputs. -
Return type:
-
WanExecutorOutputs
prepare_inputs()
prepare_inputs(contexts)
Converts a batch of contexts into a structured tensor container ready for graph execution.
Each context in the batch represents a single request or work item. The implementation is responsible for collating, tokenizing, or otherwise transforming the batch into the tensor format expected by the compiled graph.
The returned struct may contain tensors on any device. The caller
is responsible for transferring them to the appropriate device
before passing them to execute(), using .to(device)
on the returned TensorStruct.
-
Parameters:
-
contexts (list[WanContext]) – A list of context objects representing the batch of requests to prepare inputs for.
-
Returns:
-
A
TensorStructcontaining the prepared graph inputs for the batch. -
Return type:
-
WanExecutorInputs
WanTokenizer
class max.pipelines.architectures.wan.WanTokenizer(model_path, pipeline_config, subfolder, *, subfolder_2=None, revision=None, max_length=None, secondary_max_length=None, trust_remote_code=False, default_num_inference_steps=50, **unused_kwargs)
Bases: PixelGenerationTokenizer
Wan-specific tokenizer that produces WanContext with video/MoE fields.
-
Parameters:
new_context()
async new_context(request, input_image=None)
Creates a new PixelContext object, leveraging necessary information from OpenResponsesRequest.
-
Parameters:
-
- request (OpenResponsesRequest)
- input_image (Image | None)
-
Return type:
WanTransformerModel
class max.pipelines.architectures.wan.WanTransformerModel(config, encoding, devices, weights, session=None, eager_load=True)
Bases: ComponentModel
MAX-native Wan DiT interface with block-level compilation.
Each block is compiled independently so only one block’s workspace is live at any time, keeping peak VRAM low.
-
Parameters:
-
- config (dict[str, Any])
- encoding (SupportedEncoding)
- devices (list[Device])
- weights (Weights)
- session (InferenceSession | None)
- eager_load (bool)
compute_rope()
compute_rope(num_frames, height, width)
Compute 3D RoPE cos/sin tensors and transfer to device.
load_model()
load_model(*, seq_text_len, seq_len, batch_size=1)
Compile the transformer as separate pre/block/post graphs.
Block graphs are compiled with symbolic seq_len and concrete
batch_size / seq_text_len. Pre/post graphs use symbolic
spatial dims.
prepare_state_dict()
prepare_state_dict()
Materialize the remapped state dict without compiling graphs.
reload_model_weights()
reload_model_weights(state_dict=None)
Reload weights into already-compiled models for MoE weight switching.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!