Skip to main content

Mojo function

grouped_block_scaled_matmul

grouped_block_scaled_matmul[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, sfa_dtype: DType, sfa_layout: Layout, sfb_dtype: DType, sfb_layout: Layout, transpose_b: Bool, max_groups: Int, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True](a_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], b_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], c_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], sfa_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], sfb_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], problem_sizes: LayoutTensor[DType.int32, Layout.row_major(max_groups, 4), MutAnyOrigin], num_groups: Int, total_tiles: Int, a_template: LayoutTensor[a_type, a_layout, MutAnyOrigin], b_template: LayoutTensor[b_type, b_layout, MutAnyOrigin], c_template: LayoutTensor[c_type, c_layout, MutAnyOrigin], sfa_template: LayoutTensor[sfa_dtype, sfa_layout, MutAnyOrigin], sfb_template: LayoutTensor[sfb_dtype, sfb_layout, MutAnyOrigin], ctx: DeviceContext)

Launch grouped block-scaled FP8 matmul kernel on SM100.

Computes C[g] = scale(A[g]) @ scale(B[g]) for g in range(num_groups), where each group can have different M, N, K dimensions.

Parameters:

  • c_type (DType): Output element type.
  • c_layout (Layout): Output tensor layout.
  • a_type (DType): A matrix element type (FP8).
  • a_layout (Layout): A tensor layout.
  • b_type (DType): B matrix element type (FP8).
  • b_layout (Layout): B tensor layout.
  • sfa_dtype (DType): A scaling factor type (F8-UE8M0).
  • sfa_layout (Layout): A scaling factor tensor layout.
  • sfb_dtype (DType): B scaling factor type (F8-UE8M0).
  • sfb_layout (Layout): B scaling factor tensor layout.
  • transpose_b (Bool): Whether B is transposed (must be True).
  • max_groups (Int): Maximum number of groups (compile-time bound).
  • config (BlockScaledMatmulConfig): Block-scaled matmul configuration.
  • elementwise_compute_lambda_fn (Optional): Optional epilogue lambda for element-wise operations on output. Applied after matmul, before writing to global memory.
  • register_based_epilogue (Bool): If True (default), apply epilogue in registers. If False, use SMEM-based epilogue path.

Args:

  • a_ptrs (LayoutTensor): Per-group A matrix pointers (max_groups, 1).
  • b_ptrs (LayoutTensor): Per-group B matrix pointers (max_groups, 1).
  • c_ptrs (LayoutTensor): Per-group C matrix pointers (max_groups, 1).
  • sfa_ptrs (LayoutTensor): Per-group A scaling factor pointers (max_groups, 1).
  • sfb_ptrs (LayoutTensor): Per-group B scaling factor pointers (max_groups, 1).
  • problem_sizes (LayoutTensor): Per-group problem sizes (max_groups, 4) as [M, N, K, L].
  • num_groups (Int): Actual number of groups (runtime value <= max_groups).
  • total_tiles (Int): Total tiles across all groups (computed by caller).
  • a_template (LayoutTensor): Template A tensor for TMA descriptor creation.
  • b_template (LayoutTensor): Template B tensor for TMA descriptor creation.
  • c_template (LayoutTensor): Template C tensor for TMA descriptor creation.
  • sfa_template (LayoutTensor): Template SFA tensor for TMA descriptor creation.
  • sfb_template (LayoutTensor): Template SFB tensor for TMA descriptor creation.
  • ctx (DeviceContext): Device context for kernel launch.

Raises:

If configuration constraints are violated.

Was this page helpful?