Skip to main content

Mojo struct

GroupedTensormapManager

@register_passable(trivial) struct GroupedTensormapManager

Manages tensormap SMEM state and updates for grouped GEMM.

Handles the 4-step CuTe DSL update pattern:

  1. tensormap_fence_acquire() - Acquire fence on block's GMEM tensormap
  2. replace_tensormap_global_address_in_shared_mem() - Update SMEM descriptor
  3. tensormap_cp_fence_release() - Copy SMEM -> block's GMEM tensormap
  4. syncwarp() - Sync before using updated tensormap

TMA descriptor arrays are passed by reference (as UnsafePointer from TMATensorTileArray[blk]) to methods rather than stored by value. This ensures PTX tensormap operations receive valid GMEM addresses with correct address space semantics.

The manager stores only SMEM descriptor pointers, which are shared across all warps within a CTA.

Fields

  • smem (GroupedTensormapSmem):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterType, TrivialRegisterType

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Methods

init_ab_tensormaps

init_ab_tensormaps[a_dtype: DType, a_layout: Layout, a_desc: Layout, b_dtype: DType, b_layout: Layout, b_desc: Layout, sfa_dtype: DType, sfa_layout: Layout, sfa_desc: Layout, sfb_dtype: DType, sfb_layout: Layout, sfb_desc: Layout](self, template_a: TMATensorTile[a_dtype, a_layout, a_desc], template_b: TMATensorTile[b_dtype, b_layout, b_desc], template_sfa: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc], template_sfb: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc])

Initialize A/B/SFA/SFB tensormaps in SMEM from grid-constant templates.

Called by MMA warp (lane 0). Copies template descriptors to SMEM. Templates must be kernel parameters with nvvm.grid_constant metadata.

init_c_tensormap

init_c_tensormap[c_dtype: DType, c_layout: Layout, c_desc: Layout](self, template_c: TMATensorTile[c_dtype, c_layout, c_desc])

Initialize C tensormap in SMEM from grid-constant template.

Called by epilogue warp (lane 0). Copies template descriptor to SMEM.

update_ab_for_group

update_ab_for_group[a_dtype: DType, a_layout: Layout, a_desc: Layout, b_dtype: DType, b_layout: Layout, b_desc: Layout, sfa_dtype: DType, sfa_layout: Layout, sfa_desc: Layout, sfb_dtype: DType, sfb_layout: Layout, sfb_desc: Layout, max_groups: Int](self, group_idx: UInt32, group_a_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], group_b_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], group_sfa_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], group_sfb_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], tma_a: UnsafePointer[TMATensorTile[a_dtype, a_layout, a_desc], MutAnyOrigin], tma_b: UnsafePointer[TMATensorTile[b_dtype, b_layout, b_desc], MutAnyOrigin], tma_sfa: UnsafePointer[TMATensorTile[sfa_dtype, sfa_layout, sfa_desc], MutAnyOrigin], tma_sfb: UnsafePointer[TMATensorTile[sfb_dtype, sfb_layout, sfb_desc], MutAnyOrigin])

Update A/B/SFA/SFB tensormaps for the specified group.

Called when group_changed=True in TMA load warp. TMA pointers must be from TMATensorTileArray[block_idx.x] (GMEM).

update_c_for_group

update_c_for_group[c_dtype: DType, c_layout: Layout, c_desc: Layout, max_groups: Int](self, group_idx: UInt32, group_c_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], tma_c: UnsafePointer[TMATensorTile[c_dtype, c_layout, c_desc], MutAnyOrigin])

Update C tensormap for the specified group.

Called when group_changed=True in epilogue warp. TMA pointer must be from TMATensorTileArray[block_idx.x] (GMEM).

Was this page helpful?