Mojo function
blackwell_block_scaled_matmul_tma_umma_warp_specialized
blackwell_block_scaled_matmul_tma_umma_warp_specialized[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, sfa_dtype: DType, sfa_layout: Layout, sfb_dtype: DType, sfb_layout: Layout, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: OptionalReg[UInt32] = None](c_tensor: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_tensor: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_tensor: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales_tensor: LayoutTensor[sfa_dtype, sfa_layout, MutAnyOrigin], b_scales_tensor: LayoutTensor[sfb_dtype, sfb_layout, MutAnyOrigin], ctx: DeviceContext)
Launch block-scaled FP8 matmul kernel on SM100.
Computes C = scale(A) @ scale(B) where A and B are FP8 matrices with per-block scaling factors following MXFP8 conventions.
Parameters:
- c_type (
DType): Output element type. - c_layout (
Layout): Output tensor layout. - a_type (
DType): A matrix element type (FP8). - a_layout (
Layout): A matrix layout. - b_type (
DType): B matrix element type (FP8). - b_layout (
Layout): B matrix layout. - sfa_dtype (
DType): A scaling factor type (F8-UE8M0). - sfa_layout (
Layout): A scaling factor layout. - sfb_dtype (
DType): B scaling factor type (F8-UE8M0). - sfb_layout (
Layout): B scaling factor layout. - transpose_b (
Bool): Whether B is transposed (must be True). - config (
BlockScaledMatmulConfig): Block-scaled matmul configuration. - elementwise_compute_lambda_fn (
OptionalReg): Optional epilogue lambda. - register_based_epilogue (
Bool): Whether to use register-based epilogue. - pdl_level (
PDLLevel): Programmatic dependent launch level. - max_profiled_tiles_per_SM (
OptionalReg): Optional profiling tile count.
Args:
- c_tensor (
LayoutTensor): Output tensor. - a_tensor (
LayoutTensor): A matrix tensor. - b_tensor (
LayoutTensor): B matrix tensor. - a_scales_tensor (
LayoutTensor): A scaling factors. - b_scales_tensor (
LayoutTensor): B scaling factors. - ctx (
DeviceContext): Device context for kernel launch.
Raises:
If configuration constraints are violated.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!