Mojo module
grouped_block_scaled_matmul_kernel
Grouped block-scaled SM100 matmul kernel for multiple GEMM problems.
This kernel extends the block_scaled_matmul_kernel to support grouped GEMM with variable problem sizes per group. It uses:
- GroupedTileScheduler: For linear tile iteration across groups
- TMATensorTileArray: For per-block updatable TMA descriptors
- Dynamic tensormap updates: When transitioning between groups
Architecture (aligned with NVIDIA CuTe DSL grouped_blockscaled_gemm.py):
- TMA warp: Initializes A/B/SFA/SFB tensormaps, handles group transitions
- MMA warp: Consumes input tiles, performs block-scaled MMA
- Epilogue warps: Initializes C tensormap, handles C group transitions
- Named barrier synchronization between warps for tensormap init
Key differences from block_scaled_matmul_kernel.mojo:
- TMA descriptors are per-block (TMATensorTileArray) not grid constants
- SMEM tensormap buffers for dynamic updates (5 x 128 bytes)
- GroupedWorkInfo provides group_idx, k_tile_count, group_changed
- When group_changed=True, tensormaps are updated before loading tiles
- K-loop uses per-group k_tile_count instead of global K dimension
comptime values
NUM_TENSORMAPS
comptime NUM_TENSORMAPS = 5
TMA_DESCRIPTOR_SIZE
comptime TMA_DESCRIPTOR_SIZE = 128
Structs
-
GroupedBlockScaledMatmulKernel: Grouped block-scaled matmul kernel with dynamic tensormap updates. -
GroupedTensormapManager: Manages tensormap SMEM state and updates for grouped GEMM. -
GroupedTensormapSmem: Shared memory pointers for tensormap descriptors.
Functions
-
is_valid_dtypes_and_scale_factor_vec_size: Check if dtypes and sf_vec_size are valid combinations. -
is_valid_mma_tiler_and_cluster_shape: Check if MMA tiler and cluster shape are valid.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!