Mojo struct
PRegisterBuffer
struct PRegisterBuffer[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, output_frag_size: Int, shared_memory_backed: Bool, mma_shape: IndexList[3], k_group_size: Int, tr_load_enabled: Bool = False, num_stages: Int = 1]
Fields
- reg_tile (
PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].RegType): - shared_memory_ptr (
UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
mma_dtype
comptime mma_dtype = dtype
mma_tile_layout
comptime mma_tile_layout = Layout.row_major(num_m_mmas, simd_width_of[dtype]())
MmaTileType
comptime MmaTileType = TileTensor[PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].mma_dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
reg_dtype
comptime reg_dtype = accum_type_
reg_layout
comptime reg_layout = row_major[((num_stages * num_n_mmas) * num_m_mmas), output_frag_size]()
reg_tile_layout
comptime reg_tile_layout = Layout.row_major((num_n_mmas * num_m_mmas), output_frag_size)
RegType
comptime RegType = TileTensor[accum_type_, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
stage_layout
comptime stage_layout = row_major[(num_n_mmas * num_m_mmas), output_frag_size]()
StageTileType
comptime StageTileType = TileTensor[accum_type_, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Methods
__init__
__init__(out self, shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])
get_mma_tile_shared
get_mma_tile_shared[tile_idx: Int, k_idx: Int](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].MmaTileType
Returns:
PRegisterBuffer
stage_tile
stage_tile[stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].StageTileType
Return the TileTensor sub-tile for the given pipeline stage.
Returns:
PRegisterBuffer
mma_tile
mma_tile[tile_idx: Int, k_idx: Int, stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled, num_stages].MmaTileType
TileTensor MMA operand with cast+interleave via SIMD whole-vector ops.
Converts f32 accumulator rows to bf16 MMA fragments using SIMD cast, interleave, and slice — no per-element [j] indexing needed.
Returns:
PRegisterBuffer
zero
zero[stage: Int](self)
zero(self)
copy_to_shared
copy_to_shared(self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!