Skip to main content

Mojo struct

GroupedCLCWorkIterator

struct GroupedCLCWorkIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, num_clc_stages: Int, cta_group: Int = 2]

Per-warp work iterator for grouped GEMM with CLC barrier support.

This iterator combines grouped GEMM features with CLC-based synchronization for 2SM support. It uses CLC barriers to ensure both CTAs in a cluster process the same tile at the same time.

Usage: var work_iter = scheduler.clc_work_iterator() for current in work_iter: if current.group_changed: update_tensormaps(current.group_idx) process_tile(current)

Fields

  • work_info (GroupedWorkInfo): Current work item.
  • consumer_state (PipelineState[num_clc_stages]): CLC consumer pipeline state.
  • throttle_pipeline (GroupedCLCWorkIterator[tile_m, tile_n, tile_k, max_groups, num_clc_stages, cta_group].ThrottlePipeline): Throttle pipeline for load/scheduler sync.
  • full_mbar (UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): CLC full barriers (signaled by scheduler when work is ready).
  • empty_mbar (UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): CLC empty barriers (signaled by workers when done).
  • clc_response (UnsafePointer[UInt128, MutAnyOrigin, address_space=AddressSpace.SHARED]): CLC response storage (contains work info).
  • cumulative_tiles (StaticTuple[UInt32, (max_groups + 1)]): Cumulative tile count at the start of each group.
  • problem_m (StaticTuple[UInt32, max_groups]): M dimension for each group.
  • problem_n (StaticTuple[UInt32, max_groups]): N dimension for each group.
  • problem_k (StaticTuple[UInt32, max_groups]): K dimension for each group.
  • num_groups (UInt32): Number of active groups.
  • total_tiles (UInt32): Total tiles across all groups.
  • use_clc_fetch (Bool): If True, next waits on CLC barriers (for MMA warp).

Implemented traits

AnyType, Copyable, ImplicitlyDestructible, Iterable, Iterator, Movable, RegisterPassable

comptime members

Element

comptime Element = GroupedWorkInfo

IteratorType

comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[mut=iterable_mut]] = GroupedCLCWorkIterator[tile_m, tile_n, tile_k, max_groups, num_clc_stages, cta_group]

Parameters

  • iterable_mut (Bool):
  • iterable_origin (Origin):

ThrottlePipeline

comptime ThrottlePipeline = ProducerConsumerPipeline[num_clc_stages]

Methods

__init__

__init__(problem_sizes: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int, full_mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], empty_mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], clc_response: UnsafePointer[UInt128, MutAnyOrigin, address_space=AddressSpace.SHARED], throttle_ptr: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], initial_work: GroupedWorkInfo, use_clc_fetch: Bool = False) -> Self

Initialize CLC work iterator.

Args:

  • problem_sizes (TileTensor): (num_groups, 4) tensor with [M, N, K, L] per group.
  • num_groups (Int): Number of active groups.
  • full_mbar (UnsafePointer): CLC full barrier pointer.
  • empty_mbar (UnsafePointer): CLC empty barrier pointer.
  • clc_response (UnsafePointer): CLC response storage pointer.
  • throttle_ptr (UnsafePointer): Throttle pipeline barrier pointer.
  • initial_work (GroupedWorkInfo): Initial work item (first tile).
  • use_clc_fetch (Bool): If True, next waits on CLC barriers (for MMA warp).

__iter__

__iter__(ref self) -> Self

__next__

__next__(mut self) -> GroupedWorkInfo

Return current work item and advance.

When use_clc_fetch is True (MMA warp), waits on CLC barriers for synchronization. Otherwise uses simple linear advance.

Returns:

GroupedWorkInfo Raises:

StopIteration: When there is no more work to process.

Was this page helpful?