Python module
max.pipelines.architectures.wan
Wan diffusion architecture for video generation.
BlockLevelModelβ
class max.pipelines.architectures.wan.BlockLevelModel(pre, blocks, post, *, combined_blocks=None)
Bases: object
Executes transformer forward pass as pre -> N blocks -> post.
Supports two modes:
- Combined (
combined_blocksis set): All transformer blocks are compiled into a singleModelgraph, so the runtime allocates one shared workspace. - Per-block (
blockslist): Each block is a separateModel. Kept for backwards compatibility and MoE weight-swap.
WanArchConfigβ
class max.pipelines.architectures.wan.WanArchConfig(*, pipeline_config)
Bases: ArchConfig
Pipeline-level config for Wan (implements ArchConfig; no KV cache).
-
Parameters:
-
pipeline_config (PipelineConfig)
get_max_seq_len()β
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The pipeline configuration.
- model_config (MAXModelConfig | None) β The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
pipeline_configβ
pipeline_config: PipelineConfig
WanConfigβ
class max.pipelines.architectures.wan.WanConfig(*, config_file=None, section_name=None, patch_size=(1, 2, 2), num_attention_heads=40, attention_head_dim=128, in_channels=16, out_channels=16, text_dim=4096, freq_dim=256, ffn_dim=13824, num_layers=40, cross_attn_norm=True, qk_norm='rms_norm_across_heads', eps=1e-06, image_dim=None, added_kv_proj_dim=None, rope_max_seq_len=1024, pos_embed_seq_len=None, dtype=bfloat16, device=<factory>)
Bases: WanConfigBase
-
Parameters:
-
- config_file (str | None)
- section_name (str | None)
- patch_size (tuple[int, int, int])
- num_attention_heads (int)
- attention_head_dim (int)
- in_channels (int)
- out_channels (int)
- text_dim (int)
- freq_dim (int)
- ffn_dim (int)
- num_layers (int)
- cross_attn_norm (bool)
- qk_norm (str | None)
- eps (float)
- image_dim (int | None)
- added_kv_proj_dim (int | None)
- rope_max_seq_len (int)
- pos_embed_seq_len (int | None)
- dtype (DType)
- device (DeviceRef)
generate()β
static generate(config_dict, encoding, devices)
model_configβ
model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'strict': False}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
WanConfigBaseβ
class max.pipelines.architectures.wan.WanConfigBase(*, config_file=None, section_name=None, patch_size=(1, 2, 2), num_attention_heads=40, attention_head_dim=128, in_channels=16, out_channels=16, text_dim=4096, freq_dim=256, ffn_dim=13824, num_layers=40, cross_attn_norm=True, qk_norm='rms_norm_across_heads', eps=1e-06, image_dim=None, added_kv_proj_dim=None, rope_max_seq_len=1024, pos_embed_seq_len=None, dtype=bfloat16, device=<factory>)
Bases: MAXModelConfigBase
-
Parameters:
-
- config_file (str | None)
- section_name (str | None)
- patch_size (tuple[int, int, int])
- num_attention_heads (int)
- attention_head_dim (int)
- in_channels (int)
- out_channels (int)
- text_dim (int)
- freq_dim (int)
- ffn_dim (int)
- num_layers (int)
- cross_attn_norm (bool)
- qk_norm (str | None)
- eps (float)
- image_dim (int | None)
- added_kv_proj_dim (int | None)
- rope_max_seq_len (int)
- pos_embed_seq_len (int | None)
- dtype (DType)
- device (DeviceRef)
added_kv_proj_dimβ
attention_head_dimβ
attention_head_dim: int
cross_attn_normβ
cross_attn_norm: bool
deviceβ
device: DeviceRef
dtypeβ
dtype: DType
epsβ
eps: float
ffn_dimβ
ffn_dim: int
freq_dimβ
freq_dim: int
image_dimβ
in_channelsβ
in_channels: int
model_configβ
model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'strict': False}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
num_attention_headsβ
num_attention_heads: int
num_layersβ
num_layers: int
out_channelsβ
out_channels: int
patch_sizeβ
pos_embed_seq_lenβ
qk_normβ
rope_max_seq_lenβ
rope_max_seq_len: int
text_dimβ
text_dim: int
WanTransformerModelβ
class max.pipelines.architectures.wan.WanTransformerModel(config, encoding, devices, weights, session=None, eager_load=True)
Bases: ComponentModel
MAX-native Wan DiT interface with block-level compilation.
Each block is compiled independently so only one blockβs workspace is live at any time, keeping peak VRAM low.
-
Parameters:
-
- config (dict[str, Any])
- encoding (SupportedEncoding)
- devices (list[Device])
- weights (Weights)
- session (InferenceSession | None)
- eager_load (bool)
compute_rope()β
compute_rope(num_frames, height, width)
Compute 3D RoPE cos/sin tensors and transfer to device.
load_model()β
load_model(*, seq_text_len, seq_len, batch_size=1)
Compile the transformer as separate pre/block/post graphs.
Block graphs are compiled with symbolic seq_len and concrete
batch_size / seq_text_len. Pre/post graphs use symbolic
spatial dims.
prepare_state_dict()β
prepare_state_dict()
Materialize the remapped state dict without compiling graphs.
reload_model_weights()β
reload_model_weights(state_dict=None)
Reload weights into already-compiled models for MoE weight switching.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!