IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

PRegisterBuffer

struct PRegisterBuffer[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, output_frag_size: Int, shared_memory_backed: Bool, mma_shape: IndexList[Int(3)], tr_load_enabled: Bool = False, num_stages: Int = Int(1), p_swizzle: Optional[Swizzle] = None, raw_fp8_cast: Bool = False]

Fields​

  • ​reg_tile (PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].RegType):
  • ​smem_tile (PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].SmemTileType):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

BlockSmemType​

comptime BlockSmemType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

input_frag_size​

comptime input_frag_size = num_matrix_reg[mma_shape[Int(0)], mma_shape[Int(2)]]()

mma_dtype​

comptime mma_dtype = dtype

MmaTileType​

comptime MmaTileType = TileTensor[PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].mma_dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

reg_dtype​

comptime reg_dtype = accum_type_

reg_layout​

comptime reg_layout = row_major[((num_stages * num_n_mmas) * num_m_mmas), output_frag_size]()

RegType​

comptime RegType = TileTensor[accum_type_, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

SmemTileType​

comptime SmemTileType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]

stage_layout​

comptime stage_layout = row_major[(num_n_mmas * num_m_mmas), output_frag_size]()

StageTileType​

comptime StageTileType = TileTensor[accum_type_, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Methods​

__init__​

def __init__(out self, smem_tile: TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED])

get_mma_tile_shared​

def get_mma_tile_shared[tile_idx: Int, k_idx: Int](self) -> Self.MmaTileType

Returns:

Self.MmaTileType

stage_tile​

def stage_tile[stage: Int = Int(0)](self) -> Self.StageTileType

Return the TileTensor sub-tile for the given pipeline stage.

Returns:

Self.StageTileType

mma_tile​

def mma_tile[tile_idx: Int, k_idx: Int, stage: Int = Int(0)](self) -> Self.MmaTileType

TileTensor MMA operand with cast+interleave via SIMD whole-vector ops.

Converts f32 accumulator rows to bf16 MMA fragments using SIMD cast, interleave, and slice β€” no per-element [j] indexing needed.

Returns:

Self.MmaTileType

zero​

def zero[stage: Int](self)

copy_to_shared​

def copy_to_shared(self)