Mojo struct
PRegisterBufferRDNA
struct PRegisterBufferRDNA[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, mma_shape: IndexList[3], k_group_size: Int]
P register buffer (post-softmax scores). Holds the accumulator in registers; copy_to_shared casts to dtype and writes to a [BK, BM] SMEM region that the PV phase reads back as A.
Fieldsβ
- βreg_tile (
PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].RegisterTileType): - βshared_memory_ptr (
UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]):
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
mma_dtypeβ
comptime mma_dtype = dtype
mma_frag_sizeβ
comptime mma_frag_size = RDNA_AB_FRAG_SIZE
mma_tile_layoutβ
comptime mma_tile_layout = row_major[num_m_mmas, 16]()
MMATileTypeβ
comptime MMATileType = TileTensor[PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].mma_dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]
output_frag_sizeβ
comptime output_frag_size = RDNA_CD_FRAG_SIZE
reg_dtypeβ
comptime reg_dtype = accum_type_
reg_tile_layoutβ
comptime reg_tile_layout = row_major[(num_n_mmas * num_m_mmas), 8]()
RegisterTileTypeβ
comptime RegisterTileType = TileTensor[accum_type_, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.LOCAL]
Methodsβ
__init__β
__init__(out self, shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])
get_mma_tileβ
get_mma_tile[tile_idx: Int, k_idx: Int](self) -> PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].MMATileType
Load one MMA fragment from the SMEM-staged P scores. SMEM is keyed as key * BM + seq; each lane reads its (seq, key) slot.
Returns:
PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].MMATileType
get_dtypeβ
zeroβ
zero(self)
get_reg_tileβ
get_reg_tile[stage: Int = 0](self) -> PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].RegisterTileType
Returns:
PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].RegisterTileType
copy_to_sharedβ
copy_to_shared[chunk_idx: Int](self)
Cast accumulator β dtype and write the chunk_idx-th BK chunk of P to SMEM. Only the warp that owns that chunk participates.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!