Mojo function
validate_grouped_gemm_constraints
validate_grouped_gemm_constraints[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]()
Validate grouped GEMM configuration constraints.
Constraints from NVIDIA CuTe DSL grouped_blockscaled_gemm.py:
- MMA tiler M: 128 or 256
- MMA tiler N: 128 or 256
- Cluster M/N: Power of 2, <=4 per axis (for SF multicast)
- Total cluster size: <=16
- 16-byte alignment on contiguous dimensions
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!