Skip to main content

Mojo struct

PRegisterBuffer

struct PRegisterBuffer[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, output_frag_size: Int, shared_memory_backed: Bool, mma_shape: IndexList[3], tr_load_enabled: Bool = False, num_stages: Int = 1, p_swizzle: Optional[Swizzle] = None]

Fields

  • reg_tile (PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].RegType):
  • smem_tile (PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].SmemTileType):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

BlockSmemType

comptime BlockSmemType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

input_frag_size

comptime input_frag_size = num_matrix_reg[mma_shape[0], mma_shape[2]]()

mma_dtype

comptime mma_dtype = dtype

MmaTileType

comptime MmaTileType = TileTensor[PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].mma_dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

reg_dtype

comptime reg_dtype = accum_type_

reg_layout

comptime reg_layout = row_major[((num_stages * num_n_mmas) * num_m_mmas), output_frag_size]()

RegType

comptime RegType = TileTensor[accum_type_, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

SmemTileType

comptime SmemTileType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

stage_layout

comptime stage_layout = row_major[(num_n_mmas * num_m_mmas), output_frag_size]()

StageTileType

comptime StageTileType = TileTensor[accum_type_, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Methods

__init__

__init__(out self, smem_tile: TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED])

get_mma_tile_shared

get_mma_tile_shared[tile_idx: Int, k_idx: Int](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].MmaTileType

Returns:

PRegisterBuffer

stage_tile

stage_tile[stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].StageTileType

Return the TileTensor sub-tile for the given pipeline stage.

Returns:

PRegisterBuffer

mma_tile

mma_tile[tile_idx: Int, k_idx: Int, stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].MmaTileType

TileTensor MMA operand with cast+interleave via SIMD whole-vector ops.

Converts f32 accumulator rows to bf16 MMA fragments using SIMD cast, interleave, and slice — no per-element [j] indexing needed.

Returns:

PRegisterBuffer

zero

zero[stage: Int](self)

copy_to_shared

copy_to_shared(self)

Was this page helpful?