Skip to main content

Mojo module

grouped_block_scaled_matmul_kernel

Grouped block-scaled SM100 matmul kernel for multiple GEMM problems.

This kernel extends the block_scaled_matmul_kernel to support grouped GEMM with variable problem sizes per group. It uses:

  1. GroupedTileScheduler: For linear tile iteration across groups
  2. TMATensorTileArray: For per-block updatable TMA descriptors
  3. Dynamic tensormap updates: When transitioning between groups

Architecture (aligned with NVIDIA CuTe DSL grouped_blockscaled_gemm.py):

  • TMA warp: Initializes A/B/SFA/SFB tensormaps, handles group transitions
  • MMA warp: Consumes input tiles, performs block-scaled MMA
  • Epilogue warps: Initializes C tensormap, handles C group transitions
  • Named barrier synchronization between warps for tensormap init

Key differences from block_scaled_matmul_kernel.mojo:

  1. TMA descriptors are per-block (TMATensorTileArray) not grid constants
  2. SMEM tensormap buffers for dynamic updates (5 x 128 bytes)
  3. GroupedWorkInfo provides group_idx, k_tile_count, group_changed
  4. When group_changed=True, tensormaps are updated before loading tiles
  5. K-loop uses per-group k_tile_count instead of global K dimension

comptime values

NUM_TENSORMAPS

comptime NUM_TENSORMAPS = 5

TMA_DESCRIPTOR_SIZE

comptime TMA_DESCRIPTOR_SIZE = 128

Structs

Functions

Was this page helpful?