Mojo struct
GroupedTensormapManager
@register_passable(trivial)
struct GroupedTensormapManager
Manages tensormap SMEM state and updates for grouped GEMM.
Handles the 4-step CuTe DSL update pattern:
- tensormap_fence_acquire() - Acquire fence on block's GMEM tensormap
- replace_tensormap_global_address_in_shared_mem() - Update SMEM descriptor
- tensormap_cp_fence_release() - Copy SMEM -> block's GMEM tensormap
- syncwarp() - Sync before using updated tensormap
TMA descriptor arrays are passed by reference (as UnsafePointer from TMATensorTileArray[blk]) to methods rather than stored by value. This ensures PTX tensormap operations receive valid GMEM addresses with correct address space semantics.
The manager stores only SMEM descriptor pointers, which are shared across all warps within a CTA.
Fields
- smem (
GroupedTensormapSmem):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterType,
TrivialRegisterType
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Methods
init_ab_tensormaps
init_ab_tensormaps[a_dtype: DType, a_layout: Layout, a_desc: Layout, b_dtype: DType, b_layout: Layout, b_desc: Layout, sfa_dtype: DType, sfa_layout: Layout, sfa_desc: Layout, sfb_dtype: DType, sfb_layout: Layout, sfb_desc: Layout](self, template_a: TMATensorTile[a_dtype, a_layout, a_desc], template_b: TMATensorTile[b_dtype, b_layout, b_desc], template_sfa: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc], template_sfb: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc])
Initialize A/B/SFA/SFB tensormaps in SMEM from grid-constant templates.
Called by MMA warp (lane 0). Copies template descriptors to SMEM. Templates must be kernel parameters with nvvm.grid_constant metadata.
init_c_tensormap
init_c_tensormap[c_dtype: DType, c_layout: Layout, c_desc: Layout](self, template_c: TMATensorTile[c_dtype, c_layout, c_desc])
Initialize C tensormap in SMEM from grid-constant template.
Called by epilogue warp (lane 0). Copies template descriptor to SMEM.
update_ab_for_group
update_ab_for_group[a_dtype: DType, a_layout: Layout, a_desc: Layout, b_dtype: DType, b_layout: Layout, b_desc: Layout, sfa_dtype: DType, sfa_layout: Layout, sfa_desc: Layout, sfb_dtype: DType, sfb_layout: Layout, sfb_desc: Layout, max_groups: Int](self, group_idx: UInt32, group_a_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], group_b_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], group_sfa_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], group_sfb_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], tma_a: UnsafePointer[TMATensorTile[a_dtype, a_layout, a_desc], MutAnyOrigin], tma_b: UnsafePointer[TMATensorTile[b_dtype, b_layout, b_desc], MutAnyOrigin], tma_sfa: UnsafePointer[TMATensorTile[sfa_dtype, sfa_layout, sfa_desc], MutAnyOrigin], tma_sfb: UnsafePointer[TMATensorTile[sfb_dtype, sfb_layout, sfb_desc], MutAnyOrigin])
Update A/B/SFA/SFB tensormaps for the specified group.
Called when group_changed=True in TMA load warp. TMA pointers must be from TMATensorTileArray[block_idx.x] (GMEM).
update_c_for_group
update_c_for_group[c_dtype: DType, c_layout: Layout, c_desc: Layout, max_groups: Int](self, group_idx: UInt32, group_c_ptrs: LayoutTensor[DType.uint64, Layout.row_major(max_groups, 1), MutAnyOrigin], tma_c: UnsafePointer[TMATensorTile[c_dtype, c_layout, c_desc], MutAnyOrigin])
Update C tensormap for the specified group.
Called when group_changed=True in epilogue warp. TMA pointer must be from TMATensorTileArray[block_idx.x] (GMEM).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!