Mojo struct
PRegisterBufferRDNA
struct PRegisterBufferRDNA[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, mma_shape: IndexList[3], k_group_size: Int]
RDNA-specific P register buffer for Wave32 WMMA attention.
Fields
- reg_tile (
PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].RegisterTileType): - shared_memory_ptr (
UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]):
Implemented traits
AnyType,
ImplicitlyDestructible,
RegisterBuffer,
RegisterMMABuffer
comptime members
__del__is_trivial
comptime __del__is_trivial = True
chunk_shared_memory_layout
comptime chunk_shared_memory_layout = Layout.row_major(BK, BM)
ChunkSharedMemoryTileType
comptime ChunkSharedMemoryTileType = LayoutTensor[dtype, PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].chunk_shared_memory_layout, MutAnyOrigin, address_space=AddressSpace.SHARED]
mma_dtype
comptime mma_dtype = dtype
mma_frag_size
comptime mma_frag_size = RDNA_AB_FRAG_SIZE
mma_tile_layout
comptime mma_tile_layout = Layout.row_major(num_m_mmas, 16)
MMATileType
comptime MMATileType = LayoutTensor[PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].mma_dtype, PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
output_frag_size
comptime output_frag_size = RDNA_CD_FRAG_SIZE
reg_dtype
comptime reg_dtype = accum_type_
reg_tile_layout
comptime reg_tile_layout = Layout.row_major((num_n_mmas * num_m_mmas), 8)
RegisterTileType
comptime RegisterTileType = LayoutTensor[accum_type_, PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
Methods
__init__
__init__(out self, shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])
get_mma_tile
get_mma_tile[tile_idx: Int, k_idx: Int](self) -> PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].MMATileType
Get MMA tile by loading from shared memory.
RDNA WMMA B register: b_frag[v] = B[v, lane%16]. With swap_a_b, P goes to hardware B. We need B = P^T[key, seq]. B[k=element_v, n=lane%16] → element = key (row), lane = seq (col). So: b_frag[v] = P^T[key=v, seq=lane] = P_shared[key=v, seq=lane].
Returns:
PRegisterBufferRDNA
get_dtype
vectorize
vectorize(self) -> LayoutTensor[accum_type_, coalesce(LayoutTensor._compute_tile_layout[1, 8]()[1], True), MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=LayoutTensor._divide_tiles[1, 8]()[0], layout_int_type=_get_layout_type(PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].reg_tile_layout, AddressSpace.LOCAL), linear_idx_type=_get_index_type(PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].reg_tile_layout, AddressSpace.LOCAL)]
Returns:
zero
zero(self)
get_reg_tile
get_reg_tile[stage: Int = 0](self) -> PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].RegisterTileType
Returns:
PRegisterBufferRDNA
copy_to_shared
copy_to_shared[chunk_idx: Int](self)
Copy one BK chunk of P register tile to shared memory using RDNA layouts.
Each chunk corresponds to BK=32 keys. With 2 warps each handling WN=32 keys, chunk 0 = warp 0's data, chunk 1 = warp 1's data. Only the owning warp writes, using warp-local tile indices to avoid OOB register access.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!