Mojo struct
MMATileBuffers
struct MMATileBuffers[mut: Bool, dtype: DType, layout: Layout, origin: Origin[mut], address_space: AddressSpace, element_layout: Layout, layout_int_type: DType, linear_idx_type: DType, masked: Bool, alignment: Int, //, _dtype: DType, /, smem_layout: Layout, reg_tile_layout: Layout, swizzle: Swizzle, tensor_type: AnyStruct[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]], thread_layout: Layout, warp_rows: Int, warp_cols: Int, stride: Int]
Manages memory for a single matrix (A or B) in GEMM computation.
This struct encapsulates all memory handling for a matrix, including:
- Shared memory allocation and tiling
- Register buffer allocation
- Data movement between memory levels (DRAM→local→shared)
Fields
- smem_tile (
LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()]
): - smem_warp_tile (
LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()]
): - load_reg_tile (
LayoutTensor[_dtype, reg_tile_layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()]
):
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = LayoutTensor[_dtype, reg_tile_layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial
MMARegTileType
alias MMARegTileType = LayoutTensor[_dtype, reg_tile_layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()]
SMemTileType
alias SMemTileType = LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()]
Methods
__init__
__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_idx: Int, warp_k_idx: Int, block_idx: Int)
Initialize memory regions for a matrix based on warp coordinates.
Args:
- tensor (
LayoutTensor
): The tensor to load from global memory. - warp_idx (
Int
): The warp index within the computation grid (used for MMA operations). - warp_k_idx (
Int
): The warp index within the computation grid (used for MMA operations). - block_idx (
Int
): The block index within the computation grid (used for warp tiling).
copy_to_smem
copy_to_smem(self)
Copy data from thread-local memory to shared memory.
Uses structured thread cooperation to efficiently transfer data.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!