Mojo struct
KBufferRDNA
struct KBufferRDNA[dtype: DType, layout: Layout, address_space: AddressSpace, alignment: Int, origin: Origin[mut=origin.mut], masked: Bool, layout_int_type: DType, linear_idx_type: DType, //, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1]
RDNA-specific K buffer for Wave32 WMMA attention.
Fields
- load_tile (
KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].LoadTileType): - mma_tile (
KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].MMATileType): - smem_ptr (
UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]): - bounds (
Int): - load_tile_id (
Int): - global_iterator (
KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].GlobalTiledIteratorType):
Implemented traits
AnyType,
ImplicitlyDestructible,
KVBuffer
comptime members
base_layout
comptime base_layout = Layout.row_major(VariadicList(BN, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].simd_width))
GlobalTensorType
comptime GlobalTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]
GlobalTiledIteratorType
comptime GlobalTiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[BN, BK]()[0], origin, address_space=address_space, axis=1, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked or _tile_is_masked[layout, BN, BK]()]
LoadTileType
comptime LoadTileType = LayoutTensor[dtype, Layout.row_major(VariadicList(((num_stages * KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].num_mmas) * KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].num_k_tiles), KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].simd_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL]
MMA_K
comptime MMA_K = RDNA_MMA_K
MMA_M
comptime MMA_M = RDNA_MMA_M
MMA_N
comptime MMA_N = RDNA_MMA_N
mma_tile_layout
comptime mma_tile_layout = Layout.row_major(VariadicList(KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].num_mmas, 16))
MMATileType
comptime MMATileType = LayoutTensor[dtype, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
num_k_tiles
comptime num_k_tiles = ceildiv(BK, (16 * group_size))
num_mmas
comptime num_mmas = ceildiv(WN, 16)
num_repeats
comptime num_repeats = (BK // KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].simd_width)
num_warps_n
comptime num_warps_n = (BN // WN)
SharedTileType
comptime SharedTileType = LayoutTensor[dtype, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED]
SharedWarpTileType
comptime SharedWarpTileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].wtile_dim0, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].wtile_dim1]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].smem_layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].smem_layout, AddressSpace.SHARED), masked=_tile_is_masked[KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].smem_layout, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].wtile_dim0, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].wtile_dim1]()]
simd_width
comptime simd_width = simd_width_of[dtype]()
smem_layout
comptime smem_layout = blocked_product(KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].base_layout, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].tiler_layout, True)
thread_layout
comptime thread_layout = Layout.row_major(VariadicList((num_threads // 4), 4))
tiler_layout
comptime tiler_layout = Layout.row_major(VariadicList(1, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].num_repeats))
wtile_dim0
comptime wtile_dim0 = WN
wtile_dim1
comptime wtile_dim1 = BK
Methods
__init__
__init__(out self, global_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_b_rows: OptionalReg[Int], shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])
get_dtype
load_from_dram
load_from_dram(mut self)
get_mma_tile
get_mma_tile(self) -> KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].MMATileType
Returns:
KBufferRDNA
copy_to_shared
copy_to_shared[tile_id: Int = 0](self)
load_from_shared
load_from_shared[k_mma: Int](self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!