Skip to main content

Mojo struct

KBufferRDNA

struct KBufferRDNA[dtype: DType, layout: Layout, address_space: AddressSpace, alignment: Int, mut: Bool, _mlir_origin: LITOrigin[mut._mlir_value], origin: Origin[mut=mut], masked: Bool, layout_int_type: DType, linear_idx_type: DType, out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, //, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: Optional[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1]

RDNA-specific K buffer for Wave32 WMMA attention.

Fields

  • load_tile (KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].LoadTileType):
  • mma_tile (KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].MMATileType):
  • smem_iter (KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].SharedIterType):
  • bounds (Int):
  • load_tile_id (Int):
  • global_iterator (KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].GlobalTiledIteratorType):

Implemented traits

AnyType, ImplicitlyDestructible, KVBuffer

comptime members

__del__is_trivial

comptime __del__is_trivial = True

base_layout

comptime base_layout = Layout.row_major(BN, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].simd_width)

GlobalTensorType

comptime GlobalTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]

GlobalTiledIteratorType

comptime GlobalTiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[BN, BK]()[0], origin, address_space=address_space, axis=1, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, BN, BK]()]

LoadTileType

comptime LoadTileType = LayoutTensor[dtype, Layout.row_major(((num_stages * KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].num_mmas) * KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].num_k_tiles), KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]

MMA_K

comptime MMA_K = RDNA_MMA_K

MMA_M

comptime MMA_M = RDNA_MMA_M

MMA_N

comptime MMA_N = RDNA_MMA_N

mma_tile_layout

comptime mma_tile_layout = Layout.row_major(KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].num_mmas, 16)

MMATileType

comptime MMATileType = LayoutTensor[dtype, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

num_k_tiles

comptime num_k_tiles = ceildiv(BK, (16 * group_size))

num_mmas

comptime num_mmas = ceildiv(WN, 16)

num_repeats

comptime num_repeats = (BK // KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].simd_width)

num_warps_n

comptime num_warps_n = (BN // WN)

SharedIterType

comptime SharedIterType = LayoutTensorIter[dtype, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]

SharedTileType

comptime SharedTileType = KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].SharedIterType.LayoutTensorType

SharedWarpTileType

comptime SharedWarpTileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].wtile_dim0, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].wtile_dim1]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_index_type(AddressSpace.SHARED), linear_idx_type=_get_index_type(AddressSpace.SHARED), masked=_tile_is_masked[KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].smem_layout, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].wtile_dim0, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].wtile_dim1]()]

simd_width

comptime simd_width = simd_width_of[dtype]()

smem_layout

comptime smem_layout = blocked_product(KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].base_layout, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].tiler_layout, True)

thread_layout

comptime thread_layout = Layout.row_major((num_threads // 4), 4)

tiler_layout

comptime tiler_layout = Layout.row_major(1, KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].num_repeats)

wtile_dim0

comptime wtile_dim0 = WN

wtile_dim1

comptime wtile_dim1 = BK

Methods

__init__

__init__(out self, global_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_b_rows: OptionalReg[Int], shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])

get_dtype

static get_dtype() -> DType

Returns:

DType

load_from_dram

load_from_dram(mut self)

get_mma_tile

get_mma_tile(self) -> KBufferRDNA[tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages].MMATileType

Returns:

KBufferRDNA

copy_to_shared

copy_to_shared[tile_id: Int = 0](self)

load_from_shared

load_from_shared[k_mma: Int](self)

Was this page helpful?