Skip to main content

Mojo struct

QRegisterBuffer

struct QRegisterBuffer[dtype: DType, mma_shape: IndexList[3], WM: Int, WN: Int, BN: Int, BK: Int, depth: Int, thread_rows: Int, thread_cols: Int]

Fields​

  • ​reg_tile (QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].RegType):

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

input_frag_size​

comptime input_frag_size = num_matrix_reg[QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].MMA_M, QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].MMA_K]()

mma_dtype​

comptime mma_dtype = dtype

MMA_K​

comptime MMA_K = mma_shape[2]

MMA_M​

comptime MMA_M = mma_shape[0]

num_k_tiles​

comptime num_k_tiles = ceildiv(BK, QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].MMA_K)

num_mmas​

comptime num_mmas = ceildiv(WM, QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].MMA_M)

num_tiles​

comptime num_tiles = (depth // BK)

reg_dtype​

comptime reg_dtype = dtype

reg_layout​

comptime reg_layout = row_major[((QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].num_mmas * QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].num_k_tiles) * QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].num_tiles), QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].input_frag_size]()

RegType​

comptime RegType = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Methods​

__init__​

__init__[q_tile_layout: TensorLayout](out self, q_tile: TileTensor[dtype, q_tile_layout, ImmutAnyOrigin])

Load Q tile from DRAM into registers via buffer_load intrinsics.

Each warp loads its [WM, depth] sub-tile using col-major thread distribution (matching get_warp_layout[mma_shape]), then tiles it into BK-wide strips stored in register memory.

Args:

mma_tile​

mma_tile[tile_idx: Int, k_idx: Int](self) -> TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Return MMA-sized sub-tile for the given tile and k indices.

Returns:

TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

scale​

scale[accum_type: DType](self, scale_factor: Scalar[accum_type])

Scale all Q register elements in-place.

Casts bf16 -> f32, multiplies by scale_factor, casts back to bf16. Used for pre-scaling Q by (1/sqrt(d) * log2e) so that QK matmul produces already-scaled scores, eliminating scale from the hot loop.

zero​

zero(self)