Mojo module
grouped_block_scaled_matmul
CPU entry points for grouped block-scaled SM100 matmul.
Supports multiple GEMM operations with variable problem sizes per group. Uses TMATensorTileArray for per-block updatable TMA descriptors.
This module implements grouped block-scaled GEMM following the architecture of NVIDIA CuTe DSL grouped_blockscaled_gemm.py:
- Creates template TMA descriptors from the first group
- Creates TMATensorTileArray with one tensormap per block
- Launches GroupedBlockScaledMatmulKernel with per-group pointers
Usage: # Per-group pointers (device addresses) var a_ptrs = ... # (num_groups, 1) with uint64 addresses var b_ptrs = ... # (num_groups, 1) var c_ptrs = ... # (num_groups, 1) var sfa_ptrs = ... # (num_groups, 1) var sfb_ptrs = ... # (num_groups, 1)
# Problem sizes per group
var problem_sizes = ... # (num_groups, 4) with [M, N, K, L]
grouped_block_scaled_matmul[...](
a_ptrs, b_ptrs, c_ptrs, sfa_ptrs, sfb_ptrs,
problem_sizes, num_groups, ctx
)Functions
-
compute_total_tiles: Compute total number of tiles across all groups. -
grouped_block_scaled_matmul: Launch grouped block-scaled FP8 matmul kernel on SM100. -
grouped_smem_size: Calculate shared memory size for grouped block-scaled kernel. -
validate_grouped_gemm_constraints: Validate grouped GEMM configuration constraints.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!