Skip to main content

Mojo struct

KBufferRDNA

struct KBufferRDNA[cache_dtype: DType, gmem_layout: TensorLayout, //, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1]

K buffer: holds a (BN, depth) DRAM tile reference, a register load_tile that staggers DMA across BK strips, an mma_tile for the current K fragment, and a (BN, BK) LDS region for the staged strip.

Fields​

  • ​load_tile (KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].LoadTileType):
  • ​mma_tile (KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].MMATileType):
  • ​smem_tile (KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].SmemTileType):
  • ​gmem_tile (TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin]):
  • ​strip_idx (Int):
  • ​load_tile_id (Int):

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

load_layout​

comptime load_layout = row_major[(num_stages * (KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].num_mmas * KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].num_k_tiles)), KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].simd_width]()

LoadTileType​

comptime LoadTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

MMA_K​

comptime MMA_K = RDNA_MMA_K

MMA_M​

comptime MMA_M = RDNA_MMA_M

MMA_N​

comptime MMA_N = RDNA_MMA_N

mma_tile_layout​

comptime mma_tile_layout = row_major[KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].num_mmas, 16]()

MMATileType​

comptime MMATileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

num_k_tiles​

comptime num_k_tiles = ceildiv(BK, (16 * group_size))

num_mmas​

comptime num_mmas = ceildiv(WN, 16)

num_warps_n​

comptime num_warps_n = (BN // WN)

simd_width​

comptime simd_width = simd_width_of[cache_dtype]()

smem_layout​

comptime smem_layout = row_major[BN, BK]()

SmemTileType​

comptime SmemTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

Methods​

__init__​

__init__(out self, gmem_tile: TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin], shared_ptr: UnsafePointer[Scalar[cache_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])

get_dtype​

static get_dtype() -> DType

Returns:

DType

load_from_dram​

load_from_dram(mut self)

Load the next BK strip of K from DRAM into the staging register slot. The gmem_tile's runtime row count drives SRD OOB clamping for unaligned tail tiles.

get_mma_tile​

get_mma_tile(self) -> KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].MMATileType

Returns:

KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].MMATileType

copy_to_shared​

copy_to_shared[tile_id: Int = 0](self)

Write the staging register slot tile_id to LDS, distributing the (BN, BK) tile across threads using the same row_major( _thread_rows, _thread_cols) layout as load_from_dram.

load_from_shared​

load_from_shared[k_mma: Int](self)

SMEM->fragment, wave-cooperative.

K is the A operand under swap_a_b. RDNA WMMA A maps a_frag[v] = A[lane % 16, v], so lane selects key (K row) and element selects depth (K column).