Mojo function
grouped_block_scaled_matmul
grouped_block_scaled_matmul[transpose_b: Bool, max_groups: Int, *, config: BlockScaledMatmulConfig[config.a_type, config.b_type, config.c_type, config.sfa_dtype, config.sfb_dtype, transpose_b], elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None](a_ptrs: TileTensor[DType.uint64, a_ptrs.LayoutType, a_ptrs.origin, address_space=a_ptrs.address_space, linear_idx_type=a_ptrs.linear_idx_type, element_size=a_ptrs.element_size], b_ptrs: TileTensor[DType.uint64, b_ptrs.LayoutType, b_ptrs.origin, address_space=b_ptrs.address_space, linear_idx_type=b_ptrs.linear_idx_type, element_size=b_ptrs.element_size], c_ptrs: TileTensor[DType.uint64, c_ptrs.LayoutType, c_ptrs.origin, address_space=c_ptrs.address_space, linear_idx_type=c_ptrs.linear_idx_type, element_size=c_ptrs.element_size], sfa_ptrs: TileTensor[DType.uint64, sfa_ptrs.LayoutType, sfa_ptrs.origin, address_space=sfa_ptrs.address_space, linear_idx_type=sfa_ptrs.linear_idx_type, element_size=sfa_ptrs.element_size], sfb_ptrs: TileTensor[DType.uint64, sfb_ptrs.LayoutType, sfb_ptrs.origin, address_space=sfb_ptrs.address_space, linear_idx_type=sfb_ptrs.linear_idx_type, element_size=sfb_ptrs.element_size], problem_sizes: TileTensor[DType.int32, problem_sizes.LayoutType, problem_sizes.origin, address_space=problem_sizes.address_space, linear_idx_type=problem_sizes.linear_idx_type, element_size=problem_sizes.element_size], num_groups: Int, total_tiles: Int, a_template: TileTensor[config.a_type, a_template.LayoutType, a_template.origin, address_space=a_template.address_space, linear_idx_type=a_template.linear_idx_type, element_size=a_template.element_size], b_template: TileTensor[config.b_type, b_template.LayoutType, b_template.origin, address_space=b_template.address_space, linear_idx_type=b_template.linear_idx_type, element_size=b_template.element_size], c_template: TileTensor[config.c_type, c_template.LayoutType, c_template.origin, address_space=c_template.address_space, linear_idx_type=c_template.linear_idx_type, element_size=c_template.element_size], sfa_template: TileTensor[config.sfa_dtype, sfa_template.LayoutType, sfa_template.origin, address_space=sfa_template.address_space, linear_idx_type=sfa_template.linear_idx_type, element_size=sfa_template.element_size], sfb_template: TileTensor[config.sfb_dtype, sfb_template.LayoutType, sfb_template.origin, address_space=sfb_template.address_space, linear_idx_type=sfb_template.linear_idx_type, element_size=sfb_template.element_size], ctx: DeviceContext)
Launch grouped block-scaled FP8 matmul kernel on SM100.
Computes C[g] = scale(A[g]) @ scale(B[g]) for g in range(num_groups), where each group can have different M, N, K dimensions.
Parameters:
- transpose_b (
Bool): Whether B is transposed (must be True). - max_groups (
Int): Maximum number of groups (compile-time bound). - config (
BlockScaledMatmulConfig): Block-scaled matmul configuration. - elementwise_compute_lambda_fn (
Optional): Optional epilogue lambda for element-wise operations on output. Applied after matmul, before writing to global memory.
Args:
- a_ptrs (
TileTensor): Per-group A matrix pointers (max_groups, 1). - b_ptrs (
TileTensor): Per-group B matrix pointers (max_groups, 1). - c_ptrs (
TileTensor): Per-group C matrix pointers (max_groups, 1). - sfa_ptrs (
TileTensor): Per-group A scaling factor pointers (max_groups, 1). - sfb_ptrs (
TileTensor): Per-group B scaling factor pointers (max_groups, 1). - problem_sizes (
TileTensor): Per-group problem sizes (max_groups, 4) as [M, N, K, L]. - num_groups (
Int): Actual number of groups (runtime value <= max_groups). - total_tiles (
Int): Total tiles across all groups (computed by caller). - a_template (
TileTensor): Template A tensor (1, M, K) for TMA descriptor creation. - b_template (
TileTensor): Template B tensor (1, N, K) for TMA descriptor creation. - c_template (
TileTensor): Template C tensor (1, M, N) for TMA descriptor creation. - sfa_template (
TileTensor): Template SFA tensor for TMA descriptor creation. - sfb_template (
TileTensor): Template SFB tensor for TMA descriptor creation. - ctx (
DeviceContext): Device context for kernel launch.
Raises:
If configuration constraints are violated.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!