Skip to main content

Mojo function

grouped_block_scaled_matmul

grouped_block_scaled_matmul[transpose_b: Bool, max_groups: Int, *, config: BlockScaledMatmulConfig[config.a_type, config.b_type, config.c_type, config.sfa_dtype, config.sfb_dtype, transpose_b], elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None](a_ptrs: TileTensor[DType.uint64, a_ptrs.LayoutType, a_ptrs.origin, address_space=a_ptrs.address_space, linear_idx_type=a_ptrs.linear_idx_type, element_size=a_ptrs.element_size], b_ptrs: TileTensor[DType.uint64, b_ptrs.LayoutType, b_ptrs.origin, address_space=b_ptrs.address_space, linear_idx_type=b_ptrs.linear_idx_type, element_size=b_ptrs.element_size], c_ptrs: TileTensor[DType.uint64, c_ptrs.LayoutType, c_ptrs.origin, address_space=c_ptrs.address_space, linear_idx_type=c_ptrs.linear_idx_type, element_size=c_ptrs.element_size], sfa_ptrs: TileTensor[DType.uint64, sfa_ptrs.LayoutType, sfa_ptrs.origin, address_space=sfa_ptrs.address_space, linear_idx_type=sfa_ptrs.linear_idx_type, element_size=sfa_ptrs.element_size], sfb_ptrs: TileTensor[DType.uint64, sfb_ptrs.LayoutType, sfb_ptrs.origin, address_space=sfb_ptrs.address_space, linear_idx_type=sfb_ptrs.linear_idx_type, element_size=sfb_ptrs.element_size], problem_sizes: TileTensor[DType.int32, problem_sizes.LayoutType, problem_sizes.origin, address_space=problem_sizes.address_space, linear_idx_type=problem_sizes.linear_idx_type, element_size=problem_sizes.element_size], num_groups: Int, total_tiles: Int, a_template: TileTensor[config.a_type, a_template.LayoutType, a_template.origin, address_space=a_template.address_space, linear_idx_type=a_template.linear_idx_type, element_size=a_template.element_size], b_template: TileTensor[config.b_type, b_template.LayoutType, b_template.origin, address_space=b_template.address_space, linear_idx_type=b_template.linear_idx_type, element_size=b_template.element_size], c_template: TileTensor[config.c_type, c_template.LayoutType, c_template.origin, address_space=c_template.address_space, linear_idx_type=c_template.linear_idx_type, element_size=c_template.element_size], sfa_template: TileTensor[config.sfa_dtype, sfa_template.LayoutType, sfa_template.origin, address_space=sfa_template.address_space, linear_idx_type=sfa_template.linear_idx_type, element_size=sfa_template.element_size], sfb_template: TileTensor[config.sfb_dtype, sfb_template.LayoutType, sfb_template.origin, address_space=sfb_template.address_space, linear_idx_type=sfb_template.linear_idx_type, element_size=sfb_template.element_size], ctx: DeviceContext)

Launch grouped block-scaled FP8 matmul kernel on SM100.

Computes C[g] = scale(A[g]) @ scale(B[g]) for g in range(num_groups), where each group can have different M, N, K dimensions.

Parameters:

  • transpose_b (Bool): Whether B is transposed (must be True).
  • max_groups (Int): Maximum number of groups (compile-time bound).
  • config (BlockScaledMatmulConfig): Block-scaled matmul configuration.
  • elementwise_compute_lambda_fn (Optional): Optional epilogue lambda for element-wise operations on output. Applied after matmul, before writing to global memory.

Args:

  • a_ptrs (TileTensor): Per-group A matrix pointers (max_groups, 1).
  • b_ptrs (TileTensor): Per-group B matrix pointers (max_groups, 1).
  • c_ptrs (TileTensor): Per-group C matrix pointers (max_groups, 1).
  • sfa_ptrs (TileTensor): Per-group A scaling factor pointers (max_groups, 1).
  • sfb_ptrs (TileTensor): Per-group B scaling factor pointers (max_groups, 1).
  • problem_sizes (TileTensor): Per-group problem sizes (max_groups, 4) as [M, N, K, L].
  • num_groups (Int): Actual number of groups (runtime value <= max_groups).
  • total_tiles (Int): Total tiles across all groups (computed by caller).
  • a_template (TileTensor): Template A tensor (1, M, K) for TMA descriptor creation.
  • b_template (TileTensor): Template B tensor (1, N, K) for TMA descriptor creation.
  • c_template (TileTensor): Template C tensor (1, M, N) for TMA descriptor creation.
  • sfa_template (TileTensor): Template SFA tensor for TMA descriptor creation.
  • sfb_template (TileTensor): Template SFB tensor for TMA descriptor creation.
  • ctx (DeviceContext): Device context for kernel launch.

Raises:

If configuration constraints are violated.

Was this page helpful?