Mojo struct
TMAStoreExecutor
@register_passable(trivial)
struct TMAStoreExecutor[c_type: DType, c_smem_layout: Layout, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, stage_contiguous_size: Int, cta_group: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, is_lower_frag_required: Bool]
Execute TMA store from shared memory to global memory with proper tiling.
Encapsulates all the complex SMEM tiling/reshaping logic for TMA stores. Handles 3 distinct paths based on transpose_c, cta_group, and MMA_M:
- transpose_c + cta_group==2 + MMA_M==128: Split reshape
- transpose_c + other: Loop over swizzle-width tiles
- non-transpose: Simple tile selection
Template Parameters: c_type: Output data type. c_smem_layout: Shared memory layout for C tile. BM: Block M dimension. BN: Block N dimension. MMA_M: MMA M dimension. MMA_N: MMA N dimension. stageN: Stage width in elements. stage_contiguous_size: Contiguous size in SMEM layout. cta_group: Number of CTAs cooperating (1 or 2). c_swizzle: TensorMap swizzle mode. transpose_c: Whether output is transposed. is_lower_frag_required: Whether lower fragment is used.
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
c_smem_shape0
comptime c_smem_shape0 = c_smem_layout.shape[0].value()
CG1_TMA_BM
comptime CG1_TMA_BM = TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].c_smem_shape0
CG2_TMA_BM
comptime CG2_TMA_BM = c_smem_layout.shape[0].value() if (MMA_M == 256) else BM
num_c_smem_tiles
comptime num_c_smem_tiles = ((128 // TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].swizzle_width) // 1 if is_lower_frag_required else 2)
swizzle_width
comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())
TMA_BM
comptime TMA_BM = c_smem_layout.shape[0].value() if (MMA_M == 256) else BM if (cta_group == 2) else TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].CG1_TMA_BM
Methods
execute
static execute[c_layout: Layout, c_desc_layout: Layout](c_smem_tile: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], store_coords: TMAStoreCoords[BM, BN, MMA_M, MMA_N, stageN, cta_group, TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].c_smem_shape0, stage], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], warp_id: UInt32, lane: UInt32)
Execute TMA store with appropriate tiling for the configuration.
Args:
- c_smem_tile (
LayoutTensor): Source shared memory tile. - store_coords (
TMAStoreCoords): Precomputed TMA store coordinates. - c_tma_op (
TMATensorTile): TMA tensor tile for async store operations. - warp_id (
UInt32): Current warp ID. - lane (
UInt32): Current lane ID within warp.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!