Skip to main content

Python module

max.pipelines.architectures.wan

Wan diffusion architecture for video generation.

BlockLevelModel​

class max.pipelines.architectures.wan.BlockLevelModel(pre, blocks, post, *, combined_blocks=None)

source

Bases: object

Executes transformer forward pass as pre -> N blocks -> post.

Supports two modes:

  • Combined (combined_blocks is set): All transformer blocks are compiled into a single Model graph, so the runtime allocates one shared workspace.
  • Per-block (blocks list): Each block is a separate Model. Kept for backwards compatibility and MoE weight-swap.

Parameters:

WanArchConfig​

class max.pipelines.architectures.wan.WanArchConfig(*, pipeline_config)

source

Bases: ArchConfig

Pipeline-level config for Wan (implements ArchConfig; no KV cache).

Parameters:

pipeline_config (PipelineConfig)

get_max_seq_len()​

get_max_seq_len()

source

Returns the default maximum sequence length for the model.

Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

Return type:

int

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

pipeline_config​

pipeline_config: PipelineConfig

source

WanConfig​

class max.pipelines.architectures.wan.WanConfig(*, config_file=None, section_name=None, patch_size=(1, 2, 2), num_attention_heads=40, attention_head_dim=128, in_channels=16, out_channels=16, text_dim=4096, freq_dim=256, ffn_dim=13824, num_layers=40, cross_attn_norm=True, qk_norm='rms_norm_across_heads', eps=1e-06, image_dim=None, added_kv_proj_dim=None, rope_max_seq_len=1024, pos_embed_seq_len=None, dtype=bfloat16, device=<factory>)

source

Bases: WanConfigBase

Parameters:

  • config_file (str | None)
  • section_name (str | None)
  • patch_size (tuple[int, int, int])
  • num_attention_heads (int)
  • attention_head_dim (int)
  • in_channels (int)
  • out_channels (int)
  • text_dim (int)
  • freq_dim (int)
  • ffn_dim (int)
  • num_layers (int)
  • cross_attn_norm (bool)
  • qk_norm (str | None)
  • eps (float)
  • image_dim (int | None)
  • added_kv_proj_dim (int | None)
  • rope_max_seq_len (int)
  • pos_embed_seq_len (int | None)
  • dtype (DType)
  • device (DeviceRef)

generate()​

static generate(config_dict, encoding, devices)

source

Parameters:

  • config_dict (dict[str, Any])
  • encoding (Literal['float32', 'bfloat16', 'q4_k', 'q4_0', 'q6_k', 'float8_e4m3fn', 'float4_e2m1fnx2', 'gptq'])
  • devices (list[Device])

Return type:

WanConfig

model_config​

model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'strict': False}

source

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

WanConfigBase​

class max.pipelines.architectures.wan.WanConfigBase(*, config_file=None, section_name=None, patch_size=(1, 2, 2), num_attention_heads=40, attention_head_dim=128, in_channels=16, out_channels=16, text_dim=4096, freq_dim=256, ffn_dim=13824, num_layers=40, cross_attn_norm=True, qk_norm='rms_norm_across_heads', eps=1e-06, image_dim=None, added_kv_proj_dim=None, rope_max_seq_len=1024, pos_embed_seq_len=None, dtype=bfloat16, device=<factory>)

source

Bases: MAXModelConfigBase

Parameters:

  • config_file (str | None)
  • section_name (str | None)
  • patch_size (tuple[int, int, int])
  • num_attention_heads (int)
  • attention_head_dim (int)
  • in_channels (int)
  • out_channels (int)
  • text_dim (int)
  • freq_dim (int)
  • ffn_dim (int)
  • num_layers (int)
  • cross_attn_norm (bool)
  • qk_norm (str | None)
  • eps (float)
  • image_dim (int | None)
  • added_kv_proj_dim (int | None)
  • rope_max_seq_len (int)
  • pos_embed_seq_len (int | None)
  • dtype (DType)
  • device (DeviceRef)

added_kv_proj_dim​

added_kv_proj_dim: int | None

source

attention_head_dim​

attention_head_dim: int

source

cross_attn_norm​

cross_attn_norm: bool

source

device​

device: DeviceRef

source

dtype​

dtype: DType

source

eps​

eps: float

source

ffn_dim​

ffn_dim: int

source

freq_dim​

freq_dim: int

source

image_dim​

image_dim: int | None

source

in_channels​

in_channels: int

source

model_config​

model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'strict': False}

source

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

num_attention_heads​

num_attention_heads: int

source

num_layers​

num_layers: int

source

out_channels​

out_channels: int

source

patch_size​

patch_size: tuple[int, int, int]

source

pos_embed_seq_len​

pos_embed_seq_len: int | None

source

qk_norm​

qk_norm: str | None

source

rope_max_seq_len​

rope_max_seq_len: int

source

text_dim​

text_dim: int

source

WanTransformerModel​

class max.pipelines.architectures.wan.WanTransformerModel(config, encoding, devices, weights, session=None, eager_load=True)

source

Bases: ComponentModel

MAX-native Wan DiT interface with block-level compilation.

Each block is compiled independently so only one block’s workspace is live at any time, keeping peak VRAM low.

Parameters:

compute_rope()​

compute_rope(num_frames, height, width)

source

Compute 3D RoPE cos/sin tensors and transfer to device.

Parameters:

Return type:

tuple[Buffer, Buffer]

load_model()​

load_model(*, seq_text_len, seq_len, batch_size=1)

source

Compile the transformer as separate pre/block/post graphs.

Block graphs are compiled with symbolic seq_len and concrete batch_size / seq_text_len. Pre/post graphs use symbolic spatial dims.

Parameters:

  • seq_text_len (int)
  • seq_len (int)
  • batch_size (int)

Return type:

Callable[[…], Any]

prepare_state_dict()​

prepare_state_dict()

source

Materialize the remapped state dict without compiling graphs.

Return type:

dict[str, Any]

reload_model_weights()​

reload_model_weights(state_dict=None)

source

Reload weights into already-compiled models for MoE weight switching.

Parameters:

state_dict (dict[str, Any] | None)

Return type:

None