Skip to main content

struct TmemFragments[dtype: DType, frag_size: Int, *, is_lower_required: Bool = True, data_paths: Int = 16, bits: Int = 256]

Paired upper/lower accumulator fragments from TMEM.

Encapsulates the SM100 TMEM row-split hardware detail:

  • Upper fragment: rows 0-15 (always present)
  • Lower fragment: rows 16-31 (only when is_lower_required=True)

The is_lower_required flag is determined by:

  • False when cta_group=1 and MMA_M=64 (fits in 16 rows)
  • True otherwise (needs both halves)

Example:

Load both fragments in one call

var frags = TmemFragments[DType.float32, 16].load(tmem_addr)

Work with fragments

frags.upper = process(frags.upper) frags.lower = process(frags.lower)

Store both fragments

frags.store(tmem_addr) TmemFragments.wait_store()

Parameters​

  • ​dtype (DType): Fragment data type (typically float32).
  • ​frag_size (Int): Elements per fragment (derived from data_paths and bits).
  • ​is_lower_required (Bool): Whether lower fragment is needed.
  • ​data_paths (Int): SM100 data paths (typically 16).
  • ​bits (Int): Bits per fragment load (typically 256).

Fields​

  • ​upper (InlineArray[Scalar[dtype], frag_size]):
  • ​lower (InlineArray[Scalar[dtype], frag_size]):

Implemented traits​

AnyType, Copyable, ImplicitlyDestructible, Movable

Methods​

__init__​

__init__(out self)

Initialize with zero fragments.

__init__(out self, upper: InlineArray[Scalar[dtype], frag_size], lower: InlineArray[Scalar[dtype], frag_size])

Initialize with provided fragments.

load​

static load[repeat: Int = 1](tmem: TmemAddress) -> TmemFragments[dtype, (frag_size * repeat), is_lower_required=is_lower_required]

Load fragments from TMEM address.

Loads upper fragment always; loads lower only if required.

Parameters:

  • ​repeat (Int): Number of times to repeat the load pattern.

Args:

Returns:

TmemFragments[dtype, (frag_size * repeat), is_lower_required=is_lower_required]: TmemFragments containing upper and (optionally) lower data.

store​

store[repeat: Int = 1](self, tmem: TmemAddress)

Store fragments to TMEM address.

Stores upper fragment always; stores lower only if required.

Parameters:

  • ​repeat (Int): Number of times to repeat the store pattern.

Args:

cast​

cast[target_dtype: DType](self) -> TmemFragments[target_dtype, frag_size, is_lower_required=is_lower_required]

Cast fragments to a different dtype.

Returns:

TmemFragments[target_dtype, frag_size, is_lower_required=is_lower_required]

wait_load​

static wait_load()

Wait for TMEM load operations to complete.

wait_store​

static wait_store()

Wait for TMEM store operations to complete.