Mojo function
blackwell_block_scaled_matmul_tma_umma_warp_specialized
blackwell_block_scaled_matmul_tma_umma_warp_specialized[transpose_b: Bool, *, config: BlockScaledMatmulConfig[config.a_type, config.b_type, config.c_type, config.sfa_dtype, config.sfb_dtype, transpose_b], elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: Optional[UInt32] = None](c_tensor: TileTensor[c_tensor.dtype, c_tensor.LayoutType, c_tensor.origin, address_space=c_tensor.address_space, linear_idx_type=c_tensor.linear_idx_type, element_shape_types=c_tensor.element_shape_types], a_tensor: TileTensor[a_tensor.dtype, a_tensor.LayoutType, a_tensor.origin, address_space=a_tensor.address_space, linear_idx_type=a_tensor.linear_idx_type, element_shape_types=a_tensor.element_shape_types], b_tensor: TileTensor[b_tensor.dtype, b_tensor.LayoutType, b_tensor.origin, address_space=b_tensor.address_space, linear_idx_type=b_tensor.linear_idx_type, element_shape_types=b_tensor.element_shape_types], a_scales_tensor: TileTensor[a_scales_tensor.dtype, a_scales_tensor.LayoutType, a_scales_tensor.origin, address_space=a_scales_tensor.address_space, linear_idx_type=a_scales_tensor.linear_idx_type, element_shape_types=a_scales_tensor.element_shape_types], b_scales_tensor: TileTensor[b_scales_tensor.dtype, b_scales_tensor.LayoutType, b_scales_tensor.origin, address_space=b_scales_tensor.address_space, linear_idx_type=b_scales_tensor.linear_idx_type, element_shape_types=b_scales_tensor.element_shape_types], ctx: DeviceContext, alpha: Float32 = 1)
Launch block-scaled FP8 matmul kernel on SM100.
Computes C = scale(A) @ scale(B) where A and B are FP8 matrices with per-block scaling factors following MXFP8 conventions.
When config.AB_swapped is True, internally swaps A and B operands (along with their scale factors) and transposes the output for better performance when M is small.
Parameters:
- transpose_b (
Bool): Whether B is transposed (must be True). - config (
BlockScaledMatmulConfig): Block-scaled matmul configuration. - elementwise_compute_lambda_fn (
Optional): Optional epilogue lambda. - register_based_epilogue (
Bool): Whether to use register-based epilogue. - pdl_level (
PDLLevel): Programmatic dependent launch level. - max_profiled_tiles_per_SM (
Optional): Optional profiling tile count.
Args:
- c_tensor (
TileTensor): Output tensor (TileTensor). - a_tensor (
TileTensor): A matrix tensor (TileTensor). - b_tensor (
TileTensor): B matrix tensor (TileTensor). - a_scales_tensor (
TileTensor): A scaling factors (TileTensor). - b_scales_tensor (
TileTensor): B scaling factors (TileTensor). - ctx (
DeviceContext): Device context for kernel launch. - alpha (
Float32): Tensor scale factor (scalar).
Raises:
If configuration constraints are violated.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!