Mojo struct
KBufferRDNA
struct KBufferRDNA[cache_dtype: DType, gmem_layout: TensorLayout, //, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1]
K buffer: holds a (BN, depth) DRAM tile reference, a register load_tile that staggers DMA across BK strips, an mma_tile for the current K fragment, and a (BN, BK) LDS region for the staged strip.
Fieldsβ
- βload_tile (
KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].LoadTileType): - βmma_tile (
KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].MMATileType): - βsmem_tile (
KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].SmemTileType): - βgmem_tile (
TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin]): - βstrip_idx (
Int): - βload_tile_id (
Int):
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
load_layoutβ
comptime load_layout = row_major[(num_stages * (KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].num_mmas * KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].num_k_tiles)), KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].simd_width]()
LoadTileTypeβ
comptime LoadTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]
MMA_Kβ
comptime MMA_K = RDNA_MMA_K
MMA_Mβ
comptime MMA_M = RDNA_MMA_M
MMA_Nβ
comptime MMA_N = RDNA_MMA_N
mma_tile_layoutβ
comptime mma_tile_layout = row_major[KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].num_mmas, 16]()
MMATileTypeβ
comptime MMATileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]
num_k_tilesβ
comptime num_k_tiles = ceildiv(BK, (16 * group_size))
num_mmasβ
comptime num_mmas = ceildiv(WN, 16)
num_warps_nβ
comptime num_warps_n = (BN // WN)
simd_widthβ
comptime simd_width = simd_width_of[cache_dtype]()
smem_layoutβ
comptime smem_layout = row_major[BN, BK]()
SmemTileTypeβ
comptime SmemTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
Methodsβ
__init__β
__init__(out self, gmem_tile: TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin], shared_ptr: UnsafePointer[Scalar[cache_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])
get_dtypeβ
load_from_dramβ
load_from_dram(mut self)
Load the next BK strip of K from DRAM into the staging register slot. The gmem_tile's runtime row count drives SRD OOB clamping for unaligned tail tiles.
get_mma_tileβ
get_mma_tile(self) -> KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].MMATileType
Returns:
KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].MMATileType
copy_to_sharedβ
copy_to_shared[tile_id: Int = 0](self)
Write the staging register slot tile_id to LDS, distributing the (BN, BK) tile across threads using the same row_major( _thread_rows, _thread_cols) layout as load_from_dram.
load_from_sharedβ
load_from_shared[k_mma: Int](self)
SMEM->fragment, wave-cooperative.
K is the A operand under swap_a_b. RDNA WMMA A maps a_frag[v] = A[lane % 16, v], so lane selects key (K row) and element selects depth (K column).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!