Skip to main content

Mojo struct

BlackwellBlockwiseFP8MatmulKernel

struct BlackwellBlockwiseFP8MatmulKernel[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, b_scales_type: DType, b_scales_layout: TensorLayout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple(SIMD(1))]

Blockwise FP8 matmul kernel with register-based accumulation.

This kernel implements per-K-iteration scaling in CUDA cores:

  1. Load warp: TMA loads A, B, A-scales to SMEM
  2. MMA warp: Standard MMA (partial to TMEM)
  3. Epilogue warp: TMEM read → scale → register accumulate → output

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

a_expected_bytes

comptime a_expected_bytes = ((BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM * BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK) * size_of[a_type]())

a_scales_expected_bytes

comptime a_scales_expected_bytes = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM * size_of[a_scales_type]())

a_swizzle_elems

comptime a_swizzle_elems = (config.a_swizzle.bytes() // size_of[a_type]())

a_tile_dim0

comptime a_tile_dim0 = compute_tma_tile_dims[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BN, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_N, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]()[0]

a_tma_load_size

comptime a_tma_load_size = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_tile_dim0 * BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_swizzle_elems)

a_tma_rows

comptime a_tma_rows = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_tile_dim0

accum_dims

comptime accum_dims = get_accumulator_dims[c_smem_dim1=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputN, block_tile_shape=config.block_tile_shape, mma_shape=config.mma_shape, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]()

accum_pipeline_consumer_arv_count

comptime accum_pipeline_consumer_arv_count = compute_accum_barrier_counts[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]()[1]

accum_pipeline_producer_arv_count

comptime accum_pipeline_producer_arv_count = compute_accum_barrier_counts[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]()[0]

accum_type

comptime accum_type = DType.float32

AccumTensor

comptime AccumTensor = TmemTensor[DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].tmem_accum_layout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]

Accumulator

comptime Accumulator = BlockwiseFP8Accumulator[DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].accum_dims[0], BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].accum_dims[1], BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].is_lower_required, config.block_tile_shape, config.mma_shape, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_SIZE]

ADescLayout

comptime ADescLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_swizzle_elems], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_swizzle_elems], ComptimeInt[1]]

AScalesLayout

comptime AScalesLayout = Layout[ComptimeInt[1], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM], ComptimeInt[1]]

AScalesTmaOp

comptime AScalesTmaOp = TMATensorTile[a_scales_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout]()]

ATileLayout

comptime ATileLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK], ComptimeInt[1]]

ATmaOp

comptime ATmaOp = TMATensorTile[a_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ADescLayout]()]

b_expected_bytes

comptime b_expected_bytes = ((BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BN * BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK) * size_of[b_type]())

b_swizzle_elems

comptime b_swizzle_elems = (config.b_swizzle.bytes() // size_of[b_type]())

b_tile_dim0

comptime b_tile_dim0 = compute_tma_tile_dims[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BN, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_N, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]()[1]

b_tma_load_size

comptime b_tma_load_size = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_tile_dim0 * BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_swizzle_elems)

b_tma_rows

comptime b_tma_rows = BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_tile_dim0

BDescLayout

comptime BDescLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_swizzle_elems], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_swizzle_elems], ComptimeInt[1]]

BK

comptime BK = config.block_tile_shape[2]

BM

comptime BM = config.block_tile_shape[0]

BN

comptime BN = config.block_tile_shape[1]

BScalesTile

comptime BScalesTile = TileTensor[b_scales_type, b_scales_layout, ImmutAnyOrigin]

BTileLayout

comptime BTileLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BK], ComptimeInt[1]]

BTmaOp

comptime BTmaOp = TMATensorTile[b_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BDescLayout]()]

c_swizzle_elems

comptime c_swizzle_elems = (config.c_swizzle.bytes() // size_of[c_type]())

c_tile_dim0

comptime c_tile_dim0 = compute_tma_tile_dims[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BN, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_N, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]()[2]

CDescLayout

comptime CDescLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].c_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].c_swizzle_elems], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].c_swizzle_elems], ComptimeInt[1]]

clc_consumer_arv_count

comptime clc_consumer_arv_count = compute_clc_barrier_counts[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SCHEDULER_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TMA_LOAD_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_SIZE, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]()[1]

clc_producer_arv_count

comptime clc_producer_arv_count = compute_clc_barrier_counts[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SCHEDULER_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TMA_LOAD_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_SIZE, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]()[0]

clc_throttle_consumer_arv_count

comptime clc_throttle_consumer_arv_count = compute_clc_barrier_counts[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SCHEDULER_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TMA_LOAD_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_SIZE, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]()[3]

clc_throttle_producer_arv_count

comptime clc_throttle_producer_arv_count = compute_clc_barrier_counts[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SCHEDULER_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TMA_LOAD_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_SIZE, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group]()[2]

CLUSTER_M

comptime CLUSTER_M = config.cluster_shape[0]

CLUSTER_N

comptime CLUSTER_N = config.cluster_shape[1]

CLUSTER_SIZE

comptime CLUSTER_SIZE = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_M * BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_N)

Context

comptime Context = KernelContext[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_clc_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_N]

cta_group

comptime cta_group = config.cta_group

CTileLayout

comptime CTileLayout = Layout[ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].c_tile_dim0], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputN], ComptimeInt[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputN], ComptimeInt[1]]

CTmaOp

comptime CTmaOp = TMATensorTile[c_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTileLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTileLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTileLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CDescLayout]()]

EPILOGUE_THREADS

comptime EPILOGUE_THREADS = (4 * WARP_SIZE)

EpilogueCtx

comptime EpilogueCtx = EpilogueWarpContext[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].opc, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]

EpilogueHandle

comptime EpilogueHandle = EpilogueWarp[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].opc, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]

input_expected_bytes

comptime input_expected_bytes = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group * ((BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_expected_bytes + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].b_expected_bytes) + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].a_scales_expected_bytes))

InputTilePipeline

comptime InputTilePipeline = InputTilePipeline[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TilePayload, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.Core.num_group_pipeline_stages, config.k_group_size]

is_lower_required

comptime is_lower_required = is_lower_fragment_required[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, config.block_tile_shape]()

MMA_K

comptime MMA_K = config.mma_shape[2]

MMA_M

comptime MMA_M = config.mma_shape[0]

MMA_N

comptime MMA_N = config.mma_shape[1]

MMA_THREADS

comptime MMA_THREADS = WARP_SIZE

MmaCtx

comptime MmaCtx = MmaWarpContext[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].opc, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]

MmaHandle

comptime MmaHandle = MmaWarp[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].opc, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS]

MmaOp

comptime MmaOp = MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_clc_pipeline_stages

comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages

num_group_pipeline_stages

comptime num_group_pipeline_stages = (BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_pipeline_stages // config)

num_output_stages

comptime num_output_stages = config.num_output_stages

num_output_warps

comptime num_output_warps = 4

num_pipeline_stages

comptime num_pipeline_stages = config.num_pipeline_stages

NUM_THREADS

comptime NUM_THREADS = (((BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SCHEDULER_THREADS + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TMA_LOAD_THREADS) + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_THREADS) + BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].EPILOGUE_THREADS)

NUM_TMEM_COLS

comptime NUM_TMEM_COLS = 512

opc

comptime opc = OutputPipelineConfig(BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_accum_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].stage_stride_cols, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group)

OutputM

comptime OutputM = config.output_tile_shape[0]

OutputN

comptime OutputN = config.output_tile_shape[1]

OutputPipeline

comptime OutputPipeline = OutputTilePipeline[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].opc]

Scheduler

comptime Scheduler = TileScheduler[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_clc_pipeline_stages, Index[Int, Int, Int, dtype=DType.uint32](config.cluster_shape[0], config.cluster_shape[1], config.cluster_shape[2]), config.raster_order, config.block_swizzle_size]

SCHEDULER_THREADS

comptime SCHEDULER_THREADS = WARP_SIZE

SmemType

comptime SmemType = BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config]

stage_stride_cols

comptime stage_stride_cols = (512 // config)

TilePayload

comptime TilePayload = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, IndexList(BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.Core.BM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.Core.BN, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.Core.BK, __list_literal__=Tuple()), IndexList(1, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.Core.BM, __list_literal__=Tuple()), BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.Core.num_pipeline_stages]

TileWriterType

comptime TileWriterType = BlockwiseFP8TileWriter[c_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputM, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].OutputN, DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].accum_dims[0], BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].accum_dims[1], block_tile_shape=config.block_tile_shape, mma_shape=config.mma_shape, is_lower_frag_required=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].is_lower_required, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, num_output_stages=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_output_stages, num_output_warps=4, c_swizzle=config.c_swizzle]

TMA_LOAD_THREADS

comptime TMA_LOAD_THREADS = WARP_SIZE

Tmem

comptime Tmem = TmemAllocation[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].opc.cta_group]

tmem_accum_layout

comptime tmem_accum_layout = Layout.row_major(BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].MMA_N)

TmemDealloc

comptime TmemDealloc = TmemDeallocBarrier[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].opc.cta_group]

Methods

load_input_tiles

static load_input_tiles[a_tma_origin: ImmutOrigin, b_tma_origin: ImmutOrigin, a_scales_tma_origin: ImmutOrigin, tiles_origin: MutOrigin, //](a_loader: TileLoader[a_tma_origin, a_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ADescLayout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group], b_loader: TileLoader[b_tma_origin, b_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BDescLayout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group], a_scales_loader: ScalesLoader[a_scales_tma_origin, a_scales_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group], tiles: InputProducerStage[tiles_origin, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TilePayload, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.Core.num_group_pipeline_stages, config.k_group_size], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int], iter_idx: Int, elect_one_cta: Bool)

Load A, B, and A-scales tiles using TMA.

Args:

  • a_loader (TileLoader): TileLoader for A matrix.
  • b_loader (TileLoader): TileLoader for B matrix.
  • a_scales_loader (ScalesLoader): ScalesLoader for A-scales.
  • tiles (InputProducerStage): InputProducerStage context with encapsulated tile access.
  • peer_cta_coord (Tuple): Peer CTA coordinates for multicast.
  • work_tile_coord (Tuple): Current work tile M/N coordinates.
  • iter_idx (Int): K iteration index.
  • elect_one_cta (Bool): Whether this is the elected CTA in the cluster.

mma

static mma[tiles_origin: MutOrigin, //](tiles: InputConsumerStage[tiles_origin, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].TilePayload, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].SmemType.Core.num_group_pipeline_stages, config.k_group_size], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b], accum_tensor: TmemTensor[DType.float32, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].tmem_accum_layout, cta_group=BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group])

Execute standard MMA operations (partial results to TMEM).

For blockwise FP8, each K iteration writes a fresh partial to TMEM. The epilogue accumulates across K in registers, not TMEM. Therefore init_c is always True (unlike standard matmul).

Args:

  • tiles (InputConsumerStage): Input consumer stage with A, B, A-scales tiles.
  • mma_op (MmaOpSM100_SS): The MMA operator.
  • accum_tensor (TmemTensor): Typed TMEM tensor view for the accumulator stage.

validate_config

static validate_config()

Validate configuration constraints at compile time.

init_barriers

static init_barriers(ctx: KernelContext[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].num_clc_pipeline_stages, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].cta_group, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_M, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CLUSTER_N], a_tma_op: TMATensorTile[a_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BDescLayout]()], c_tma_op: TMATensorTile[c_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTileLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTileLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTileLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CDescLayout]()], a_scales_tma_op: TMATensorTile[a_scales_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout]()], input_barriers: SMemArray[SharedMemBarrier, (BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Core.num_group_pipeline_stages * 2)], accum_barriers: SMemArray[SharedMemBarrier, (BlockwiseFP8Smem[a_type, b_type, c_type, a_scales_type, transpose_b, config=config].Core.num_accum_pipeline_stages * 2)], clc_throttle: SMemArray[SharedMemBarrier, (config * 2)], clc_full: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], clc_empty: SMemArray[SharedMemBarrier, config.num_clc_pipeline_stages], tmem_dealloc: SMemArray[SharedMemBarrier, 1])

Initialize barriers and prefetch TMA descriptors.

run

static run(a_tma_op: TMATensorTile[a_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ATileLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].ADescLayout]()], b_tma_op: TMATensorTile[b_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BTileLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].BDescLayout]()], c_tma_op: TMATensorTile[c_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTileLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTileLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CTileLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].CDescLayout]()], a_scales_tma_op: TMATensorTile[a_scales_type, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout.rank, _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout](), _to_index_list[BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout.rank, BlackwellBlockwiseFP8MatmulKernel[a_type, b_type, c_type, a_scales_type, b_scales_type, b_scales_layout, transpose_b, config, cluster_shape].AScalesLayout]()], cluster_dim: StaticTuple[Int32, 3], num_iters: Int, b_scales: TileTensor[b_scales_type, b_scales_layout, ImmutAnyOrigin], problem_shape: StaticTuple[Int32, 3])

Kernel entry point for blockwise FP8 matmul.

Was this page helpful?