Skip to main content

Mojo struct

PRegisterBufferRDNA

struct PRegisterBufferRDNA[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, mma_shape: IndexList[3], k_group_size: Int]

P register buffer (post-softmax scores). Holds the accumulator in registers; copy_to_shared casts to dtype and writes to a [BK, BM] SMEM region that the PV phase reads back as A.

Fields​

  • ​reg_tile (PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].RegisterTileType):
  • ​shared_memory_ptr (UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]):

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

mma_dtype​

comptime mma_dtype = dtype

mma_frag_size​

comptime mma_frag_size = RDNA_AB_FRAG_SIZE

mma_tile_layout​

comptime mma_tile_layout = row_major[num_m_mmas, 16]()

MMATileType​

comptime MMATileType = TileTensor[PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].mma_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

output_frag_size​

comptime output_frag_size = RDNA_CD_FRAG_SIZE

reg_dtype​

comptime reg_dtype = accum_type_

reg_tile_layout​

comptime reg_tile_layout = row_major[(num_n_mmas * num_m_mmas), 8]()

RegisterTileType​

comptime RegisterTileType = TileTensor[accum_type_, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]

Methods​

__init__​

__init__(out self, shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])

get_mma_tile​

get_mma_tile[tile_idx: Int, k_idx: Int](self) -> PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].MMATileType

Load one MMA fragment from the SMEM-staged P scores. SMEM is keyed as key * BM + seq; each lane reads its (seq, key) slot.

Returns:

PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].MMATileType

get_dtype​

static get_dtype() -> DType

Returns:

DType

zero​

zero(self)

get_reg_tile​

get_reg_tile[stage: Int = 0](self) -> PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].RegisterTileType

Returns:

PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].RegisterTileType

copy_to_shared​

copy_to_shared[chunk_idx: Int](self)

Cast accumulator β†’ dtype and write the chunk_idx-th BK chunk of P to SMEM. Only the warp that owns that chunk participates.