Skip to main content

Mojo struct

PRegisterBuffer

struct PRegisterBuffer[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, output_frag_size: Int, shared_memory_backed: Bool, mma_shape: IndexList[3], k_group_size: Int, tr_load_enabled: Bool = False, num_stages: Int = 1]

Fields

  • reg_tile (PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].RegType):
  • shared_memory_ptr (UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

mma_dtype

comptime mma_dtype = dtype

mma_tile_layout

comptime mma_tile_layout = Layout.row_major(num_m_mmas, simd_width_of[dtype]())

MmaTileType

comptime MmaTileType = TileTensor[PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].mma_dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

reg_dtype

comptime reg_dtype = accum_type_

reg_layout

comptime reg_layout = row_major[((num_stages * num_n_mmas) * num_m_mmas), output_frag_size]()

reg_tile_layout

comptime reg_tile_layout = Layout.row_major((num_n_mmas * num_m_mmas), output_frag_size)

RegType

comptime RegType = TileTensor[accum_type_, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

stage_layout

comptime stage_layout = row_major[(num_n_mmas * num_m_mmas), output_frag_size]()

StageTileType

comptime StageTileType = TileTensor[accum_type_, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Methods

__init__

__init__(out self, shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])

get_mma_tile_shared

get_mma_tile_shared[tile_idx: Int, k_idx: Int](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].MmaTileType

Returns:

PRegisterBuffer

stage_tile

stage_tile[stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].StageTileType

Return the TileTensor sub-tile for the given pipeline stage.

Returns:

PRegisterBuffer

mma_tile

mma_tile[tile_idx: Int, k_idx: Int, stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].MmaTileType

TileTensor MMA operand with cast+interleave via SIMD whole-vector ops.

Converts f32 accumulator rows to bf16 MMA fragments using SIMD cast, interleave, and slice — no per-element [j] indexing needed.

Returns:

PRegisterBuffer

zero

zero[stage: Int](self)

zero(self)

copy_to_shared

copy_to_shared(self)

Was this page helpful?