Skip to main content

Mojo struct

VBufferRDNA

struct VBufferRDNA[dtype: DType, layout: Layout, address_space: AddressSpace, alignment: Int, mut: Bool, _mlir_origin: LITOrigin[mut._mlir_value], origin: Origin[mut=mut], masked: Bool, layout_int_type: DType, linear_idx_type: DType, out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, //, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], BN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, num_warps_n: Int = 1]

RDNA-specific V buffer with transpose loads for Wave32 WMMA.

Fields

  • load_tile (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].LoadTileType):
  • mma_tile (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].MMATileType):
  • smem_iter (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].SharedIterType):
  • global_iterator (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].GlobalTiledIteratorType):
  • global_base_tile (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].GlobalTensorType):
  • current_stage (Int):
  • remaining_rows (Int):

Implemented traits

AnyType, ImplicitlyDestructible, KVBuffer

comptime members

__del__is_trivial

comptime __del__is_trivial = True

base_layout

comptime base_layout = Layout.row_major(VBufferRDNA.pad[depth](), VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].simd_width)

depth_tile_size

comptime depth_tile_size = min(depth, 128)

GlobalTensorType

comptime GlobalTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]

GlobalTiledIteratorType

comptime GlobalTiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[BK, depth]()[0], origin, address_space=address_space, axis=0, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, BK, depth]()]

load_width

comptime load_width = 4 if (depth == 64)._mlir_value else VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].simd_width

loads_per_thread_per_depth_tile

comptime loads_per_thread_per_depth_tile = ((VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].depth_tile_size * BK) // (VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].load_width * num_threads))

LoadTileType

comptime LoadTileType = LayoutTensor[dtype, Layout.row_major(((VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].loads_per_thread_per_depth_tile * (depth // VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].depth_tile_size)) * num_stages), VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].load_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]

MMA_K

comptime MMA_K = RDNA_MMA_K

MMA_M

comptime MMA_M = RDNA_MMA_M

mma_tile_layout

comptime mma_tile_layout = Layout.row_major(VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].warp_depth_tiles, 16)

MMATileType

comptime MMATileType = LayoutTensor[dtype, VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

num_depth_tiles

comptime num_depth_tiles = (depth // 16)

num_k_tiles

comptime num_k_tiles = ceildiv(BK, (16 * group_size))

num_repeats

comptime num_repeats = (BK // VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].simd_width)

SharedIterType

comptime SharedIterType = LayoutTensorIter[dtype, VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]

SharedTileType

comptime SharedTileType = VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].SharedIterType.LayoutTensorType

simd_width

comptime simd_width = simd_width_of[dtype]()

smem_layout

comptime smem_layout = blocked_product(VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].base_layout, VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].tiler_layout, True)

tiler_layout

comptime tiler_layout = Layout.row_major(1, VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].num_repeats)

warp_depth_tiles

comptime warp_depth_tiles = ((depth // num_warps_n) // 16)

Methods

__init__

__init__(out self, global_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], total_rows: OptionalReg[Int] = None)

get_dtype

static get_dtype() -> DType

Returns:

DType

pad

static pad[dim: Int]() -> Int

Returns:

Int

load_from_dram

load_from_dram(mut self)

Load V tile from global memory using Wave32-aware pattern.

Thread decomposition adapts to the number of warps:

  • threads_per_row = WARP_SIZE / rows_per_warp
  • Each thread loads depth_per_thread elements from one row
  • All BK rows are covered across all warps

get_mma_tile

get_mma_tile(self) -> VBufferRDNA[tensor_core_mma, BN, BK, depth, num_threads, num_stages, num_warps_n].MMATileType

Returns:

VBufferRDNA

copy_to_shared

copy_to_shared[tile_id: Int = 0](self)

Copy V tile to shared memory with transpose for RDNA.

V is stored transposed in shared memory: (depth, BK). Thread decomposition matches load_from_dram so each thread writes the data it loaded to the correct transposed position.

load_from_shared

load_from_shared[k_mma: Int](self)

Load MMA fragments from shared memory for RDNA Wave32.

RDNA WMMA A register: a_frag[v] = A[lane%16, v] (lane selects ROW, element selects COLUMN). V goes to hardware A (via swap_a_b). For PV: D = V * P. Hardware A maps: lane = depth (row of V^T), element = key (column of V^T). So: a_frag[v] = V^T[depth=lane, key=v]. V^T is stored in shared memory as [depth, key] with key split into blocks of simd_width=8: block0 has keys 0..7, block1 has keys 8..15.

Was this page helpful?