Skip to main content

Mojo module

grouped_block_scaled_matmul

CPU entry points for grouped block-scaled SM100 matmul.

Supports multiple GEMM operations with variable problem sizes per group. Uses TMATensorTileArray for per-block updatable TMA descriptors.

This module implements grouped block-scaled GEMM following the architecture of NVIDIA CuTe DSL grouped_blockscaled_gemm.py:

  1. Creates template TMA descriptors from caller-provided TileTensors
  2. Creates TMATensorTileArray with one tensormap per block
  3. Launches GroupedBlockScaledMatmulKernel with per-group pointers

All tensor arguments use TileTensor with compile-time dtype constraints. Callers must provide template tensors in the correct shapes:

  • A/B/C templates: 3D (1, M/N, K/N) with batch=1
  • Scale factor templates: 5D (1, groups_M/N, groups_K, SF_ATOM_M[0], SF_ATOM_M[1] * SF_ATOM_K)

Usage: # Per-group pointers as TileTensor[DType.uint64, ...] var a_ptrs = TileTensor(ptr, tile_row_majornum_groups, 1) ...

# Problem sizes as TileTensor[DType.int32, ...]
var problem_sizes = TileTensor(ptr, tile_row_major[num_groups, 4]())

# 3D template TileTensors for TMA descriptor creation
var a_template = TileTensor(a_ptr, tile_row_major[1, M, K]())
...

grouped_block_scaled_matmul[...](
    a_ptrs, b_ptrs, c_ptrs, sfa_ptrs, sfb_ptrs,
    problem_sizes, num_groups, total_tiles,
    a_template, b_template, c_template,
    sfa_template, sfb_template, ctx
)

Functions​

Was this page helpful?