Mojo module
grouped_block_scaled_matmul
CPU entry points for grouped block-scaled SM100 matmul.
Supports multiple GEMM operations with variable problem sizes per group. Uses TMATensorTileArray for per-block updatable TMA descriptors.
This module implements grouped block-scaled GEMM following the architecture of NVIDIA CuTe DSL grouped_blockscaled_gemm.py:
- Creates template TMA descriptors from caller-provided TileTensors
- Creates TMATensorTileArray with one tensormap per block
- Launches GroupedBlockScaledMatmulKernel with per-group pointers
All tensor arguments use TileTensor with compile-time dtype constraints. Callers must provide template tensors in the correct shapes:
- A/B/C templates: 3D (1, M/N, K/N) with batch=1
- Scale factor templates: 5D (1, groups_M/N, groups_K, SF_ATOM_M[0], SF_ATOM_M[1] * SF_ATOM_K)
Usage: # Per-group pointers as TileTensor[DType.uint64, ...] var a_ptrs = TileTensor(ptr, tile_row_majornum_groups, 1) ...
# Problem sizes as TileTensor[DType.int32, ...]
var problem_sizes = TileTensor(ptr, tile_row_major[num_groups, 4]())
# 3D template TileTensors for TMA descriptor creation
var a_template = TileTensor(a_ptr, tile_row_major[1, M, K]())
...
grouped_block_scaled_matmul[...](
a_ptrs, b_ptrs, c_ptrs, sfa_ptrs, sfb_ptrs,
problem_sizes, num_groups, total_tiles,
a_template, b_template, c_template,
sfa_template, sfb_template, ctx
)Functionsβ
- β
grouped_block_scaled_matmul: Launch grouped block-scaled FP8 matmul kernel on SM100. - β
grouped_smem_size: Calculate shared memory size for grouped block-scaled kernel. - β
validate_grouped_gemm_constraints: Validate grouped GEMM configuration constraints.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!