For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MatmulTileWriter
struct MatmulTileWriter[dtype: DType, tensor_layout: TensorLayout, tensor_storage: TensorStorage, linear_idx_type: DType, tensor_element_size: Int, smem_tile_layout: TensorLayout, //, *, BM: Int, BN: Int, swizzle: TensorMapSwizzle, wgmma_shape: IndexList[Int(3)], num_consumer: Int = Int(1), use_tma_store: Bool = False, elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, swapAB: Bool = False]
Fieldsβ
- βtensor (
MatmulTileWriter[BM=BM, BN=BN, swizzle=swizzle, wgmma_shape=wgmma_shape, num_consumer=num_consumer, use_tma_store=use_tma_store, elementwise_lambda_fn=elementwise_lambda_fn, elementwise_compute_lambda_fn=elementwise_compute_lambda_fn, swapAB=swapAB].CTensorType): - βsmem_tile (
TileTensor[dtype, smem_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED]): - βwarp_group_thread_idx (
Int): - βlocal_warp_group_idx (
Int): - βlocal_thread_idx (
Int): - βblock_y (
Int): - βblock_x (
Int):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
CTensorTypeβ
comptime CTensorType = TileTensor[dtype, tensor_layout, MutAnyOrigin, Storage=tensor_storage, linear_idx_type=linear_idx_type, element_size=tensor_element_size]
frag_sizeβ
comptime frag_size = (Int((mul wgmma_shape[Int(0)], wgmma_shape[Int(1)])) // _resolve_warpgroup_size())
lambda_typeβ
comptime lambda_type = def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], mut SIMD[dtype, width]) capturing -> None
Nβ
comptime N = tensor_layout.static_shape[Int(1)]
num_consumer_threadsβ
comptime num_consumer_threads = (num_consumer * _resolve_warpgroup_size())
num_m_mmasβ
comptime num_m_mmas = ((BM // wgmma_shape[Int(0)]) // num_consumer)
num_n_mmasβ
comptime num_n_mmas = (BN // wgmma_shape[Int(1)])
simd_sizeβ
comptime simd_size = simd_width_of[dtype]()
WG_BMβ
comptime WG_BM = smem_tile_layout.static_shape[Int(0)]
WG_BNβ
comptime WG_BN = smem_tile_layout.static_shape[Int(1)]
Methodsβ
__init__β
def __init__(tensor: TileTensor[dtype, tensor_layout, MutAnyOrigin, Storage=tensor_storage, linear_idx_type=linear_idx_type, element_size=tensor_element_size], smem_tile: TileTensor[dtype, smem_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED], warp_group_thread_idx: Int, local_warp_group_idx: Int, local_thread_idx: Int, block_y: Int, block_x: Int) -> Self
write_tileβ
def write_tile[tma_rank: Int, tma_tile_shape: IndexList[tma_rank], tma_desc_shape: IndexList[tma_rank], accum_type: DType, reg_tile_layout: Layout, //](self, tma_op: TMATensorTile[dtype, tma_rank, tma_tile_shape, tma_desc_shape], reg_tile: LayoutTensor[accum_type, reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL])
Write output from registers to global memory.
Selects optimized st.matrix path for bf16 when constraints are met, otherwise uses general register-to-global path.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!