Mojo struct
MatmulTileWriter
@register_passable(trivial)
struct MatmulTileWriter[dtype: DType, layout: Layout, address_space: AddressSpace, element_layout: Layout, layout_int_type: DType, linear_idx_type: DType, masked: Bool, alignment: Int, smem_tile_layout: Layout, //, *, BM: Int, BN: Int, swizzle: TensorMapSwizzle, wgmma_shape: IndexList[3], num_consumer: Int = 1, use_tma_store: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None]
Fields
- tensor (
LayoutTensor[dtype, layout, MutableAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]): - smem_tile (
LayoutTensor[dtype, smem_tile_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=128]): - warp_group_thread_idx (
UInt): - local_warp_group_idx (
UInt): - local_thread_idx (
UInt): - block_y (
Int): - block_x (
Int):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
CTensorType
alias CTensorType = LayoutTensor[dtype, layout, MutableAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]
frag_size
alias frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // _resolve_warpgroup_size())
lambda_type
alias lambda_type = fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], mut SIMD[dtype, width]) capturing -> None
N
alias N = layout.shape[1].value()
num_consumer_threads
alias num_consumer_threads = (num_consumer * _resolve_warpgroup_size())
num_m_mmas
alias num_m_mmas = ((BM // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // num_consumer)
num_n_mmas
alias num_n_mmas = (BN // wgmma_shape.__getitem__[3, DType.int64, Int](1))
simd_size
alias simd_size = simd_width_of[dtype]()
WG_BM
alias WG_BM = smem_tile_layout.shape[0].value()
WG_BN
alias WG_BN = smem_tile_layout.shape[1].value()
Methods
__init__
__init__(tensor: LayoutTensor[dtype, layout, MutableAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], smem_tile: LayoutTensor[dtype, smem_tile_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=128], warp_group_thread_idx: UInt, local_warp_group_idx: UInt, local_thread_idx: UInt, block_y: Int, block_x: Int) -> Self
write_tile
write_tile[tma_layout: Layout, desc_layout: Layout, accum_type: DType, reg_tile_layout: Layout, //](self, tma_op: TMATensorTile[dtype, tma_layout, desc_layout], reg_tile: LayoutTensor[accum_type, reg_tile_layout, MutableAnyOrigin, address_space=AddressSpace(5)])
Write output from registers to global memory.
Selects optimized st.matrix path for bf16 when constraints are met, otherwise uses general register-to-global path.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!