IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

QRegisterBuffer

struct QRegisterBuffer[dtype: DType, mma_shape: IndexList[3], WM: Int, WN: Int, BN: Int, BK: Int, depth: Int, thread_rows: Int, thread_cols: Int]

Fields​

  • ​reg_tile (QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].RegType):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

input_frag_size​

comptime input_frag_size = num_matrix_reg[QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].MMA_M, QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].MMA_K]()

mma_dtype​

comptime mma_dtype = dtype

MMA_K​

comptime MMA_K = mma_shape[2]

MMA_M​

comptime MMA_M = mma_shape[0]

num_k_tiles​

comptime num_k_tiles = ceildiv(BK, QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].MMA_K)

num_mmas​

comptime num_mmas = ceildiv(WM, QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].MMA_M)

num_tiles​

comptime num_tiles = ceildiv(depth, BK)

reg_dtype​

comptime reg_dtype = dtype

reg_layout​

comptime reg_layout = row_major[((QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].num_mmas * QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].num_k_tiles) * QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].num_tiles), QRegisterBuffer[dtype, mma_shape, WM, WN, BN, BK, depth, thread_rows, thread_cols].input_frag_size]()

RegType​

comptime RegType = TileTensor[dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Methods​

__init__​

def __init__[q_tile_layout: TensorLayout](out self, q_tile: TileTensor[dtype, q_tile_layout, ImmutAnyOrigin])

Load Q tile from DRAM into registers via buffer_load intrinsics.

Each warp loads its [WM, depth] sub-tile using col-major thread distribution (matching get_warp_layout[mma_shape]), then tiles it into BK-wide strips stored in register memory.

Args:

mma_tile​

def mma_tile[tile_idx: Int, k_idx: Int](self) -> TileTensor[dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Return MMA-sized sub-tile for the given tile and k indices.

Returns:

TileTensor[dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

scale​

def scale[accum_type: DType](self, scale_factor: Scalar[accum_type])

Scale all Q register elements in-place.

Casts bf16 -> f32, multiplies by scale_factor, casts back to bf16. Used for pre-scaling Q by (1/sqrt(d) * log2e) so that QK matmul produces already-scaled scores, eliminating scale from the hot loop.

zero​

def zero(self)