IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

QRegisterBufferRDNA

struct QRegisterBufferRDNA[dtype: DType, mma_shape: IndexList[3], k_group_size: Int, WM: Int, WN: Int, BN: Int, BK: Int, depth: Int]

Q register buffer: loads each warp's (WM, depth) Q sub-tile into BK-strip MMA fragments at construction.

Fields​

  • ​reg_tile (QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].RegisterTileType):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

mma_dtype​

comptime mma_dtype = dtype

MMA_K​

comptime MMA_K = mma_shape[2]

MMA_M​

comptime MMA_M = mma_shape[0]

mma_tile_layout​

comptime mma_tile_layout = row_major[QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_mmas, 16]()

MMATileType​

comptime MMATileType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

num_k_tiles​

comptime num_k_tiles = ceildiv(BK, (QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].MMA_K * k_group_size))

num_mmas​

comptime num_mmas = ceildiv(WM, QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].MMA_M)

num_tiles​

comptime num_tiles = (depth // BK)

rdna_frag_size​

comptime rdna_frag_size = RDNA_AB_FRAG_SIZE

reg_dtype​

comptime reg_dtype = dtype

reg_tile_layout​

comptime reg_tile_layout = row_major[((QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_mmas * QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_k_tiles) * QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_tiles), 16]()

RegisterTileType​

comptime RegisterTileType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

simd_width​

comptime simd_width = simd_width_of[dtype]()

Methods​

__init__​

def __init__[q_layout: TensorLayout](out self, tensor: TileTensor[dtype, q_layout, ImmutAnyOrigin], valid_rows: Int)

valid_rows is the Q row bound for OOB clamping (= group for decode, clamped seq tile size for prefill).

get_dtype​

static def get_dtype() -> DType

Returns:

DType

get_mma_tile​

def get_mma_tile[tile_idx: Int, k_idx: Int](self) -> Self.MMATileType

MMA fragment for the (tile_idx-th depth tile, k_idx-th K strip within it).

Returns:

Self.MMATileType

get_reg_tile​

def get_reg_tile[stage: Int = 0](self) -> Self.RegisterTileType

Returns:

Self.RegisterTileType

zero​

def zero(self)