Skip to main content

Mojo struct

PRegisterBufferRDNA

struct PRegisterBufferRDNA[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, mma_shape: IndexList[3], k_group_size: Int]

RDNA-specific P register buffer for Wave32 WMMA attention.

Fields

  • reg_tile (PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].RegisterTileType):
  • shared_memory_ptr (UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]):

Implemented traits

AnyType, ImplicitlyDestructible, RegisterBuffer, RegisterMMABuffer

comptime members

__del__is_trivial

comptime __del__is_trivial = True

chunk_shared_memory_layout

comptime chunk_shared_memory_layout = Layout.row_major(BK, BM)

ChunkSharedMemoryTileType

comptime ChunkSharedMemoryTileType = LayoutTensor[dtype, PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].chunk_shared_memory_layout, MutAnyOrigin, address_space=AddressSpace.SHARED]

mma_dtype

comptime mma_dtype = dtype

mma_frag_size

comptime mma_frag_size = RDNA_AB_FRAG_SIZE

mma_tile_layout

comptime mma_tile_layout = Layout.row_major(num_m_mmas, 16)

MMATileType

comptime MMATileType = LayoutTensor[PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].mma_dtype, PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

output_frag_size

comptime output_frag_size = RDNA_CD_FRAG_SIZE

reg_dtype

comptime reg_dtype = accum_type_

reg_tile_layout

comptime reg_tile_layout = Layout.row_major((num_n_mmas * num_m_mmas), 8)

RegisterTileType

comptime RegisterTileType = LayoutTensor[accum_type_, PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

Methods

__init__

__init__(out self, shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])

get_mma_tile

get_mma_tile[tile_idx: Int, k_idx: Int](self) -> PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].MMATileType

Get MMA tile by loading from shared memory.

RDNA WMMA B register: b_frag[v] = B[v, lane%16]. With swap_a_b, P goes to hardware B. We need B = P^T[key, seq]. B[k=element_v, n=lane%16] → element = key (row), lane = seq (col). So: b_frag[v] = P^T[key=v, seq=lane] = P_shared[key=v, seq=lane].

Returns:

PRegisterBufferRDNA

get_dtype

static get_dtype() -> DType

Returns:

DType

vectorize

vectorize(self) -> LayoutTensor[accum_type_, coalesce(LayoutTensor._compute_tile_layout[1, 8]()[1], True), MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=LayoutTensor._divide_tiles[1, 8]()[0], layout_int_type=_get_layout_type(PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].reg_tile_layout, AddressSpace.LOCAL), linear_idx_type=_get_index_type(PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].reg_tile_layout, AddressSpace.LOCAL)]

Returns:

LayoutTensor

zero

zero(self)

get_reg_tile

get_reg_tile[stage: Int = 0](self) -> PRegisterBufferRDNA[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, mma_shape, k_group_size].RegisterTileType

Returns:

PRegisterBufferRDNA

copy_to_shared

copy_to_shared[chunk_idx: Int](self)

Copy one BK chunk of P register tile to shared memory using RDNA layouts.

Each chunk corresponds to BK=32 keys. With 2 warps each handling WN=32 keys, chunk 0 = warp 0's data, chunk 1 = warp 1's data. Only the owning warp writes, using warp-local tile indices to avoid OOB register access.

Was this page helpful?