struct TmemTensor[dtype: DType, layout: Layout, *, cta_group: Int = 1]
Typed tensor view over Tensor Memory (TMEM) for MMA accumulators.
Provides a typed abstraction for TMEM with:
- Type safety: dtype and layout known at compile time
- Fragment access: upper (rows 0-15) and lower (rows 16-31)
- MMA integration: offset() returns raw address for MMA operations
The layout parameter captures the logical accumulator shape (M Γ N), enabling future extensions like custom tiling patterns or multi-tile accumulator management.
Example:
Create typed TMEM view with (64, 128) accumulator layout
comptime layout = Layout.row_major(64, 128) var tmem = TmemTensorDType.float32, layout
Use with MMA operations (returns raw UInt32 offset)
mma_op.mma(a_tile, b_tile, tmem.offset(), init_c=True)
Load fragments for epilogue
var upper = tmem.load_upperrepeat=4 var lower = tmem.load_lowerrepeat=4 TmemTensor.wait_load()
Parametersβ
- βdtype (
DType): Accumulator data type (typically float32). - βlayout (
Layout): Logical layout of the accumulator tile (M Γ N). - βcta_group (
Int): CTA cooperation level (1 or 2).
Fieldsβ
- βcol_addr (
Int):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
bitsβ
comptime bits = 256
data_pathsβ
comptime data_paths = 16
frag_sizeβ
comptime frag_size = 4
Fragmentsβ
comptime Fragments = TmemFragments[dtype, 4, is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]
is_lower_requiredβ
comptime is_lower_required = not (TmemTensor[dtype, layout, cta_group=cta_group].tile_m == 64) if (cta_group == 1) else (cta_group == 1).__bool__()
tile_mβ
comptime tile_m = layout.shape[0].value()
Methodsβ
__init__β
__init__(col_addr: Int) -> Self
Create TMEM tensor view at the given column address.
__init__(addr: TmemAddress) -> Self
Create TMEM tensor view from a TmemAddress.
offsetβ
addressβ
address(self) -> TmemAddress
Get TmemAddress for low-level fragment operations.
Returns:
load_upperβ
load_upper[repeat: Int = 1](self) -> InlineArray[Scalar[dtype], (4 * repeat)]
Load upper accumulator fragment (rows 0-15).
Parameters:
- βrepeat (
Int): Number of times to repeat the load pattern.
Returns:
InlineArray[Scalar[dtype], (4 * repeat)]: InlineArray containing the upper fragment data.
load_lowerβ
load_lower[repeat: Int = 1](self) -> InlineArray[Scalar[dtype], (4 * repeat)]
Load lower accumulator fragment (rows 16-31).
Parameters:
- βrepeat (
Int): Number of times to repeat the load pattern.
Returns:
InlineArray[Scalar[dtype], (4 * repeat)]: InlineArray containing the lower fragment data.
store_upperβ
store_upper[repeat: Int = 1](self, data: InlineArray[Scalar[dtype], (4 * repeat)])
Store upper accumulator fragment (rows 0-15).
Parameters:
- βrepeat (
Int): Number of times to repeat the store pattern.
Args:
- βdata (
InlineArray[Scalar[dtype], (4 * repeat)]): InlineArray containing the data to store.
store_lowerβ
store_lower[repeat: Int = 1](self, data: InlineArray[Scalar[dtype], (4 * repeat)])
Store lower accumulator fragment (rows 16-31).
Parameters:
- βrepeat (
Int): Number of times to repeat the store pattern.
Args:
- βdata (
InlineArray[Scalar[dtype], (4 * repeat)]): InlineArray containing the data to store.
load_fragmentsβ
load_fragments[repeat: Int = 1](self) -> TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]
Load both upper and lower fragments in one call.
Handles is_lower_required automatically based on layout.
Parameters:
- βrepeat (
Int): Number of times to repeat the load pattern.
Returns:
TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]: TmemFragments containing upper and (conditionally) lower data.
store_fragmentsβ
store_fragments[repeat: Int = 1](self, frags: TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required])
Store both upper and lower fragments in one call.
Handles is_lower_required automatically based on layout.
Parameters:
- βrepeat (
Int): Number of times to repeat the store pattern.
Args:
- βfrags (
TmemFragments[dtype, (4 * repeat), is_lower_required=TmemTensor[dtype, layout, cta_group=cta_group].is_lower_required]): TmemFragments containing upper and (conditionally) lower data.
wait_loadβ
static wait_load()
Wait for TMEM load operations to complete.
wait_storeβ
static wait_store()
Wait for TMEM store operations to complete.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!