Mojo struct
GroupedCLCSchedulerIterator
@register_passable(trivial)
struct GroupedCLCSchedulerIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, num_clc_stages: Int, cta_group: Int = 2]
Scheduler warp iterator for grouped GEMM with CLC.
The scheduler warp produces work items for other warps via CLC. It iterates through all tiles across all groups and signals CLC barriers.
Usage: var sched_iter = scheduler.scheduler_iterator() while sched_iter.has_work(): with sched_iter.next(): sched_iter.signal_and_advance() sched_iter.drain()
Fields
- work_info (
GroupedWorkInfo): Current work item. - linear_tile_idx (
UInt32): Current linear tile index. - consumer_state (
PipelineState[num_clc_stages]): - producer_state (
PipelineState[num_clc_stages]): - throttle_pipeline (
GroupedCLCSchedulerIterator[tile_m, tile_n, tile_k, max_groups, num_clc_stages, cta_group].ThrottlePipeline): - full_mbar (
LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]): - empty_mbar (
LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]): - clc_response (
LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]): - cumulative_tiles (
StaticTuple[UInt32, (max_groups + 1)]): - problem_m (
StaticTuple[UInt32, max_groups]): - problem_n (
StaticTuple[UInt32, max_groups]): - problem_k (
StaticTuple[UInt32, max_groups]): - num_groups (
UInt32): - total_tiles (
UInt32): - signal_count (
UInt32): Number of signals sent (for pipeline fill tracking).
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
ThrottlePipeline
comptime ThrottlePipeline = ProducerConsumerPipeline[num_clc_stages]
Methods
__init__
__init__(problem_sizes: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int, full_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], clc_response: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], throttle_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], initial_work: GroupedWorkInfo) -> Self
Initialize scheduler iterator.
has_work
next
next[state_origin: MutOrigin, //](ref[tile_m] self) -> GroupedAdvanceContext[origin_of(state_origin.work_info), origin_of(state_origin.linear_tile_idx)]
Get context manager for advance-after-work pattern.
Returns:
GroupedAdvanceContext
signal_and_advance
signal_and_advance(mut self)
Signal CLC throttle and produce next work request.
This is called inside the work loop after processing current work. It signals that we've consumed the throttle and produces the next work item for all CTAs.
NOTE: We skip the throttle_pipeline.consumer_signal_and_step() call that the hardware CLC version uses. For software CLC simulation, the clc_full/clc_empty barriers provide sufficient synchronization. The throttle pattern causes a deadlock because:
- Scheduler waits for TMA Load via throttle full barrier
- TMA Load waits for Scheduler via throttle empty barrier
- Both block on first iteration since barriers start at phase 0
drain
drain(mut self)
Drain all pending CLC requests before kernel exit.
Only waits for slots that were actually signaled to avoid deadlock when workload is smaller than pipeline depth.
Note: After signaling, producer_state has stepped to the NEXT stage. We need to wait on stages 0..slots_to_drain-1, not from producer_state.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!