Python module
float8_config
Float8 configuration parsing utilities for models.
Float8Config
class max.nn.float8_config.Float8Config(input_scale, weight_scale, mlp_in_float8, attn_qkv_in_float8, embedding_output_dtype=None, quant_method=None)
Configures float8 quantization settings for a layer or model section.
-
Parameters:
-
- input_scale (Float8InputScaleSpec)
- weight_scale (Float8WeightScaleSpec)
- mlp_in_float8 (set[int])
- attn_qkv_in_float8 (set[int])
- embedding_output_dtype (DType | None)
- quant_method (str | None)
attn_qkv_in_float8
Set of layer indices with attention QKV projections in float8.
QKV projections are considered to be either “all quantized” or all not quantized per layer. So either all of {q,k,v,o}_proj are float8, or all bfloat16.
embedding_output_dtype
The DType of the output from the embedding layer.
input_scale
input_scale: Float8InputScaleSpec
Float8InputScaleSpec for input activation scaling.
is_dynamic
property is_dynamic: bool
Returns True if this input scale is dynamic.
is_static
property is_static: bool
Returns True if this input scale is static.
mlp_in_float8
Set of layer indices with MLPs in float8.
MLPs are considered to be either “all quantized” or all not quantized per layer. So either all of gate proj, down proj, and up proj are float8, or all bfloat16.
quant_method
The quantization method used (e.g., “fbgemm_fp8”).
weight_scale
weight_scale: Float8WeightScaleSpec
Float8WeightScaleSpec for weight scaling.
Float8InputScaleSpec
class max.nn.float8_config.Float8InputScaleSpec(granularity, origin, dtype, activation_scale_ub=None, block_size=None)
Specifies how input activations are scaled for float8 quantization.
-
Parameters:
-
- granularity (Float8ScaleGranularity)
- origin (Float8ScaleOrigin)
- dtype (DType)
- activation_scale_ub (float | None)
- block_size (tuple[int, int] | None)
activation_scale_ub
An optional upper bound for dynamic activation scaling.
block_size
The tuple[int, int] of the block size for block-wise scaling.
dtype
dtype: DType
The DType of the input scale factor(s).
granularity
granularity: Float8ScaleGranularity
The Float8ScaleGranularity of the input scale factor application.
is_block
property is_block: bool
Whether the input scale granularity is block-wise.
is_colwise
property is_colwise: bool
Whether the input scale granularity is column-wise.
is_rowwise
property is_rowwise: bool
Whether the input scale granularity is row-wise.
is_tensor
property is_tensor: bool
Whether the input scale granularity is per-tensor.
origin
origin: Float8ScaleOrigin
The Float8ScaleOrigin (static or dynamic) of the input scale factor.
Float8ScaleGranularity
class max.nn.float8_config.Float8ScaleGranularity(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Specifies the granularity of the quantization scale factor.
Determines whether a scale factor applies per-tensor, per-row (often for weights), per-column, or per-block within a tensor.
BLOCK
BLOCK = 'block'
Per-block scaling.
COLWISE
COLWISE = 'colwise'
Per-column scaling.
ROWWISE
ROWWISE = 'rowwise'
Per-row scaling.
TENSOR
TENSOR = 'tensor'
Per-tensor scaling.
Float8ScaleOrigin
class max.nn.float8_config.Float8ScaleOrigin(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Specifies whether the quantization scale is determined statically or dynamically.
DYNAMIC
DYNAMIC = 'dynamic'
Scales are computed at runtime based on the input data.
STATIC
STATIC = 'static'
Scales are pre-computed and loaded with the model weights.
Float8WeightScaleSpec
class max.nn.float8_config.Float8WeightScaleSpec(granularity, dtype, block_size=None)
Specifies how weights are scaled for float8 quantization.
-
Parameters:
-
- granularity (Float8ScaleGranularity)
- dtype (DType)
- block_size (tuple[int, int] | None)
block_size
The tuple[int, int] of the block size for block-wise scaling.
dtype
dtype: DType
The DType of the weight scale factor(s).
granularity
granularity: Float8ScaleGranularity
The Float8ScaleGranularity of the weight scale factor application.
is_block
property is_block: bool
Whether the weight scale granularity is block-wise.
is_colwise
property is_colwise: bool
Whether the weight scale granularity is column-wise.
is_rowwise
property is_rowwise: bool
Whether the weight scale granularity is row-wise.
is_tensor
property is_tensor: bool
Whether the weight scale granularity is per-tensor.
parse_float8_config()
max.nn.float8_config.parse_float8_config(huggingface_config, state_dict, dtype, state_dict_name_prefix='', ignored_modules_prefix='model.')
Parses Float8Config from HuggingFace config by dispatching to format-specific parsers.
Dispatches to the appropriate format-specific parser based on the quantization method in the HuggingFace config.
-
Parameters:
-
Return type:
-
Float8Config | None
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!