@register_passable(trivial)
struct TmemTensor[dtype: DType, layout: Layout, *, cta_group: Int = 1]
Typed tensor view over Tensor Memory (TMEM) for MMA accumulators.
Provides a LayoutTensor-like abstraction for TMEM with:
- Type safety: dtype and layout known at compile time
- Fragment access: upper (rows 0-15) and lower (rows 16-31)
- MMA integration: offset() returns raw address for MMA operations
The layout parameter captures the logical accumulator shape (M × N), enabling future extensions like custom tiling patterns or multi-tile accumulator management.
Example:
Create typed TMEM view with (64, 128) accumulator layout
comptime layout = Layout.row_major(64, 128) var tmem = TmemTensorDType.float32, layout
Use with MMA operations (returns raw UInt32 offset)
mma_op.mma(a_tile, b_tile, tmem.offset(), init_c=True)
Load fragments for epilogue
var upper = tmem.load_upperrepeat=4 var lower = tmem.load_lowerrepeat=4 TmemTensor.wait_load()
Parameters
- dtype (
DType): Accumulator data type (typically float32). - layout (
Layout): Logical layout of the accumulator tile (M × N). - cta_group (
Int): CTA cooperation level (1 or 2).
Fields
- col_addr (
Int):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
bits
comptime bits = 256
data_paths
comptime data_paths = 16
frag_size
comptime frag_size = 4
Fragments
comptime Fragments = TmemFragments[dtype, 4, is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]
is_lower_required
comptime is_lower_required = (TmemTensor[dtype, layout, cta_group=cta_group].tile_m == 64) if (eq cta_group._mlir_value, 1) else (cta_group == 1).__bool__().__invert__()
tile_m
comptime tile_m = layout.shape[0].value()
Methods
__init__
__init__(col_addr: Int) -> Self
Create TMEM tensor view at the given column address.
__init__(addr: TmemAddress) -> Self
Create TMEM tensor view from a TmemAddress.
offset
address
address(self) -> TmemAddress
Get TmemAddress for low-level fragment operations.
Returns:
load_upper
load_upper[repeat: Int = 1](self) -> SIMD[dtype, (4 * repeat)]
Load upper accumulator fragment (rows 0-15).
Parameters:
- repeat (
Int): Number of times to repeat the load pattern.
Returns:
SIMD: SIMD vector containing the upper fragment data.
load_lower
load_lower[repeat: Int = 1](self) -> SIMD[dtype, (4 * repeat)]
Load lower accumulator fragment (rows 16-31).
Parameters:
- repeat (
Int): Number of times to repeat the load pattern.
Returns:
SIMD: SIMD vector containing the lower fragment data.
store_upper
store_upper[repeat: Int = 1](self, data: SIMD[dtype, (4 * repeat)])
Store upper accumulator fragment (rows 0-15).
Parameters:
- repeat (
Int): Number of times to repeat the store pattern.
Args:
- data (
SIMD): SIMD vector containing the data to store.
store_lower
store_lower[repeat: Int = 1](self, data: SIMD[dtype, (4 * repeat)])
Store lower accumulator fragment (rows 16-31).
Parameters:
- repeat (
Int): Number of times to repeat the store pattern.
Args:
- data (
SIMD): SIMD vector containing the data to store.
load_fragments
load_fragments[repeat: Int = 1](self) -> TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]
Load both upper and lower fragments in one call.
Handles is_lower_required automatically based on layout.
Parameters:
- repeat (
Int): Number of times to repeat the load pattern.
Returns:
TmemFragments: TmemFragments containing upper and (conditionally) lower data.
store_fragments
store_fragments[repeat: Int = 1](self, frags: TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required])
Store both upper and lower fragments in one call.
Handles is_lower_required automatically based on layout.
Parameters:
- repeat (
Int): Number of times to repeat the store pattern.
Args:
- frags (
TmemFragments): TmemFragments containing upper and (conditionally) lower data.
wait_load
static wait_load()
Wait for TMEM load operations to complete.
wait_store
static wait_store()
Wait for TMEM store operations to complete.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!