Skip to main content

Mojo struct

QRegisterBufferRDNA

struct QRegisterBufferRDNA[dtype: DType, mma_shape: IndexList[3], k_group_size: Int, WM: Int, WN: Int, BN: Int, BK: Int, depth: Int]

Q register buffer: loads each warp's (WM, depth) Q sub-tile into BK-strip MMA fragments at construction.

Fields​

  • ​reg_tile (QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].RegisterTileType):

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

mma_dtype​

comptime mma_dtype = dtype

MMA_K​

comptime MMA_K = mma_shape[2]

MMA_M​

comptime MMA_M = mma_shape[0]

mma_tile_layout​

comptime mma_tile_layout = row_major[QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_mmas, 16]()

MMATileType​

comptime MMATileType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

num_k_tiles​

comptime num_k_tiles = ceildiv(BK, (QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].MMA_K * k_group_size))

num_mmas​

comptime num_mmas = ceildiv(WM, QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].MMA_M)

num_tiles​

comptime num_tiles = (depth // BK)

rdna_frag_size​

comptime rdna_frag_size = RDNA_AB_FRAG_SIZE

reg_dtype​

comptime reg_dtype = dtype

reg_tile_layout​

comptime reg_tile_layout = row_major[((QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_mmas * QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_k_tiles) * QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_tiles), 16]()

RegisterTileType​

comptime RegisterTileType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

simd_width​

comptime simd_width = simd_width_of[dtype]()

Methods​

__init__​

__init__[q_layout: TensorLayout](out self, tensor: TileTensor[dtype, q_layout, ImmutAnyOrigin], valid_rows: Int)

valid_rows is the Q row bound for OOB clamping (= group for decode, clamped seq tile size for prefill).

get_dtype​

static get_dtype() -> DType

Returns:

DType

get_mma_tile​

get_mma_tile[tile_idx: Int, k_idx: Int](self) -> QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].MMATileType

MMA fragment for the (tile_idx-th depth tile, k_idx-th K strip within it).

Returns:

QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].MMATileType

get_reg_tile​

get_reg_tile[stage: Int = 0](self) -> QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].RegisterTileType

Returns:

QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].RegisterTileType

zero​

zero(self)