Skip to main content

Mojo struct

TMEMToSMemWriter

@register_passable(trivial) struct TMEMToSMemWriter[c_type: DType, accum_type: DType, c_smem_layout: Layout, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, num_output_warps: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_c: Bool = False]

Write TMEM accumulator fragments to shared memory for SM100.

This is the SM100-specific equivalent of SM90's FragmentToSMemWriter. Key difference: SM100 accumulators live in Tensor Memory (TMEM), not registers, so we need tcgen05_ld to load them first.

Handles three tile reshaping cases:

  1. transpose_c + is_lower_frag_required: 2 warps share swizzle blocks
  2. transpose_c + !is_lower_frag_required: 4 warps, upper only
  3. !transpose_c: Simple row-major tiling

Template Parameters: c_type: Output data type (e.g., bfloat16). accum_type: Accumulator data type (e.g., float32). c_smem_layout: Shared memory tile layout. BM: Block M dimension. BN: Block N dimension. MMA_M: MMA M dimension. MMA_N: MMA N dimension. stageN: Stage N dimension. cta_group: Number of CTAs cooperating (1 or 2). num_output_warps: Number of warps participating in output. c_swizzle: TMA swizzle mode. transpose_c: Whether output is transposed.

Fields

  • warp_id (UInt32):
  • lane_id (UInt32):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Config

comptime Config = EpilogueConfig[MMA_M, MMA_N, stageN, cta_group, transpose_c]

data_paths

comptime data_paths = 16

stage_contiguous_size

comptime stage_contiguous_size = c_smem_layout.shape[1].value()

swizzle

comptime swizzle = make_swizzle[c_type, c_swizzle]()

swizzle_width

comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())

Methods

__init__

__init__(warp_id: UInt32, lane_id: UInt32) -> Self

Initialize the TMEM to SMEM writer.

Args:

  • warp_id (UInt32): Warp ID within the CTA.
  • lane_id (UInt32): Lane ID within the warp.

write_stage

write_stage[repeat: Int, bits: Int = 256](self, tmem_addr: UInt32, stage: Int, c_smem_tile: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128])

Write a single stage from TMEM to shared memory with tile reshaping.

Automatically handles the correct tile reshaping based on transpose_c and is_lower_frag_required configuration.

Template Parameters: repeat: Repeat factor for fragment loading. bits: TMEM bits width (default 256).

Args:

  • tmem_addr (UInt32): Base tensor memory address.
  • stage (Int): Current stage index.
  • c_smem_tile (LayoutTensor): Base shared memory tile (will be reshaped internally).

write_fragments

write_fragments[repeat: Int](self, upper_frag: SIMD[c_type, (4 * repeat)], lower_frag: SIMD[c_type, (4 * repeat)], c_smem_tile: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128])

Write pre-loaded fragments to shared memory with tile reshaping.

Use this when fragments are loaded separately (e.g., with load_tmem_fragments) and need to be written after applying register-based epilogue.

Template Parameters: repeat: Repeat factor matching the fragment size.

Args:

  • upper_frag (SIMD): Upper fragment (already casted to c_type).
  • lower_frag (SIMD): Lower fragment (already casted to c_type, ignored if not needed).
  • c_smem_tile (LayoutTensor): Base shared memory tile (will be reshaped internally).

Was this page helpful?