Mojo struct
VBufferRDNA
struct VBufferRDNA[cache_dtype: DType, gmem_layout: TensorLayout, //, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], BN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, num_warps_n: Int = 1]
V buffer with transpose-on-LDS-write.
V is read from DRAM in row-major (BK x depth) chunks but written to
LDS as pad(depth) x BK blocked in simd_width-wide column groups,
so the per-warp depth slice reads contiguously during the PV MMA.
Fieldsβ
- βload_tile (
VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].LoadTileType): - βmma_tile (
VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].MMATileType): - βsmem_tile (
VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].SmemTileType): - βgmem_tile (
TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin]): - βstrip_idx (
Int): - βcurrent_stage (
Int): - βremaining_rows (
Int):
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
depth_tile_sizeβ
comptime depth_tile_size = min(depth, 128)
load_layoutβ
comptime load_layout = row_major[((VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].loads_per_thread_per_depth_tile * (depth // VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].depth_tile_size)) * num_stages), VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].load_width]()
load_widthβ
comptime load_width = 4 if (depth == 64) else VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].simd_width
loads_per_thread_per_depth_tileβ
comptime loads_per_thread_per_depth_tile = ((VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].depth_tile_size * BK) // (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].load_width * num_threads))
LoadTileTypeβ
comptime LoadTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]
MMA_Kβ
comptime MMA_K = RDNA_MMA_K
MMA_Mβ
comptime MMA_M = RDNA_MMA_M
mma_tile_layoutβ
comptime mma_tile_layout = row_major[VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].warp_depth_tiles, 16]()
MMATileTypeβ
comptime MMATileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]
num_depth_tilesβ
comptime num_depth_tiles = (depth // 16)
num_k_tilesβ
comptime num_k_tiles = ceildiv(BK, (16 * group_size))
num_repeatsβ
comptime num_repeats = (BK // VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].simd_width)
simd_widthβ
comptime simd_width = simd_width_of[cache_dtype]()
smem_layoutβ
comptime smem_layout = row_major[(pad[cache_dtype, depth, depth]() * VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].num_repeats), VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].simd_width]()
SmemTileTypeβ
comptime SmemTileType = TileTensor[cache_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
warp_depth_tilesβ
comptime warp_depth_tiles = ((depth // num_warps_n) // 16)
Methodsβ
__init__β
__init__(out self, gmem_tile: TileTensor[cache_dtype, gmem_layout, ImmutAnyOrigin], shared_ptr: UnsafePointer[Scalar[cache_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], total_rows: OptionalReg[Int] = None)
get_dtypeβ
padβ
load_from_dramβ
load_from_dram(mut self)
Load the next BK strip of V from DRAM into the staging register slot. Per-thread chunks are bounds-clamped via remaining_rows for the unaligned tail tile.
get_mma_tileβ
get_mma_tile(self) -> VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].MMATileType
Returns:
VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].MMATileType
copy_to_sharedβ
copy_to_shared[tile_id: Int = 0](self)
V transpose-on-write: smem[depth_pos, seq_pos] from load_tile[seq_pos, depth_pos] (with depth_pos blocked into simd_width-wide column groups along the SMEM row axis).
load_from_sharedβ
load_from_shared[k_mma: Int](self)
SMEM->fragment, wave-cooperative.
V is transposed in SMEM as [depth, key]. Under swap_a_b V is
the A operand, so lane selects the depth row and element
selects the key column.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!