Python module
max.pipelines.architectures.flux2
FLUX.2 diffusion architecture for image generation.
Flux2ArchConfigβ
class max.pipelines.architectures.flux2.Flux2ArchConfig(*, pipeline_config)
Bases: ArchConfig
Pipeline-level config for Flux2 (implements ArchConfig; no KV cache).
-
Parameters:
-
pipeline_config (PipelineConfig)
get_max_seq_len()β
get_max_seq_len()
Returns the default maximum sequence length for the model.
Subclasses should determine whether this value can be overridden by
setting the --max-length (pipeline_config.model.max_length) flag.
-
Return type:
initialize()β
classmethod initialize(pipeline_config, model_config=None)
Initialize the config from a PipelineConfig.
-
Parameters:
-
- pipeline_config (PipelineConfig) β The pipeline configuration.
- model_config (MAXModelConfig | None) β The model configuration to read from. When
None(the default),pipeline_config.modelis used. Pass an explicit config (e.g.pipeline_config.draft_model) to initialize the arch config for a different model.
-
Return type:
pipeline_configβ
pipeline_config: PipelineConfig
Flux2BlockQuantβ
class max.pipelines.architectures.flux2.Flux2BlockQuant(attn_qkv=None, attn_out=None, added_attn_qkv=None, added_attn_out=None, ff=None, ff_context=None)
Bases: object
Per-Linear NVFP4 quant plan for one Flux2TransformerBlock.
Each field is the QuantConfig to apply to the corresponding
Linear, or None to leave it in BF16. Construct via resolve()
from the block index + the checkpointβs nvfp4_layers_bfl metadata.
-
Parameters:
-
- attn_qkv (QuantConfig | None)
- attn_out (QuantConfig | None)
- added_attn_qkv (QuantConfig | None)
- added_attn_out (QuantConfig | None)
- ff (QuantConfig | None)
- ff_context (QuantConfig | None)
added_attn_outβ
added_attn_out: QuantConfig | None = None
Added-attn to_add_out (BFL txt_attn.proj).
added_attn_qkvβ
added_attn_qkv: QuantConfig | None = None
Added-attn add_{q,k,v}_proj (BFL txt_attn.qkv).
attn_outβ
attn_out: QuantConfig | None = None
Self-attn to_out[0] (BFL img_attn.proj).
attn_qkvβ
attn_qkv: QuantConfig | None = None
Self-attn to_q/to_k/to_v (BFL img_attn.qkv).
ffβ
ff: QuantConfig | None = None
Image FF ff.linear_{in,out} (BFL img_mlp.{0,2}).
ff_contextβ
ff_context: QuantConfig | None = None
Text FF ff_context.linear_{in,out} (BFL txt_mlp.{0,2}).
resolve()β
classmethod resolve(block_idx, base, nvfp4_layers_bfl)
Resolve the per-Linear plan for double_blocks.{block_idx}.
BFLβs NVFP4 exports embed _quantization_metadata listing each
Linear that was quantized; layers absent from the list stay BF16.
When that metadata isnβt available (non-NVFP4 runs, or legacy
checkpoints without metadata), fall back to the dev-NVFP4 uniform
pattern: img-side attn + both MLPs quantized, txt-side attn BF16.
-
Parameters:
-
- block_idx (int)
- base (QuantConfig | None)
- nvfp4_layers_bfl (frozenset[str])
-
Return type:
Flux2Configβ
class max.pipelines.architectures.flux2.Flux2Config(*, config_file=None, section_name=None, patch_size=1, in_channels=128, out_channels=None, num_layers=8, num_single_layers=48, attention_head_dim=128, num_attention_heads=48, joint_attention_dim=15360, timestep_guidance_channels=256, mlp_ratio=3.0, axes_dims_rope=(32, 32, 32, 32), rope_theta=2000, eps=1e-06, guidance_embeds=True, dtype=bfloat16, device=<factory>, quant_config=None, nvfp4_layers_bfl=<factory>)
Bases: MAXModelConfigBase
-
Parameters:
-
- config_file (str | None)
- section_name (str | None)
- patch_size (int)
- in_channels (int)
- out_channels (int | None)
- num_layers (int)
- num_single_layers (int)
- attention_head_dim (int)
- num_attention_heads (int)
- joint_attention_dim (int)
- timestep_guidance_channels (int)
- mlp_ratio (float)
- axes_dims_rope (tuple[int, ...])
- rope_theta (int)
- eps (float)
- guidance_embeds (bool)
- dtype (DType)
- device (DeviceRef)
- quant_config (QuantConfig | None)
- nvfp4_layers_bfl (frozenset[str])
attention_head_dimβ
attention_head_dim: int
axes_dims_ropeβ
deviceβ
device: DeviceRef
dtypeβ
dtype: DType
epsβ
eps: float
guidance_embedsβ
guidance_embeds: bool
If False (Klein/distilled), no guidance embedder weights are expected.
in_channelsβ
in_channels: int
initialize_from_config()β
classmethod initialize_from_config(config_dict, encoding, devices)
joint_attention_dimβ
joint_attention_dim: int
mlp_ratioβ
mlp_ratio: float
model_configβ
model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'strict': False}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
num_attention_headsβ
num_attention_heads: int
num_layersβ
num_layers: int
num_single_layersβ
num_single_layers: int
nvfp4_layers_bflβ
BFL-named layers that the checkpoint tagged nvfp4 in its
_quantization_metadata, e.g. double_blocks.0.img_attn.qkv.
Empty for non-NVFP4 runs OR for legacy checkpoints without metadata,
in which case the model falls back to the dev-NVFP4 uniform pattern
(img-side attn + all MLPs quantized, txt-side attn BF16).
out_channelsβ
patch_sizeβ
patch_size: int
quant_configβ
quant_config: QuantConfig | None
NVFP4 quantization config, populated when encoding is float4_e2m1fnx2.
rope_thetaβ
rope_theta: int
timestep_guidance_channelsβ
timestep_guidance_channels: int
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!