Mojo struct
PRegisterBuffer
struct PRegisterBuffer[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, output_frag_size: Int, shared_memory_backed: Bool, mma_shape: IndexList[3], tr_load_enabled: Bool = False, num_stages: Int = 1, p_swizzle: Optional[Swizzle] = None]
Fields
- reg_tile (
PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].RegType): - smem_tile (
PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].SmemTileType):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
BlockSmemType
comptime BlockSmemType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
input_frag_size
comptime input_frag_size = num_matrix_reg[mma_shape[0], mma_shape[2]]()
mma_dtype
comptime mma_dtype = dtype
MmaTileType
comptime MmaTileType = TileTensor[PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].mma_dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
reg_dtype
comptime reg_dtype = accum_type_
reg_layout
comptime reg_layout = row_major[((num_stages * num_n_mmas) * num_m_mmas), output_frag_size]()
RegType
comptime RegType = TileTensor[accum_type_, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
SmemTileType
comptime SmemTileType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
stage_layout
comptime stage_layout = row_major[(num_n_mmas * num_m_mmas), output_frag_size]()
StageTileType
comptime StageTileType = TileTensor[accum_type_, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Methods
__init__
__init__(out self, smem_tile: TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED])
get_mma_tile_shared
get_mma_tile_shared[tile_idx: Int, k_idx: Int](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].MmaTileType
Returns:
PRegisterBuffer
stage_tile
stage_tile[stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].StageTileType
Return the TileTensor sub-tile for the given pipeline stage.
Returns:
PRegisterBuffer
mma_tile
mma_tile[tile_idx: Int, k_idx: Int, stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle].MmaTileType
TileTensor MMA operand with cast+interleave via SIMD whole-vector ops.
Converts f32 accumulator rows to bf16 MMA fragments using SIMD cast, interleave, and slice — no per-element [j] indexing needed.
Returns:
PRegisterBuffer
zero
zero[stage: Int](self)
copy_to_shared
copy_to_shared(self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!