Mojo struct
TMAStoreExecutor
@register_passable(trivial)
struct TMAStoreExecutor[c_type: DType, c_smem_layout: Layout, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, stage_contiguous_size: Int, cta_group: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, is_lower_frag_required: Bool, batched: Bool = False]
Execute TMA store from SMEM to GMEM with proper tiling.
Handles 3 paths: transpose+cta_group2+MMA128, transpose+other, non-transpose. When batched=True, uses 3D coordinates (M, N, Batch) for TMA stores.
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
c_smem_shape0
comptime c_smem_shape0 = c_smem_layout.shape[0].value()
CG1_TMA_BM
comptime CG1_TMA_BM = TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required, batched].c_smem_shape0
CG2_TMA_BM
comptime CG2_TMA_BM = c_smem_layout.shape[0].value() if (eq MMA_M._mlir_value, 256) else BM
num_c_smem_tiles
comptime num_c_smem_tiles = ((128 // TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required, batched].swizzle_width) // 1 if is_lower_frag_required else 2)
swizzle_width
comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())
TMA_BM
comptime TMA_BM = c_smem_layout.shape[0].value() if (eq MMA_M._mlir_value, 256) else BM if (eq cta_group._mlir_value, 2) else TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required, batched].CG1_TMA_BM
Methods
execute
static execute[c_layout: Layout, c_desc_layout: Layout](c_smem_tile: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], store_coords: TMAStoreCoords[BM, BN, MMA_M, MMA_N, stageN, cta_group, TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required, batched].c_smem_shape0, stage, batched], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], warp_id: UInt32, lane: UInt32)
Execute TMA store with elected warp and lane 0.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!