Skip to main content

Mojo struct

VBufferRDNA

struct VBufferRDNA[cache_dtype: DType, gmem_layout: TensorLayout, //, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], BN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, num_warps_n: Int = 1]

V buffer with transpose-on-LDS-write.

V is read from DRAM in row-major (BK x depth) chunks but written to LDS as pad(depth) x BK blocked in simd_width-wide column groups, so the per-warp depth slice reads contiguously during the PV MMA.

Fields​

  • ​load_tile (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].LoadTileType):
  • ​mma_tile (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].MMATileType):
  • ​smem_tile (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].SmemTileType):
  • ​gmem_tile (TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin]):
  • ​strip_idx (Int):
  • ​current_stage (Int):
  • ​remaining_rows (Int):

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

depth_tile_size​

comptime depth_tile_size = min(depth, 128)

load_layout​

comptime load_layout = row_major[((VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].loads_per_thread_per_depth_tile * (depth // VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].depth_tile_size)) * num_stages), VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].load_width]()

load_width​

comptime load_width = 4 if (depth == 64) else VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].simd_width

loads_per_thread_per_depth_tile​

comptime loads_per_thread_per_depth_tile = ((VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].depth_tile_size * BK) // (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].load_width * num_threads))

LoadTileType​

comptime LoadTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

MMA_K​

comptime MMA_K = RDNA_MMA_K

MMA_M​

comptime MMA_M = RDNA_MMA_M

mma_tile_layout​

comptime mma_tile_layout = row_major[VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].warp_depth_tiles, 16]()

MMATileType​

comptime MMATileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

num_depth_tiles​

comptime num_depth_tiles = (depth // 16)

num_k_tiles​

comptime num_k_tiles = ceildiv(BK, (16 * group_size))

num_repeats​

comptime num_repeats = (BK // VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].simd_width)

simd_width​

comptime simd_width = simd_width_of[cache_dtype]()

smem_layout​

comptime smem_layout = row_major[(pad[cache_dtype, depth, depth]() * VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].num_repeats), VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].simd_width]()

SmemTileType​

comptime SmemTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

warp_depth_tiles​

comptime warp_depth_tiles = ((depth // num_warps_n) // 16)

Methods​

__init__​

__init__(out self, gmem_tile: TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin], shared_ptr: UnsafePointer[Scalar[cache_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], total_rows: OptionalReg[Int] = None)

get_dtype​

static get_dtype() -> DType

Returns:

DType

pad​

static pad[dim: Int]() -> Int

Returns:

Int

load_from_dram​

load_from_dram(mut self)

Load the next BK strip of V from DRAM into the staging register slot. Per-thread chunks are bounds-clamped via remaining_rows for the unaligned tail tile.

get_mma_tile​

get_mma_tile(self) -> VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].MMATileType

Returns:

VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].MMATileType

copy_to_shared​

copy_to_shared[tile_id: Int = 0](self)

V transpose-on-write: smem[depth_pos, seq_pos] from load_tile[seq_pos, depth_pos] (with depth_pos blocked into simd_width-wide column groups along the SMEM row axis).

load_from_shared​

load_from_shared[k_mma: Int](self)

SMEM->fragment, wave-cooperative.

V is transposed in SMEM as [depth, key]. Under swap_a_b V is the A operand, so lane selects the depth row and element selects the key column.