@register_passable(trivial)
struct TmemFragments[dtype: DType, frag_size: Int, *, is_lower_required: Bool = True, data_paths: Int = 16, bits: Int = 256]
Paired upper/lower accumulator fragments from TMEM.
Encapsulates the SM100 TMEM row-split hardware detail:
- Upper fragment: rows 0-15 (always present)
- Lower fragment: rows 16-31 (only when is_lower_required=True)
The is_lower_required flag is determined by:
- False when cta_group=1 and MMA_M=64 (fits in 16 rows)
- True otherwise (needs both halves)
Example:
Load both fragments in one call
var frags = TmemFragments[DType.float32, 16].load(tmem_addr)
Work with fragments
frags.upper = process(frags.upper) frags.lower = process(frags.lower)
Store both fragments
frags.store(tmem_addr) TmemFragments.wait_store()
Parameters
- dtype (
DType): Fragment data type (typically float32). - frag_size (
Int): Elements per fragment (derived from data_paths and bits). - is_lower_required (
Bool): Whether lower fragment is needed. - data_paths (
Int): SM100 data paths (typically 16). - bits (
Int): Bits per fragment load (typically 256).
Fields
- upper (
SIMD[dtype, frag_size]): - lower (
SIMD[dtype, frag_size]):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Methods
__init__
__init__() -> Self
Initialize with zero fragments.
__init__(upper: SIMD[dtype, frag_size], lower: SIMD[dtype, frag_size]) -> Self
Initialize with provided fragments.
load
static load[repeat: Int = 1](tmem: TmemAddress) -> TmemFragments[dtype, (frag_size * repeat), is_lower_required=is_lower_required]
Load fragments from TMEM address.
Loads upper fragment always; loads lower only if required.
Parameters:
- repeat (
Int): Number of times to repeat the load pattern.
Args:
- tmem (
TmemAddress): TMEM address to load from.
Returns:
TmemFragments: TmemFragments containing upper and (optionally) lower data.
store
store[repeat: Int = 1](self, tmem: TmemAddress)
Store fragments to TMEM address.
Stores upper fragment always; stores lower only if required.
Parameters:
- repeat (
Int): Number of times to repeat the store pattern.
Args:
- tmem (
TmemAddress): TMEM address to store to.
cast
cast[target_dtype: DType](self) -> TmemFragments[target_dtype, frag_size, is_lower_required=is_lower_required]
Cast fragments to a different dtype.
Returns:
wait_load
static wait_load()
Wait for TMEM load operations to complete.
wait_store
static wait_store()
Wait for TMEM store operations to complete.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!