Skip to main content

Mojo function

is_valid_dtypes_and_scale_factor_vec_size

is_valid_dtypes_and_scale_factor_vec_size(ab_dtype: DType, sf_dtype: DType, sf_vec_size: Int, c_dtype: DType) -> Bool

Check if dtypes and sf_vec_size are valid combinations.

Valid combinations (from NVIDIA CuTe DSL grouped_blockscaled_gemm.py):

  • MXF8: Float8E5M2/Float8E4M3FN + Float8E8M0FNU + sf_vec_size=32
  • MXF4: Float4E2M1FN + Float8E8M0FNU + sf_vec_size=32
  • NVF4: Float4E2M1FN + Float8E8M0FNU/Float8E4M3FN + sf_vec_size=16

Args:

  • ​ab_dtype (DType): The data type of A and B matrices.
  • ​sf_dtype (DType): The data type of scale factors.
  • ​sf_vec_size (Int): The vector size of scale factors (16 or 32).
  • ​c_dtype (DType): The data type of the output matrix.

Returns:

Bool: True if the combination is valid.

Was this page helpful?