Mojo struct
TMAStoreExecutor
struct TMAStoreExecutor[c_type: DType, c_smem_dim0: Int, c_smem_dim1: Int, epc: EpilogueConfig, stage_contiguous_size: Int, c_swizzle: TensorMapSwizzle, batched: Bool = False]
Execute TMA store from SMEM to GMEM with proper tiling.
Handles 3 paths: transpose+cta_group2+MMA128, transpose+other, non-transpose. When batched=True, uses 3D coordinates (M, N, Batch) for TMA stores.
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
BM
comptime BM = epc.BM
BN
comptime BN = epc.BN
c_smem_shape0
comptime c_smem_shape0 = c_smem_dim0
CG1_TMA_BM
comptime CG1_TMA_BM = TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].c_smem_shape0
CG2_TMA_BM
comptime CG2_TMA_BM = TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].c_smem_shape0 if (TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].MMA_M == 256) else TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].BM
cta_group
comptime cta_group = epc.cta_group
is_lower_frag_required
comptime is_lower_frag_required = epc.is_lower_frag_required
MMA_M
comptime MMA_M = epc.MMA_M
MMA_N
comptime MMA_N = epc.MMA_N
num_c_smem_tiles
comptime num_c_smem_tiles = (TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].BM // TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].swizzle_width)
stageN
comptime stageN = epc.stageN
swizzle_width
comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())
TMA_BM
comptime TMA_BM = TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].c_smem_shape0 if (eq epc.MMA_M, 256) else TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].BM if (TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].cta_group == 2) else TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].CG1_TMA_BM
transpose_c
comptime transpose_c = epc.transpose_c
Methods
execute
static execute[tma_rank: Int, tile_shape: IndexList[tma_rank], desc_shape: IndexList[tma_rank]](c_smem_tile: TileTensor[c_type, c_smem_tile.LayoutType, c_smem_tile.origin, address_space=AddressSpace.SHARED, linear_idx_type=c_smem_tile.linear_idx_type, element_size=c_smem_tile.element_size], store_coords: TMAStoreCoords[epc, TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].c_smem_shape0, store_coords.stage, batched], c_tma_op: TMATensorTile[c_type, tma_rank, tile_shape, desc_shape], warp_id: UInt32, lane: UInt32)
Execute TMA store.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!