Skip to main content

Python module

max.pipelines.architectures.flux2

FLUX.2 diffusion architecture for image generation.

Flux2ArchConfig​

class max.pipelines.architectures.flux2.Flux2ArchConfig(*, pipeline_config)

source

Bases: ArchConfig

Pipeline-level config for Flux2 (implements ArchConfig; no KV cache).

Parameters:

pipeline_config (PipelineConfig)

get_max_seq_len()​

get_max_seq_len()

source

Returns the default maximum sequence length for the model.

Subclasses should determine whether this value can be overridden by setting the --max-length (pipeline_config.model.max_length) flag.

Return type:

int

initialize()​

classmethod initialize(pipeline_config, model_config=None)

source

Initialize the config from a PipelineConfig.

Parameters:

  • pipeline_config (PipelineConfig) – The pipeline configuration.
  • model_config (MAXModelConfig | None) – The model configuration to read from. When None (the default), pipeline_config.model is used. Pass an explicit config (e.g. pipeline_config.draft_model) to initialize the arch config for a different model.

Return type:

Self

pipeline_config​

pipeline_config: PipelineConfig

source

Flux2BlockQuant​

class max.pipelines.architectures.flux2.Flux2BlockQuant(attn_qkv=None, attn_out=None, added_attn_qkv=None, added_attn_out=None, ff=None, ff_context=None)

source

Bases: object

Per-Linear NVFP4 quant plan for one Flux2TransformerBlock.

Each field is the QuantConfig to apply to the corresponding Linear, or None to leave it in BF16. Construct via resolve() from the block index + the checkpoint’s nvfp4_layers_bfl metadata.

Parameters:

added_attn_out​

added_attn_out: QuantConfig | None = None

source

Added-attn to_add_out (BFL txt_attn.proj).

added_attn_qkv​

added_attn_qkv: QuantConfig | None = None

source

Added-attn add_{q,k,v}_proj (BFL txt_attn.qkv).

attn_out​

attn_out: QuantConfig | None = None

source

Self-attn to_out[0] (BFL img_attn.proj).

attn_qkv​

attn_qkv: QuantConfig | None = None

source

Self-attn to_q/to_k/to_v (BFL img_attn.qkv).

ff​

ff: QuantConfig | None = None

source

Image FF ff.linear_{in,out} (BFL img_mlp.{0,2}).

ff_context​

ff_context: QuantConfig | None = None

source

Text FF ff_context.linear_{in,out} (BFL txt_mlp.{0,2}).

resolve()​

classmethod resolve(block_idx, base, nvfp4_layers_bfl)

source

Resolve the per-Linear plan for double_blocks.{block_idx}.

BFL’s NVFP4 exports embed _quantization_metadata listing each Linear that was quantized; layers absent from the list stay BF16. When that metadata isn’t available (non-NVFP4 runs, or legacy checkpoints without metadata), fall back to the dev-NVFP4 uniform pattern: img-side attn + both MLPs quantized, txt-side attn BF16.

Parameters:

Return type:

Flux2BlockQuant

Flux2Config​

class max.pipelines.architectures.flux2.Flux2Config(*, config_file=None, section_name=None, patch_size=1, in_channels=128, out_channels=None, num_layers=8, num_single_layers=48, attention_head_dim=128, num_attention_heads=48, joint_attention_dim=15360, timestep_guidance_channels=256, mlp_ratio=3.0, axes_dims_rope=(32, 32, 32, 32), rope_theta=2000, eps=1e-06, guidance_embeds=True, dtype=bfloat16, device=<factory>, quant_config=None, nvfp4_layers_bfl=<factory>)

source

Bases: MAXModelConfigBase

Parameters:

attention_head_dim​

attention_head_dim: int

source

axes_dims_rope​

axes_dims_rope: tuple[int, ...]

source

device​

device: DeviceRef

source

dtype​

dtype: DType

source

eps​

eps: float

source

guidance_embeds​

guidance_embeds: bool

source

If False (Klein/distilled), no guidance embedder weights are expected.

in_channels​

in_channels: int

source

initialize_from_config()​

classmethod initialize_from_config(config_dict, encoding, devices)

source

Parameters:

  • config_dict (dict[str, Any])
  • encoding (max.pipelines.lib.config.SupportedEncoding)
  • devices (list[Device])

Return type:

Self

joint_attention_dim​

joint_attention_dim: int

source

mlp_ratio​

mlp_ratio: float

source

model_config​

model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'strict': False}

source

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

num_attention_heads​

num_attention_heads: int

source

num_layers​

num_layers: int

source

num_single_layers​

num_single_layers: int

source

nvfp4_layers_bfl​

nvfp4_layers_bfl: frozenset[str]

source

BFL-named layers that the checkpoint tagged nvfp4 in its _quantization_metadata, e.g. double_blocks.0.img_attn.qkv. Empty for non-NVFP4 runs OR for legacy checkpoints without metadata, in which case the model falls back to the dev-NVFP4 uniform pattern (img-side attn + all MLPs quantized, txt-side attn BF16).

out_channels​

out_channels: int | None

source

patch_size​

patch_size: int

source

quant_config​

quant_config: QuantConfig | None

source

NVFP4 quantization config, populated when encoding is float4_e2m1fnx2.

rope_theta​

rope_theta: int

source

timestep_guidance_channels​

timestep_guidance_channels: int

source