Python module
max.pipelines.architectures.wan
Wan diffusion architecture for video generation.
BlockLevelModelβ
class max.pipelines.architectures.wan.BlockLevelModel(pre, post, *, combined_blocks)
Bases: object
Executes transformer forward pass as pre -> combined blocks -> post.
All transformer blocks are compiled into a single Model graph,
so the runtime allocates one shared workspace.
WanArchConfigβ
class max.pipelines.architectures.wan.WanArchConfig(*, pipeline_config)
Bases: ArchConfig
Pipeline-level config for Wan (implements ArchConfig; no KV cache).
-
Parameters:
-
pipeline_config (PipelineConfig)
get_max_seq_len()β
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The pipeline configuration.
- model_config (MAXModelConfig | None) β The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
pipeline_configβ
pipeline_config: PipelineConfig
WanConfigβ
class max.pipelines.architectures.wan.WanConfig(*, config_file=None, section_name=None, patch_size=(1, 2, 2), num_attention_heads=40, attention_head_dim=128, in_channels=16, out_channels=16, text_dim=4096, freq_dim=256, ffn_dim=13824, num_layers=40, cross_attn_norm=True, qk_norm='rms_norm_across_heads', eps=1e-06, image_dim=None, added_kv_proj_dim=None, rope_max_seq_len=1024, pos_embed_seq_len=None, dtype=bfloat16, device=<factory>)
Bases: WanConfigBase
-
Parameters:
-
- config_file (str | None)
- section_name (str | None)
- patch_size (tuple[int, int, int])
- num_attention_heads (int)
- attention_head_dim (int)
- in_channels (int)
- out_channels (int)
- text_dim (int)
- freq_dim (int)
- ffn_dim (int)
- num_layers (int)
- cross_attn_norm (bool)
- qk_norm (str | None)
- eps (float)
- image_dim (int | None)
- added_kv_proj_dim (int | None)
- rope_max_seq_len (int)
- pos_embed_seq_len (int | None)
- dtype (DType)
- device (DeviceRef)
generate()β
static generate(config_dict, encoding, devices)
model_configβ
model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'strict': False}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
WanConfigBaseβ
class max.pipelines.architectures.wan.WanConfigBase(*, config_file=None, section_name=None, patch_size=(1, 2, 2), num_attention_heads=40, attention_head_dim=128, in_channels=16, out_channels=16, text_dim=4096, freq_dim=256, ffn_dim=13824, num_layers=40, cross_attn_norm=True, qk_norm='rms_norm_across_heads', eps=1e-06, image_dim=None, added_kv_proj_dim=None, rope_max_seq_len=1024, pos_embed_seq_len=None, dtype=bfloat16, device=<factory>)
Bases: MAXModelConfigBase
-
Parameters:
-
- config_file (str | None)
- section_name (str | None)
- patch_size (tuple[int, int, int])
- num_attention_heads (int)
- attention_head_dim (int)
- in_channels (int)
- out_channels (int)
- text_dim (int)
- freq_dim (int)
- ffn_dim (int)
- num_layers (int)
- cross_attn_norm (bool)
- qk_norm (str | None)
- eps (float)
- image_dim (int | None)
- added_kv_proj_dim (int | None)
- rope_max_seq_len (int)
- pos_embed_seq_len (int | None)
- dtype (DType)
- device (DeviceRef)
added_kv_proj_dimβ
attention_head_dimβ
attention_head_dim: int
cross_attn_normβ
cross_attn_norm: bool
deviceβ
device: DeviceRef
dtypeβ
dtype: DType
epsβ
eps: float
ffn_dimβ
ffn_dim: int
freq_dimβ
freq_dim: int
image_dimβ
in_channelsβ
in_channels: int
model_configβ
model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'strict': False}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
num_attention_headsβ
num_attention_heads: int
num_layersβ
num_layers: int
out_channelsβ
out_channels: int
patch_sizeβ
pos_embed_seq_lenβ
qk_normβ
rope_max_seq_lenβ
rope_max_seq_len: int
text_dimβ
text_dim: int
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!