IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

VBufferRDNA

struct VBufferRDNA[cache_dtype: DType, gmem_layout: TensorLayout, //, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], BN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = Int(1), num_warps_n: Int = Int(1)]

V buffer with transpose-on-LDS-write.

V is read from DRAM in row-major (BK x depth) chunks but written to LDS as pad(depth) x BK blocked in simd_width-wide column groups, so the per-warp depth slice reads contiguously during the PV MMA.

Fields​

  • ​load_tile (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].LoadTileType):
  • ​mma_tile (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].MMATileType):
  • ​smem_tile (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].SmemTileType):
  • ​gmem_tile (TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin]):
  • ​strip_idx (Int):
  • ​current_stage (Int):
  • ​remaining_rows (Int):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

depth_tile_size​

comptime depth_tile_size = min(depth, Int(128))

load_layout​

comptime load_layout = row_major[(Int((mul (depth // min(depth, Int(128))), (Int((mul min(depth, Int(128)), BK)) // Int((mul Int(4) if (eq depth, 64) else simd_width_of[cache_dtype](), num_threads))))) * num_stages), Int(4) if (eq depth, 64) else simd_width_of[cache_dtype]()]()

load_width​

comptime load_width = Int(4) if (depth == Int(64)) else VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].simd_width

loads_per_thread_per_depth_tile​

comptime loads_per_thread_per_depth_tile = ((VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].depth_tile_size * BK) // Int((mul Int(4) if (eq depth, 64) else simd_width_of[cache_dtype](), num_threads)))

LoadTileType​

comptime LoadTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

MMA_K​

comptime MMA_K = RDNA_MMA_K

MMA_M​

comptime MMA_M = RDNA_MMA_M

mma_tile_layout​

comptime mma_tile_layout = row_major[VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].warp_depth_tiles, Int(16)]()

MMATileType​

comptime MMATileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

num_depth_tiles​

comptime num_depth_tiles = (depth // Int(16))

num_k_tiles​

comptime num_k_tiles = ceildiv(BK, (Int(16) * group_size))

num_repeats​

comptime num_repeats = (BK // simd_width_of[cache_dtype]())

simd_width​

comptime simd_width = simd_width_of[cache_dtype]()

smem_layout​

comptime smem_layout = row_major[(pad[cache_dtype, depth, depth]() * (BK // simd_width_of[cache_dtype]())), simd_width_of[cache_dtype]()]()

SmemTileType​

comptime SmemTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

warp_depth_tiles​

comptime warp_depth_tiles = ((depth // num_warps_n) // Int(16))

Methods​

__init__​

def __init__(out self, gmem_tile: TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin], shared_ptr: UnsafePointer[Scalar[cache_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], total_rows: OptionalReg[Int] = None)

get_dtype​

static def get_dtype() -> DType

Returns:

DType

pad​

static def pad[dim: Int]() -> Int

Returns:

Int

load_from_dram​

def load_from_dram(mut self)

Load the next BK strip of V from DRAM into the staging register slot. Per-thread chunks are bounds-clamped via remaining_rows for the unaligned tail tile.

get_mma_tile​

def get_mma_tile(self) -> Self.MMATileType

Returns:

Self.MMATileType

copy_to_shared​

def copy_to_shared[tile_id: Int = Int(0)](self)

V transpose-on-write: smem[depth_pos, seq_pos] from load_tile[seq_pos, depth_pos] (with depth_pos blocked into simd_width-wide column groups along the SMEM row axis).

load_from_shared​

def load_from_shared[k_mma: Int](self)

SMEM->fragment, wave-cooperative.

V is transposed in SMEM as [depth, key]. Under swap_a_b V is the A operand, so lane selects the depth row and element selects the key column.