For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MatmulTileWriter
struct MatmulTileWriter[dtype: DType, tensor_layout: TensorLayout, linear_idx_type: DType, tensor_element_size: Int, smem_tile_layout: TensorLayout, //, *, BM: Int, BN: Int, swizzle: TensorMapSwizzle, wgmma_shape: IndexList[3], num_consumer: Int = 1, use_tma_store: Bool = False, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, swapAB: Bool = False]
Fieldsβ
- βtensor (
MatmulTileWriter[BM=BM, BN=BN, swizzle=swizzle, wgmma_shape=wgmma_shape, num_consumer=num_consumer, use_tma_store=use_tma_store, elementwise_lambda_fn=elementwise_lambda_fn, elementwise_compute_lambda_fn=elementwise_compute_lambda_fn, swapAB=swapAB].CTensorType): - βsmem_tile (
TileTensor[dtype, smem_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED]): - βwarp_group_thread_idx (
Int): - βlocal_warp_group_idx (
Int): - βlocal_thread_idx (
Int): - βblock_y (
Int): - βblock_x (
Int):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
CTensorTypeβ
comptime CTensorType = TileTensor[dtype, tensor_layout, MutAnyOrigin, linear_idx_type=linear_idx_type, element_size=tensor_element_size]
frag_sizeβ
comptime frag_size = ((wgmma_shape[0] * wgmma_shape[1]) // WARPGROUP_SIZE)
lambda_typeβ
comptime lambda_type = def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], mut SIMD[dtype, width]) capturing -> None
Nβ
comptime N = tensor_layout.static_shape[1]
num_consumer_threadsβ
comptime num_consumer_threads = (num_consumer * WARPGROUP_SIZE)
num_m_mmasβ
comptime num_m_mmas = ((BM // wgmma_shape[0]) // num_consumer)
num_n_mmasβ
comptime num_n_mmas = (BN // wgmma_shape[1])
simd_sizeβ
comptime simd_size = simd_width_of[dtype]()
WG_BMβ
comptime WG_BM = smem_tile_layout.static_shape[0]
WG_BNβ
comptime WG_BN = smem_tile_layout.static_shape[1]
Methodsβ
__init__β
def __init__(tensor: TileTensor[dtype, tensor_layout, MutAnyOrigin, linear_idx_type=linear_idx_type, element_size=tensor_element_size], smem_tile: TileTensor[dtype, smem_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED], warp_group_thread_idx: Int, local_warp_group_idx: Int, local_thread_idx: Int, block_y: Int, block_x: Int) -> Self
write_tileβ
def write_tile[tma_rank: Int, tma_tile_shape: IndexList[tma_rank], tma_desc_shape: IndexList[tma_rank], accum_type: DType, reg_tile_layout: Layout, //](self, tma_op: TMATensorTile[dtype, tma_rank, tma_tile_shape, tma_desc_shape], reg_tile: LayoutTensor[accum_type, reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL])
Write output from registers to global memory.
Selects optimized st.matrix path for bf16 when constraints are met, otherwise uses general register-to-global path.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!