IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MatmulTileWriter

struct MatmulTileWriter[dtype: DType, tensor_layout: TensorLayout, linear_idx_type: DType, tensor_element_size: Int, smem_tile_layout: TensorLayout, //, *, BM: Int, BN: Int, swizzle: TensorMapSwizzle, wgmma_shape: IndexList[3], num_consumer: Int = 1, use_tma_store: Bool = False, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, swapAB: Bool = False]

Fields​

  • ​tensor (MatmulTileWriter[BM=BM, BN=BN, swizzle=swizzle, wgmma_shape=wgmma_shape, num_consumer=num_consumer, use_tma_store=use_tma_store, elementwise_lambda_fn=elementwise_lambda_fn, elementwise_compute_lambda_fn=elementwise_compute_lambda_fn, swapAB=swapAB].CTensorType):
  • ​smem_tile (TileTensor[dtype, smem_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED]):
  • ​warp_group_thread_idx (Int):
  • ​local_warp_group_idx (Int):
  • ​local_thread_idx (Int):
  • ​block_y (Int):
  • ​block_x (Int):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

CTensorType​

comptime CTensorType = TileTensor[dtype, tensor_layout, MutAnyOrigin, linear_idx_type=linear_idx_type, element_size=tensor_element_size]

frag_size​

comptime frag_size = ((wgmma_shape[0] * wgmma_shape[1]) // WARPGROUP_SIZE)

lambda_type​

comptime lambda_type = def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], mut SIMD[dtype, width]) capturing -> None

comptime N = tensor_layout.static_shape[1]

num_consumer_threads​

comptime num_consumer_threads = (num_consumer * WARPGROUP_SIZE)

num_m_mmas​

comptime num_m_mmas = ((BM // wgmma_shape[0]) // num_consumer)

num_n_mmas​

comptime num_n_mmas = (BN // wgmma_shape[1])

simd_size​

comptime simd_size = simd_width_of[dtype]()

WG_BM​

comptime WG_BM = smem_tile_layout.static_shape[0]

WG_BN​

comptime WG_BN = smem_tile_layout.static_shape[1]

Methods​

__init__​

def __init__(tensor: TileTensor[dtype, tensor_layout, MutAnyOrigin, linear_idx_type=linear_idx_type, element_size=tensor_element_size], smem_tile: TileTensor[dtype, smem_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED], warp_group_thread_idx: Int, local_warp_group_idx: Int, local_thread_idx: Int, block_y: Int, block_x: Int) -> Self

write_tile​

def write_tile[tma_rank: Int, tma_tile_shape: IndexList[tma_rank], tma_desc_shape: IndexList[tma_rank], accum_type: DType, reg_tile_layout: Layout, //](self, tma_op: TMATensorTile[dtype, tma_rank, tma_tile_shape, tma_desc_shape], reg_tile: LayoutTensor[accum_type, reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL])

Write output from registers to global memory.

Selects optimized st.matrix path for bf16 when constraints are met, otherwise uses general register-to-global path.