Skip to main content

Python module

float8_config

Float8 configuration parsing utilities for models.

Float8Config

class max.nn.float8_config.Float8Config(input_scale, weight_scale, mlp_in_float8, attn_qkv_in_float8, embedding_output_dtype=None, quant_method=None)

Configures float8 quantization settings for a layer or model section.

Parameters:

attn_qkv_in_float8

attn_qkv_in_float8: set[int]

Set of layer indices with attention QKV projections in float8.

QKV projections are considered to be either “all quantized” or all not quantized per layer. So either all of {q,k,v,o}_proj are float8, or all bfloat16.

embedding_output_dtype

embedding_output_dtype: DType | None = None

The DType of the output from the embedding layer.

input_scale

input_scale: Float8InputScaleSpec

Float8InputScaleSpec for input activation scaling.

is_dynamic

property is_dynamic: bool

Returns True if this input scale is dynamic.

is_static

property is_static: bool

Returns True if this input scale is static.

mlp_in_float8

mlp_in_float8: set[int]

Set of layer indices with MLPs in float8.

MLPs are considered to be either “all quantized” or all not quantized per layer. So either all of gate proj, down proj, and up proj are float8, or all bfloat16.

quant_method

quant_method: str | None = None

The quantization method used (e.g., “fbgemm_fp8”).

weight_scale

weight_scale: Float8WeightScaleSpec

Float8WeightScaleSpec for weight scaling.

Float8InputScaleSpec

class max.nn.float8_config.Float8InputScaleSpec(granularity, origin, dtype, activation_scale_ub=None, block_size=None)

Specifies how input activations are scaled for float8 quantization.

Parameters:

activation_scale_ub

activation_scale_ub: float | None = None

An optional upper bound for dynamic activation scaling.

block_size

block_size: tuple[int, int] | None = None

The tuple[int, int] of the block size for block-wise scaling.

dtype

dtype: DType

The DType of the input scale factor(s).

granularity

granularity: Float8ScaleGranularity

The Float8ScaleGranularity of the input scale factor application.

is_block

property is_block: bool

Whether the input scale granularity is block-wise.

is_colwise

property is_colwise: bool

Whether the input scale granularity is column-wise.

is_rowwise

property is_rowwise: bool

Whether the input scale granularity is row-wise.

is_tensor

property is_tensor: bool

Whether the input scale granularity is per-tensor.

origin

origin: Float8ScaleOrigin

The Float8ScaleOrigin (static or dynamic) of the input scale factor.

Float8ScaleGranularity

class max.nn.float8_config.Float8ScaleGranularity(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Specifies the granularity of the quantization scale factor.

Determines whether a scale factor applies per-tensor, per-row (often for weights), per-column, or per-block within a tensor.

BLOCK

BLOCK = 'block'

Per-block scaling.

COLWISE

COLWISE = 'colwise'

Per-column scaling.

ROWWISE

ROWWISE = 'rowwise'

Per-row scaling.

TENSOR

TENSOR = 'tensor'

Per-tensor scaling.

Float8ScaleOrigin

class max.nn.float8_config.Float8ScaleOrigin(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Specifies whether the quantization scale is determined statically or dynamically.

DYNAMIC

DYNAMIC = 'dynamic'

Scales are computed at runtime based on the input data.

STATIC

STATIC = 'static'

Scales are pre-computed and loaded with the model weights.

Float8WeightScaleSpec

class max.nn.float8_config.Float8WeightScaleSpec(granularity, dtype, block_size=None)

Specifies how weights are scaled for float8 quantization.

Parameters:

block_size

block_size: tuple[int, int] | None = None

The tuple[int, int] of the block size for block-wise scaling.

dtype

dtype: DType

The DType of the weight scale factor(s).

granularity

granularity: Float8ScaleGranularity

The Float8ScaleGranularity of the weight scale factor application.

is_block

property is_block: bool

Whether the weight scale granularity is block-wise.

is_colwise

property is_colwise: bool

Whether the weight scale granularity is column-wise.

is_rowwise

property is_rowwise: bool

Whether the weight scale granularity is row-wise.

is_tensor

property is_tensor: bool

Whether the weight scale granularity is per-tensor.

parse_float8_config()

max.nn.float8_config.parse_float8_config(huggingface_config, state_dict, dtype, state_dict_name_prefix='', ignored_modules_prefix='model.')

Parses Float8Config from HuggingFace config by dispatching to format-specific parsers.

Dispatches to the appropriate format-specific parser based on the quantization method in the HuggingFace config.

Parameters:

Return type:

Float8Config | None

Was this page helpful?