Mojo struct
PRegisterBuffer
struct PRegisterBuffer[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, output_frag_size: Int, shared_memory_backed: Bool, mma_shape: IndexList[3], tr_load_enabled: Bool = False, num_stages: Int = 1, p_swizzle: Optional[Swizzle] = None, raw_fp8_cast: Bool = False]
Fieldsβ
- βreg_tile (
PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].RegType): - βsmem_tile (
PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].SmemTileType):
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
BlockSmemTypeβ
comptime BlockSmemType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
input_frag_sizeβ
comptime input_frag_size = num_matrix_reg[mma_shape[0], mma_shape[2]]()
mma_dtypeβ
comptime mma_dtype = dtype
MmaTileTypeβ
comptime MmaTileType = TileTensor[PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].mma_dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
reg_dtypeβ
comptime reg_dtype = accum_type_
reg_layoutβ
comptime reg_layout = row_major[((num_stages * num_n_mmas) * num_m_mmas), output_frag_size]()
RegTypeβ
comptime RegType = TileTensor[accum_type_, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
SmemTileTypeβ
comptime SmemTileType = TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]
stage_layoutβ
comptime stage_layout = row_major[(num_n_mmas * num_m_mmas), output_frag_size]()
StageTileTypeβ
comptime StageTileType = TileTensor[accum_type_, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Methodsβ
__init__β
__init__(out self, smem_tile: TileTensor[dtype, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED])
get_mma_tile_sharedβ
get_mma_tile_shared[tile_idx: Int, k_idx: Int](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].MmaTileType
Returns:
PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].MmaTileType
stage_tileβ
stage_tile[stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].StageTileType
Return the TileTensor sub-tile for the given pipeline stage.
Returns:
PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].StageTileType
mma_tileβ
mma_tile[tile_idx: Int, k_idx: Int, stage: Int = 0](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].MmaTileType
TileTensor MMA operand with cast+interleave via SIMD whole-vector ops.
Converts f32 accumulator rows to bf16 MMA fragments using SIMD cast, interleave, and slice β no per-element [j] indexing needed.
Returns:
PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, tr_load_enabled, num_stages, p_swizzle, raw_fp8_cast].MmaTileType
zeroβ
zero[stage: Int](self)
copy_to_sharedβ
copy_to_shared(self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!