Mojo function
grouped_block_scaled_matmul
grouped_block_scaled_matmul[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, sfa_dtype: DType, sfa_layout: Layout, sfb_dtype: DType, sfb_layout: Layout, transpose_b: Bool, max_groups: Int, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True](a_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], b_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], c_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], sfa_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], sfb_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], problem_sizes: LayoutTensor[DType.int32, Layout.row_major(max_groups, 4), MutAnyOrigin], num_groups: Int, total_tiles: Int, a_template: LayoutTensor[a_type, a_layout, MutAnyOrigin], b_template: LayoutTensor[b_type, b_layout, MutAnyOrigin], c_template: LayoutTensor[c_type, c_layout, MutAnyOrigin], sfa_template: LayoutTensor[sfa_dtype, sfa_layout, MutAnyOrigin], sfb_template: LayoutTensor[sfb_dtype, sfb_layout, MutAnyOrigin], ctx: DeviceContext)
Launch grouped block-scaled FP8 matmul kernel on SM100.
Computes C[g] = scale(A[g]) @ scale(B[g]) for g in range(num_groups), where each group can have different M, N, K dimensions.
Parameters:
- c_type (
DType): Output element type. - c_layout (
Layout): Output tensor layout. - a_type (
DType): A matrix element type (FP8). - a_layout (
Layout): A tensor layout. - b_type (
DType): B matrix element type (FP8). - b_layout (
Layout): B tensor layout. - sfa_dtype (
DType): A scaling factor type (F8-UE8M0). - sfa_layout (
Layout): A scaling factor tensor layout. - sfb_dtype (
DType): B scaling factor type (F8-UE8M0). - sfb_layout (
Layout): B scaling factor tensor layout. - transpose_b (
Bool): Whether B is transposed (must be True). - max_groups (
Int): Maximum number of groups (compile-time bound). - config (
BlockScaledMatmulConfig): Block-scaled matmul configuration. - elementwise_compute_lambda_fn (
Optional): Optional epilogue lambda for element-wise operations on output. Applied after matmul, before writing to global memory. - register_based_epilogue (
Bool): If True (default), apply epilogue in registers. If False, use SMEM-based epilogue path.
Args:
- a_ptrs (
LayoutTensor): Per-group A matrix pointers (max_groups, 1). - b_ptrs (
LayoutTensor): Per-group B matrix pointers (max_groups, 1). - c_ptrs (
LayoutTensor): Per-group C matrix pointers (max_groups, 1). - sfa_ptrs (
LayoutTensor): Per-group A scaling factor pointers (max_groups, 1). - sfb_ptrs (
LayoutTensor): Per-group B scaling factor pointers (max_groups, 1). - problem_sizes (
LayoutTensor): Per-group problem sizes (max_groups, 4) as [M, N, K, L]. - num_groups (
Int): Actual number of groups (runtime value <= max_groups). - total_tiles (
Int): Total tiles across all groups (computed by caller). - a_template (
LayoutTensor): Template A tensor for TMA descriptor creation. - b_template (
LayoutTensor): Template B tensor for TMA descriptor creation. - c_template (
LayoutTensor): Template C tensor for TMA descriptor creation. - sfa_template (
LayoutTensor): Template SFA tensor for TMA descriptor creation. - sfb_template (
LayoutTensor): Template SFB tensor for TMA descriptor creation. - ctx (
DeviceContext): Device context for kernel launch.
Raises:
If configuration constraints are violated.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!