Skip to main content

Mojo struct

TMAStoreExecutor

@register_passable(trivial) struct TMAStoreExecutor[c_type: DType, c_smem_layout: Layout, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, stage_contiguous_size: Int, cta_group: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, is_lower_frag_required: Bool]

Execute TMA store from shared memory to global memory with proper tiling.

Encapsulates all the complex SMEM tiling/reshaping logic for TMA stores. Handles 3 distinct paths based on transpose_c, cta_group, and MMA_M:

  1. transpose_c + cta_group==2 + MMA_M==128: Split reshape
  2. transpose_c + other: Loop over swizzle-width tiles
  3. non-transpose: Simple tile selection

Template Parameters: c_type: Output data type. c_smem_layout: Shared memory layout for C tile. BM: Block M dimension. BN: Block N dimension. MMA_M: MMA M dimension. MMA_N: MMA N dimension. stageN: Stage width in elements. stage_contiguous_size: Contiguous size in SMEM layout. cta_group: Number of CTAs cooperating (1 or 2). c_swizzle: TensorMap swizzle mode. transpose_c: Whether output is transposed. is_lower_frag_required: Whether lower fragment is used.

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

c_smem_shape0

comptime c_smem_shape0 = c_smem_layout.shape[0].value()

CG1_TMA_BM

comptime CG1_TMA_BM = TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].c_smem_shape0

CG2_TMA_BM

comptime CG2_TMA_BM = c_smem_layout.shape[0].value() if (MMA_M == 256) else BM

num_c_smem_tiles

comptime num_c_smem_tiles = ((128 // TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].swizzle_width) // 1 if is_lower_frag_required else 2)

swizzle_width

comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())

TMA_BM

comptime TMA_BM = c_smem_layout.shape[0].value() if (MMA_M == 256) else BM if (cta_group == 2) else TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].CG1_TMA_BM

Methods

execute

static execute[c_layout: Layout, c_desc_layout: Layout](c_smem_tile: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], store_coords: TMAStoreCoords[BM, BN, MMA_M, MMA_N, stageN, cta_group, TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].c_smem_shape0, stage], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], warp_id: UInt32, lane: UInt32)

Execute TMA store with appropriate tiling for the configuration.

Args:

  • c_smem_tile (LayoutTensor): Source shared memory tile.
  • store_coords (TMAStoreCoords): Precomputed TMA store coordinates.
  • c_tma_op (TMATensorTile): TMA tensor tile for async store operations.
  • warp_id (UInt32): Current warp ID.
  • lane (UInt32): Current lane ID within warp.

Was this page helpful?