Skip to main content

Mojo struct

MMATileBuffers

struct MMATileBuffers[: DType, : DType, : IndexList[3], : Bool, : Int, : Int, : Int, : Int, : Int, : Swizzle, : Int, : Bool, : DType, : Layout, : Origin[], : AddressSpace, : Layout, : DType, : DType, : Bool, : Int, tensor_origin: ImmutableOrigin, //, smem_layout: Layout, /, tensor_type: AnyStruct[LayoutTensor[, , , address_space=, element_layout=, layout_int_type=, linear_idx_type=, masked=, alignment=]], thread_layout: Layout, block_rows: Int, warp_rows: Int, stride: Int, num_mmas: Int, mma_type: AnyStruct[AMD_MMA[, , , , , , , , , , ]]]

Manages memory for a single matrix (A or B) in GEMM computation.

This struct encapsulates all memory handling for a matrix, including:

  • Shared memory allocation and tiling
  • Register buffer allocation
  • Data movement between memory levels (DRAM→local→shared)

Fields

  • shared_mem_tile (LayoutTensor[in_type, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=alignof[::AnyType,__mlir_type.!kgen.target]()]):
  • shared_mem_warp_tile (LayoutTensor[in_type, _compute_tile_layout[*::Int]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[::Layout,*::Int](), alignment=alignof[::AnyType,__mlir_type.!kgen.target]()]):
  • load_reg_tile (LayoutTensor[in_type, row_major((num_k_tiles * num_mmas), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=alignof[::AnyType,__mlir_type.!kgen.target]()]):
  • mma_reg_tile (StaticTuple[LayoutTensor[in_type, _compute_tile_layout[::Int,::Int]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=alignof[::AnyType,__mlir_type.!kgen.target]()], num_k_tiles]):
  • gmem_iter (LayoutTensorIter[dtype, _compute_tile_layout[*::Int]()[0], origin, address_space=address_space, axis=OptionalReg[Int]({:@stdlib::@builtin::@int::@Int {1}, 0}), layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[::Layout,*::Int]() if masked if masked else _tile_is_masked[::Layout,*::Int]() else _tile_is_masked[::Layout,*::Int]()]):
  • global_offset (UInt):
  • tensor (Pointer[tensor_type, tensor_origin]):

Implemented traits

AnyType, UnknownDestructibility

Aliases

iter_type

alias iter_type = LayoutTensorIter[dtype, _compute_tile_layout[*::Int]()[0], origin, address_space=address_space, axis=OptionalReg[Int]({:@stdlib::@builtin::@int::@Int {1}, 0}), layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[::Layout,*::Int]() if masked if masked else _tile_is_masked[::Layout,*::Int]() else _tile_is_masked[::Layout,*::Int]()]

MMARegTileType

alias MMARegTileType = LayoutTensor[in_type, row_major((num_k_tiles * num_mmas), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=alignof[::AnyType,__mlir_type.!kgen.target]()]

SharedMemTileType

alias SharedMemTileType = LayoutTensor[in_type, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=alignof[::AnyType,__mlir_type.!kgen.target]()]

Methods

__init__

__init__(out self, ref [tensor_origin] tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_idx: Int, block_idx: Int)

Initialize memory regions for a matrix based on warp coordinates.

Args:

  • tensor (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]): The tensor to load from global memory.
  • warp_idx (Int): The warp index within the computation grid (used for MMA operations).
  • block_idx (Int): The block index within the computation grid (used for warp tiling).

copy_to_shared

copy_to_shared(self)

Copy data from thread-local memory to shared memory.

Uses structured thread cooperation to efficiently transfer data.

load_from_dram

load_from_dram(mut self)

Load data from global memory (DRAM) to thread-local memory.

get_reg_tile

get_reg_tile[k_tile_idx: Int](self) -> LayoutTensor[in_type, _compute_tile_layout[::Int,::Int]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=alignof[::AnyType,__mlir_type.!kgen.target]()]

Get a specific K-dimension tile from the register buffer.

Parameters:

  • k_tile_idx (Int): The K-dimension tile index.

Returns:

A tile view for the specified location in the register buffer.

load_tile_from_shared

load_tile_from_shared[k_tile_idx: Int, is_a: Bool](self)

Was this page helpful?