Mojo struct
TMEMToSMemWriter
@register_passable(trivial)
struct TMEMToSMemWriter[c_type: DType, accum_type: DType, c_smem_layout: Layout, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, num_output_warps: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_c: Bool = False]
Write TMEM accumulator fragments to shared memory for SM100.
This is the SM100-specific equivalent of SM90's FragmentToSMemWriter. Key difference: SM100 accumulators live in Tensor Memory (TMEM), not registers, so we need tcgen05_ld to load them first.
Handles three tile reshaping cases:
- transpose_c + is_lower_frag_required: 2 warps share swizzle blocks
- transpose_c + !is_lower_frag_required: 4 warps, upper only
- !transpose_c: Simple row-major tiling
Template Parameters: c_type: Output data type (e.g., bfloat16). accum_type: Accumulator data type (e.g., float32). c_smem_layout: Shared memory tile layout. BM: Block M dimension. BN: Block N dimension. MMA_M: MMA M dimension. MMA_N: MMA N dimension. stageN: Stage N dimension. cta_group: Number of CTAs cooperating (1 or 2). num_output_warps: Number of warps participating in output. c_swizzle: TMA swizzle mode. transpose_c: Whether output is transposed.
Fields
- warp_id (
UInt32): - lane_id (
UInt32):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Config
comptime Config = EpilogueConfig[MMA_M, MMA_N, stageN, cta_group, transpose_c]
data_paths
comptime data_paths = 16
stage_contiguous_size
comptime stage_contiguous_size = c_smem_layout.shape[1].value()
swizzle
comptime swizzle = make_swizzle[c_type, c_swizzle]()
swizzle_width
comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())
Methods
__init__
__init__(warp_id: UInt32, lane_id: UInt32) -> Self
Initialize the TMEM to SMEM writer.
Args:
write_stage
write_stage[repeat: Int, bits: Int = 256](self, tmem_addr: UInt32, stage: Int, c_smem_tile: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128])
Write a single stage from TMEM to shared memory with tile reshaping.
Automatically handles the correct tile reshaping based on transpose_c and is_lower_frag_required configuration.
Template Parameters: repeat: Repeat factor for fragment loading. bits: TMEM bits width (default 256).
Args:
- tmem_addr (
UInt32): Base tensor memory address. - stage (
Int): Current stage index. - c_smem_tile (
LayoutTensor): Base shared memory tile (will be reshaped internally).
write_fragments
write_fragments[repeat: Int](self, upper_frag: SIMD[c_type, (4 * repeat)], lower_frag: SIMD[c_type, (4 * repeat)], c_smem_tile: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128])
Write pre-loaded fragments to shared memory with tile reshaping.
Use this when fragments are loaded separately (e.g., with load_tmem_fragments) and need to be written after applying register-based epilogue.
Template Parameters: repeat: Repeat factor matching the fragment size.
Args:
- upper_frag (
SIMD): Upper fragment (already casted to c_type). - lower_frag (
SIMD): Lower fragment (already casted to c_type, ignored if not needed). - c_smem_tile (
LayoutTensor): Base shared memory tile (will be reshaped internally).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!