Mojo function
is_valid_dtypes_and_scale_factor_vec_size
is_valid_dtypes_and_scale_factor_vec_size(ab_dtype: DType, sf_dtype: DType, sf_vec_size: Int, c_dtype: DType) -> Bool
Check if dtypes and sf_vec_size are valid combinations.
Valid combinations (from NVIDIA CuTe DSL grouped_blockscaled_gemm.py):
- MXF8: Float8E5M2/Float8E4M3FN + Float8E8M0FNU + sf_vec_size=32
- MXF4: Float4E2M1FN + Float8E8M0FNU + sf_vec_size=32
- NVF4: Float4E2M1FN + Float8E8M0FNU/Float8E4M3FN + sf_vec_size=16
Args:
- βab_dtype (
DType): The data type of A and B matrices. - βsf_dtype (
DType): The data type of scale factors. - βsf_vec_size (
Int): The vector size of scale factors (16 or 32). - βc_dtype (
DType): The data type of the output matrix.
Returns:
Bool: True if the combination is valid.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!