Skip to main content

Mojo struct

TMAStoreExecutor

struct TMAStoreExecutor[c_type: DType, c_smem_dim0: Int, c_smem_dim1: Int, epc: EpilogueConfig, stage_contiguous_size: Int, c_swizzle: TensorMapSwizzle, batched: Bool = False]

Execute TMA store from SMEM to GMEM with proper tiling.

Handles 3 paths: transpose+cta_group2+MMA128, transpose+other, non-transpose. When batched=True, uses 3D coordinates (M, N, Batch) for TMA stores.

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

BM

comptime BM = epc.BM

BN

comptime BN = epc.BN

c_smem_shape0

comptime c_smem_shape0 = c_smem_dim0

CG1_TMA_BM

comptime CG1_TMA_BM = TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].c_smem_shape0

CG2_TMA_BM

comptime CG2_TMA_BM = TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].c_smem_shape0 if (TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].MMA_M == 256) else TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].BM

cta_group

comptime cta_group = epc.cta_group

is_lower_frag_required

comptime is_lower_frag_required = epc.is_lower_frag_required

MMA_M

comptime MMA_M = epc.MMA_M

MMA_N

comptime MMA_N = epc.MMA_N

num_c_smem_tiles

comptime num_c_smem_tiles = (TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].BM // TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].swizzle_width)

stageN

comptime stageN = epc.stageN

swizzle_width

comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())

TMA_BM

comptime TMA_BM = TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].c_smem_shape0 if (eq epc.MMA_M, 256) else TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].BM if (TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].cta_group == 2) else TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].CG1_TMA_BM

transpose_c

comptime transpose_c = epc.transpose_c

Methods

execute

static execute[tma_rank: Int, tile_shape: IndexList[tma_rank], desc_shape: IndexList[tma_rank]](c_smem_tile: TileTensor[c_type, c_smem_tile.LayoutType, c_smem_tile.origin, address_space=AddressSpace.SHARED, linear_idx_type=c_smem_tile.linear_idx_type, element_size=c_smem_tile.element_size], store_coords: TMAStoreCoords[epc, TMAStoreExecutor[c_type, c_smem_dim0, c_smem_dim1, epc, stage_contiguous_size, c_swizzle, batched].c_smem_shape0, store_coords.stage, batched], c_tma_op: TMATensorTile[c_type, tma_rank, tile_shape, desc_shape], warp_id: UInt32, lane: UInt32)

Execute TMA store.

Was this page helpful?