Skip to main content

Mojo function

validate_grouped_gemm_constraints

validate_grouped_gemm_constraints[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]()

Validate grouped GEMM configuration constraints.

Constraints from NVIDIA CuTe DSL grouped_blockscaled_gemm.py:

  • MMA tiler M: 128 or 256
  • MMA tiler N: 128 or 256
  • Cluster M/N: Power of 2, <=4 per axis (for SF multicast)
  • Total cluster size: <=16
  • 16-byte alignment on contiguous dimensions

Was this page helpful?