Mojo struct
GroupedCLCWorkIterator
@register_passable(trivial)
struct GroupedCLCWorkIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, num_clc_stages: Int, cta_group: Int = 2]
Per-warp work iterator for grouped GEMM with CLC barrier support.
This iterator combines grouped GEMM features with CLC-based synchronization for 2SM support. It uses CLC barriers to ensure both CTAs in a cluster process the same tile at the same time.
Key features:
- Uses CLC barriers for inter-CTA synchronization (like working kernel)
- Tracks group_idx, k_tile_count, group_changed (like grouped scheduler)
- wait_and_advance() actually waits on CLC barriers
Usage: var work_iter = scheduler.clc_work_iterator() while work_iter.has_work(): with work_iter.wait_and_advance() as current: if current.group_changed: update_tensormaps(current.group_idx) process_tile(current)
Fields
- work_info (
GroupedWorkInfo): Current work item. - consumer_state (
PipelineState[num_clc_stages]): CLC consumer pipeline state. - throttle_pipeline (
GroupedCLCWorkIterator[tile_m, tile_n, tile_k, max_groups, num_clc_stages, cta_group].ThrottlePipeline): Throttle pipeline for load/scheduler sync. - full_mbar (
LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]): CLC full barriers (signaled by scheduler when work is ready). - empty_mbar (
LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]): CLC empty barriers (signaled by workers when done). - clc_response (
LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]): CLC response storage (contains work info). - cumulative_tiles (
StaticTuple[UInt32, (max_groups + 1)]): Cumulative tile count at the start of each group. - problem_m (
StaticTuple[UInt32, max_groups]): M dimension for each group. - problem_n (
StaticTuple[UInt32, max_groups]): N dimension for each group. - problem_k (
StaticTuple[UInt32, max_groups]): K dimension for each group. - num_groups (
UInt32): Number of active groups. - total_tiles (
UInt32): Total tiles across all groups.
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterType,
TrivialRegisterType
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
ThrottlePipeline
comptime ThrottlePipeline = ProducerConsumerPipeline[num_clc_stages]
Methods
__init__
__init__(problem_sizes: LayoutTensor[DType.int32, Layout.row_major(max_groups, 4), MutAnyOrigin], num_groups: Int, full_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], clc_response: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], throttle_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], initial_work: GroupedWorkInfo) -> Self
Initialize CLC work iterator.
Args:
- problem_sizes (
LayoutTensor): (num_groups, 4) tensor with [M, N, K, L] per group. - num_groups (
Int): Number of active groups. - full_mbar (
LegacyUnsafePointer): CLC full barrier pointer. - empty_mbar (
LegacyUnsafePointer): CLC empty barrier pointer. - clc_response (
LegacyUnsafePointer): CLC response storage pointer. - throttle_ptr (
LegacyUnsafePointer): Throttle pipeline barrier pointer. - initial_work (
GroupedWorkInfo): Initial work item (first tile).
has_work
wait_and_advance
wait_and_advance[state_origin: MutOrigin, //](ref[state_origin] self) -> GroupedCLCWaitAndAdvanceContext[origin_of(state_origin._mlir_origin.work_info)]
Wait for next work from CLC and advance iterator.
This method waits on CLC full barriers to synchronize all CTAs in the cluster before advancing to the next work item.
Usage: with work_iter.wait_and_advance() as current: # Process current work item # After exit, work_iter points to next work
Returns:
GroupedCLCWaitAndAdvanceContext
next
next[state_origin: MutOrigin, //](ref[state_origin] self) -> GroupedAdvanceContext[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.total_tiles)]
Get context manager for advance-after-work pattern.
Does NOT wait on CLC - use wait_and_advance() for MMA warp.
Returns:
GroupedAdvanceContext
throttle_signal
throttle_signal(mut self, is_first_cta_in_cluster: Bool)
Signal CLC throttle if this is the first CTA in cluster.
NOTE: For software CLC simulation, this is a no-op. The throttle pattern causes a deadlock because both Scheduler and TMA Load wait on each other's barriers on the first iteration. The CLC full/empty barriers provide sufficient synchronization without the throttle.
Args:
- is_first_cta_in_cluster (
Bool): Only first CTA signals to avoid duplicates.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!