Skip to main content

Mojo function

blackwell_block_scaled_matmul_tma_umma_warp_specialized

blackwell_block_scaled_matmul_tma_umma_warp_specialized[transpose_b: Bool, *, config: BlockScaledMatmulConfig[config.a_type, config.b_type, config.c_type, config.sfa_dtype, config.sfb_dtype, transpose_b], elementwise_compute_lambda_fn: Optional[elementwise_compute_lambda_type] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: Optional[UInt32] = None](c_tensor: TileTensor[c_tensor.dtype, c_tensor.LayoutType, c_tensor.origin, address_space=c_tensor.address_space, linear_idx_type=c_tensor.linear_idx_type, element_shape_types=c_tensor.element_shape_types], a_tensor: TileTensor[a_tensor.dtype, a_tensor.LayoutType, a_tensor.origin, address_space=a_tensor.address_space, linear_idx_type=a_tensor.linear_idx_type, element_shape_types=a_tensor.element_shape_types], b_tensor: TileTensor[b_tensor.dtype, b_tensor.LayoutType, b_tensor.origin, address_space=b_tensor.address_space, linear_idx_type=b_tensor.linear_idx_type, element_shape_types=b_tensor.element_shape_types], a_scales_tensor: TileTensor[a_scales_tensor.dtype, a_scales_tensor.LayoutType, a_scales_tensor.origin, address_space=a_scales_tensor.address_space, linear_idx_type=a_scales_tensor.linear_idx_type, element_shape_types=a_scales_tensor.element_shape_types], b_scales_tensor: TileTensor[b_scales_tensor.dtype, b_scales_tensor.LayoutType, b_scales_tensor.origin, address_space=b_scales_tensor.address_space, linear_idx_type=b_scales_tensor.linear_idx_type, element_shape_types=b_scales_tensor.element_shape_types], ctx: DeviceContext, alpha: Float32 = 1)

Launch block-scaled FP8 matmul kernel on SM100.

Computes C = scale(A) @ scale(B) where A and B are FP8 matrices with per-block scaling factors following MXFP8 conventions.

When config.AB_swapped is True, internally swaps A and B operands (along with their scale factors) and transposes the output for better performance when M is small.

Parameters:

  • transpose_b (Bool): Whether B is transposed (must be True).
  • config (BlockScaledMatmulConfig): Block-scaled matmul configuration.
  • elementwise_compute_lambda_fn (Optional): Optional epilogue lambda.
  • register_based_epilogue (Bool): Whether to use register-based epilogue.
  • pdl_level (PDLLevel): Programmatic dependent launch level.
  • max_profiled_tiles_per_SM (Optional): Optional profiling tile count.

Args:

  • c_tensor (TileTensor): Output tensor (TileTensor).
  • a_tensor (TileTensor): A matrix tensor (TileTensor).
  • b_tensor (TileTensor): B matrix tensor (TileTensor).
  • a_scales_tensor (TileTensor): A scaling factors (TileTensor).
  • b_scales_tensor (TileTensor): B scaling factors (TileTensor).
  • ctx (DeviceContext): Device context for kernel launch.
  • alpha (Float32): Tensor scale factor (scalar).

Raises:

If configuration constraints are violated.

Was this page helpful?