For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
QRegisterBufferRDNA
struct QRegisterBufferRDNA[dtype: DType, mma_shape: IndexList[3], k_group_size: Int, WM: Int, WN: Int, BN: Int, BK: Int, depth: Int]
Q register buffer: loads each warp's (WM, depth) Q sub-tile into BK-strip MMA fragments at construction.
Fieldsβ
- βreg_tile (
QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].RegisterTileType):
Implemented traitsβ
comptime membersβ
mma_dtypeβ
comptime mma_dtype = dtype
MMA_Kβ
comptime MMA_K = mma_shape[2]
MMA_Mβ
comptime MMA_M = mma_shape[0]
mma_tile_layoutβ
comptime mma_tile_layout = row_major[QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_mmas, 16]()
MMATileTypeβ
comptime MMATileType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]
num_k_tilesβ
comptime num_k_tiles = ceildiv(BK, (QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].MMA_K * k_group_size))
num_mmasβ
comptime num_mmas = ceildiv(WM, QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].MMA_M)
num_tilesβ
comptime num_tiles = (depth // BK)
rdna_frag_sizeβ
comptime rdna_frag_size = RDNA_AB_FRAG_SIZE
reg_dtypeβ
comptime reg_dtype = dtype
reg_tile_layoutβ
comptime reg_tile_layout = row_major[((QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_mmas * QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_k_tiles) * QRegisterBufferRDNA[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth].num_tiles), 16]()
RegisterTileTypeβ
comptime RegisterTileType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]
simd_widthβ
comptime simd_width = simd_width_of[dtype]()
Methodsβ
__init__β
def __init__[q_layout: TensorLayout](out self, tensor: TileTensor[dtype, q_layout, ImmutAnyOrigin], valid_rows: Int)
valid_rows is the Q row bound for OOB clamping (= group for decode, clamped seq tile size for prefill).
get_dtypeβ
get_mma_tileβ
def get_mma_tile[tile_idx: Int, k_idx: Int](self) -> Self.MMATileType
MMA fragment for the (tile_idx-th depth tile, k_idx-th K strip within it).
Returns:
Self.MMATileType
get_reg_tileβ
def get_reg_tile[stage: Int = 0](self) -> Self.RegisterTileType
Returns:
Self.RegisterTileType
zeroβ
def zero(self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!