Skip to main content

Mojo struct

GroupedCLCSchedulerIterator

@register_passable(trivial) struct GroupedCLCSchedulerIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, num_clc_stages: Int, cta_group: Int = 2]

Scheduler warp iterator for grouped GEMM with CLC.

The scheduler warp produces work items for other warps via CLC. It iterates through all tiles across all groups and signals CLC barriers.

Usage: var sched_iter = scheduler.scheduler_iterator() while sched_iter.has_work(): with sched_iter.next(): sched_iter.signal_and_advance() sched_iter.drain()

Fields

  • work_info (GroupedWorkInfo): Current work item.
  • linear_tile_idx (UInt32): Current linear tile index.
  • consumer_state (PipelineState[num_clc_stages]):
  • producer_state (PipelineState[num_clc_stages]):
  • throttle_pipeline (GroupedCLCSchedulerIterator[tile_m, tile_n, tile_k, max_groups, num_clc_stages, cta_group].ThrottlePipeline):
  • full_mbar (LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]):
  • empty_mbar (LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]):
  • clc_response (LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]):
  • cumulative_tiles (StaticTuple[UInt32, (max_groups + 1)]):
  • problem_m (StaticTuple[UInt32, max_groups]):
  • problem_n (StaticTuple[UInt32, max_groups]):
  • problem_k (StaticTuple[UInt32, max_groups]):
  • num_groups (UInt32):
  • total_tiles (UInt32):
  • signal_count (UInt32): Number of signals sent (for pipeline fill tracking).

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

ThrottlePipeline

comptime ThrottlePipeline = ProducerConsumerPipeline[num_clc_stages]

Methods

__init__

__init__(problem_sizes: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int, full_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], clc_response: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], throttle_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], initial_work: GroupedWorkInfo) -> Self

Initialize scheduler iterator.

has_work

has_work(self) -> Bool

Check if there is more work to process.

Returns:

Bool

next

next[state_origin: MutOrigin, //](ref[tile_m] self) -> GroupedAdvanceContext[origin_of(state_origin.work_info), origin_of(state_origin.linear_tile_idx)]

Get context manager for advance-after-work pattern.

Returns:

GroupedAdvanceContext

signal_and_advance

signal_and_advance(mut self)

Signal CLC throttle and produce next work request.

This is called inside the work loop after processing current work. It signals that we've consumed the throttle and produces the next work item for all CTAs.

NOTE: We skip the throttle_pipeline.consumer_signal_and_step() call that the hardware CLC version uses. For software CLC simulation, the clc_full/clc_empty barriers provide sufficient synchronization. The throttle pattern causes a deadlock because:

  • Scheduler waits for TMA Load via throttle full barrier
  • TMA Load waits for Scheduler via throttle empty barrier
  • Both block on first iteration since barriers start at phase 0

drain

drain(mut self)

Drain all pending CLC requests before kernel exit.

Only waits for slots that were actually signaled to avoid deadlock when workload is smaller than pipeline depth.

Note: After signaling, producer_state has stepped to the NEXT stage. We need to wait on stages 0..slots_to_drain-1, not from producer_state.

Was this page helpful?