IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

KBufferRDNA

struct KBufferRDNA[cache_dtype: DType, gmem_layout: TensorLayout, //, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1]

K buffer: holds a (BN, depth) DRAM tile reference, a register load_tile that staggers DMA across BK strips, an mma_tile for the current K fragment, and a (BN, BK) LDS region for the staged strip.

Fields​

  • ​load_tile (KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].LoadTileType):
  • ​mma_tile (KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].MMATileType):
  • ​smem_tile (KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].SmemTileType):
  • ​gmem_tile (TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin]):
  • ​strip_idx (Int):
  • ​load_tile_id (Int):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

load_layout​

comptime load_layout = row_major[(num_stages * (KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].num_mmas * KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].num_k_tiles)), KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].simd_width]()

LoadTileType​

comptime LoadTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

MMA_K​

comptime MMA_K = RDNA_MMA_K

MMA_M​

comptime MMA_M = RDNA_MMA_M

MMA_N​

comptime MMA_N = RDNA_MMA_N

mma_tile_layout​

comptime mma_tile_layout = row_major[KBufferRDNA[tensor_core_mma, BN, WN, BK, depth, num_threads, num_stages].num_mmas, 16]()

MMATileType​

comptime MMATileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

num_k_tiles​

comptime num_k_tiles = ceildiv(BK, (16 * group_size))

num_mmas​

comptime num_mmas = ceildiv(WN, 16)

num_warps_n​

comptime num_warps_n = (BN // WN)

simd_width​

comptime simd_width = simd_width_of[cache_dtype]()

smem_layout​

comptime smem_layout = row_major[BN, BK]()

SmemTileType​

comptime SmemTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

Methods​

__init__​

def __init__(out self, gmem_tile: TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin], shared_ptr: UnsafePointer[Scalar[cache_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])

get_dtype​

static def get_dtype() -> DType

Returns:

DType

load_from_dram​

def load_from_dram(mut self)

Load the next BK strip of K from DRAM into the staging register slot. The gmem_tile's runtime row count drives SRD OOB clamping for unaligned tail tiles.

get_mma_tile​

def get_mma_tile(self) -> Self.MMATileType

Returns:

Self.MMATileType

copy_to_shared​

def copy_to_shared[tile_id: Int = 0](self)

Write the staging register slot tile_id to LDS, distributing the (BN, BK) tile across threads using the same row_major( _thread_rows, _thread_cols) layout as load_from_dram.

load_from_shared​

def load_from_shared[k_mma: Int](self)

SMEM->fragment, wave-cooperative.

K is the A operand under swap_a_b. RDNA WMMA A maps a_frag[v] = A[lane % 16, v], so lane selects key (K row) and element selects depth (K column).