IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

PRegisterBufferRDNA

struct PRegisterBufferRDNA[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, mma_shape: IndexList[3], k_group_size: Int]

P register buffer (post-softmax scores). Holds the accumulator in registers; copy_to_shared casts to dtype and writes to a [BK, BM] SMEM region that the PV phase reads back as A.

Fields​

  • ​reg_tile (PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].RegisterTileType):
  • ​shared_memory_ptr (UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

mma_dtype​

comptime mma_dtype = dtype

mma_frag_size​

comptime mma_frag_size = RDNA_AB_FRAG_SIZE

mma_tile_layout​

comptime mma_tile_layout = row_major[num_m_mmas, 16]()

MMATileType​

comptime MMATileType = TileTensor[PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].mma_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

output_frag_size​

comptime output_frag_size = RDNA_CD_FRAG_SIZE

reg_dtype​

comptime reg_dtype = accum_type_

reg_tile_layout​

comptime reg_tile_layout = row_major[(num_n_mmas * num_m_mmas), 8]()

RegisterTileType​

comptime RegisterTileType = TileTensor[accum_type_, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

Methods​

__init__​

def __init__(out self, shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])

get_mma_tile​

def get_mma_tile[tile_idx: Int, k_idx: Int](self) -> Self.MMATileType

Load one MMA fragment from the SMEM-staged P scores. SMEM is keyed as key * BM + seq; each lane reads its (seq, key) slot.

Returns:

Self.MMATileType

get_dtype​

static def get_dtype() -> DType

Returns:

DType

zero​

def zero(self)

get_reg_tile​

def get_reg_tile[stage: Int = 0](self) -> Self.RegisterTileType

Returns:

Self.RegisterTileType

copy_to_shared​

def copy_to_shared[chunk_idx: Int](self)

Cast accumulator β†’ dtype and write the chunk_idx-th BK chunk of P to SMEM. Only the warp that owns that chunk participates.