Mojo struct
HopperMatmulSM90Kernel
struct HopperMatmulSM90Kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, c_smem_layout: Layout, block_tile_shape: IndexList[3], wgmma_shape: IndexList[3], cluster_shape: StaticTuple[Int32, 3], num_pipeline_stages: Int, num_threads: Int = 128, transpose_b: Bool = True, a_swizzle: TensorMapSwizzle = 3, b_swizzle: TensorMapSwizzle = 3, c_swizzle: TensorMapSwizzle = 0, partitioned_multicast: Bool = False, use_tma_store: Bool = False, promotion_frequency: Int = 1, pdl_level: PDLLevel = PDLLevel(), elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, hilbert_swizzle: Bool = False]
Hopper SM90 Matmul kernel with structured shared memory management.
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = True
a_smem_layout
alias a_smem_layout = tile_layout_k_major[a_type, block_tile_shape.__getitem__[3, DType.int64, Int](0), block_tile_shape.__getitem__[3, DType.int64, Int](2), a_swizzle]()
accum_type
alias accum_type = get_accum_type[a_type]()
AccumRegTileType
alias AccumRegTileType = LayoutTensor[get_accum_type[a_type](), Layout.row_major((((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1)) * (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))), ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[get_accum_type[a_type](), simd_width_of[get_accum_type[a_type]()]()]]()]
b_smem_layout
alias b_smem_layout = tile_layout_k_major[b_type, block_tile_shape.__getitem__[3, DType.int64, Int](1), block_tile_shape.__getitem__[3, DType.int64, Int](2), b_swizzle]()
BK
alias BK = block_tile_shape.__getitem__[3, DType.int64, Int](2)
BM
alias BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
alias BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)
c_frag_size
alias c_frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)
num_consumer
alias num_consumer = ((num_threads // 128) - 1)
num_consumer_threads
alias num_consumer_threads = (((num_threads // 128) - 1) * 128)
num_m_mmas
alias num_m_mmas = ((block_tile_shape.__getitem__[3, DType.int64, Int](0) // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // ((num_threads // 128) - 1))
num_n_mmas
alias num_n_mmas = (block_tile_shape.__getitem__[3, DType.int64, Int](1) // wgmma_shape.__getitem__[3, DType.int64, Int](1))
SMem
alias SMem = HopperMatmulSM90Kernel_SMem[a_type, tile_layout_k_major[a_type, block_tile_shape.__getitem__[3, DType.int64, Int](0), block_tile_shape.__getitem__[3, DType.int64, Int](2), a_swizzle](), b_type, tile_layout_k_major[b_type, block_tile_shape.__getitem__[3, DType.int64, Int](1), block_tile_shape.__getitem__[3, DType.int64, Int](2), b_swizzle](), c_type, c_smem_layout, num_pipeline_stages, ((num_threads // 128) - 1), Int.__init__[Int32](cluster_size[cluster_shape]())]
Methods
num_regs
validate_constraints
static validate_constraints()
Validate common constraints for all kernel variants.
async_load_AB_tma
static async_load_AB_tma[a_tile_layout: Layout, b_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, /, *, num_k_iters: Int, tile_shape: IndexList[3], cluster_dims: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), use_partitioned_multicast: Bool = False](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], a_smem_iter: LayoutTensorIter[a_type, tile_layout_k_major[a_type, block_tile_shape.__getitem__[3, DType.int64, Int](0), block_tile_shape.__getitem__[3, DType.int64, Int](2), a_swizzle](), MutableAnyOrigin, address_space=AddressSpace(3), alignment=128], b_smem_iter: LayoutTensorIter[b_type, tile_layout_k_major[b_type, block_tile_shape.__getitem__[3, DType.int64, Int](1), block_tile_shape.__getitem__[3, DType.int64, Int](2), b_swizzle](), MutableAnyOrigin, address_space=AddressSpace(3), alignment=128], m_coord: UInt, n_coord: UInt, k_coord: UInt, rank_n: UInt, rank_m: UInt, ring_buffer: RingBuffer[num_pipeline_stages, num_consumers, cluster_size], mut write_pipeline_states: PipelineState[num_pipeline_stages])
Load A and B tiles using TMA (Tensor Memory Accelerator).
async_load_AB_cpasync
static async_load_AB_cpasync[a_mem_layout: Layout, b_mem_layout: Layout, //, *, pipeline_stages: Int, swizzle_mode: TensorMapSwizzle, cp_size: Int, num_k_iters: Int, tile_shape: IndexList[3]](a: LayoutTensor[a_type, a_mem_layout, MutableAnyOrigin], b: LayoutTensor[b_type, b_mem_layout, MutableAnyOrigin], block_idx_m: UInt, block_idx_n: UInt, a_smem_iter: LayoutTensorIter[a_type, tile_layout_k_major[a_type, block_tile_shape.__getitem__[3, DType.int64, Int](0), block_tile_shape.__getitem__[3, DType.int64, Int](2), a_swizzle](), MutableAnyOrigin, address_space=AddressSpace(3), alignment=128], b_smem_iter: LayoutTensorIter[b_type, tile_layout_k_major[b_type, block_tile_shape.__getitem__[3, DType.int64, Int](1), block_tile_shape.__getitem__[3, DType.int64, Int](2), b_swizzle](), MutableAnyOrigin, address_space=AddressSpace(3), alignment=128], ring_buffer: RingBuffer[num_pipeline_stages, num_consumers, cluster_size], mut write_pipeline_states: PipelineState[num_pipeline_stages])
Load A and B tiles using cp.async for unaligned memory access.
async_copy_with_bound_check
static async_copy_with_bound_check[dtype: DType, thread_layout: Layout, swizzle_mode: TensorMapSwizzle](src: LayoutTensor[dtype, layout, MutableAnyOrigin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dst: LayoutTensor[dtype, layout, MutableAnyOrigin, address_space=AddressSpace(3), element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])
Helper function for cp.async with bound checking.
finalize_kernel
static finalize_kernel()
Common finalization for all kernel variants.
run
static run[a_tile_layout: Layout, b_tile_layout: Layout, c_tma_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tma_layout, c_desc_layout], a: LayoutTensor[a_type, a_layout, MutableAnyOrigin], b: LayoutTensor[b_type, b_layout, MutableAnyOrigin], c: LayoutTensor[c_type, c_layout, MutableAnyOrigin], lut_ptr: DeviceBuffer[DType.uint32])
run_persistent
static run_persistent[a_tile_layout: Layout, b_tile_layout: Layout, c_tma_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, grid_shape: IndexList[2], schedule: MatmulSchedule](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tma_layout, c_desc_layout], c: LayoutTensor[c_type, c_layout, MutableAnyOrigin], problem_shape: IndexList[3])
run_unaligned
static run_unaligned[c_desc_layout: Layout, c_tma_layout: Layout, pipeline_stages: Int = 7](c_tma_op: TMATensorTile[c_type, c_tma_layout, c_desc_layout], a: LayoutTensor[a_type, a_layout, MutableAnyOrigin], b: LayoutTensor[b_type, b_layout, MutableAnyOrigin], c: LayoutTensor[c_type, c_layout, MutableAnyOrigin])
Kernel using cp.async for A/B loading when K alignment doesn't meet TMA requirements.
run_splitk
static run_splitk[a_tile_layout: Layout, b_tile_layout: Layout, c_tma_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, splits: Int, raster_order: RasterOrder](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tma_layout, c_desc_layout], c: LayoutTensor[c_type, c_layout, MutableAnyOrigin], workspace_buffer: NDBuffer[get_accum_type[a_type](), 3, MutableAnyOrigin], locks_ptr: UnsafePointer[NoneType], problem_shape: IndexList[3])
Split-K variant of the kernel for better load balancing on small problems.
run_grouped
static run_grouped[a_tile_layout: Layout, b_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_smem_layout, c_desc_layout], a_offsets: NDBuffer[DType.uint32, 1, MutableAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutableAnyOrigin], c: LayoutTensor[c_type, c_layout, MutableAnyOrigin])
Grouped matmul variant for MoE (Mixture of Experts) models.
This variant handles multiple experts where each expert processes a subset of tokens. The a_offsets array indicates token boundaries for each expert.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!