Mojo struct
GroupedCLCWorkIterator
struct GroupedCLCWorkIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, num_clc_stages: Int, cta_group: Int = 2]
Per-warp work iterator for grouped GEMM with CLC barrier support.
This iterator combines grouped GEMM features with CLC-based synchronization for 2SM support. It uses CLC barriers to ensure both CTAs in a cluster process the same tile at the same time.
Usage: var work_iter = scheduler.clc_work_iterator() for current in work_iter: if current.group_changed: update_tensormaps(current.group_idx) process_tile(current)
Fields
- work_info (
GroupedWorkInfo): Current work item. - consumer_state (
PipelineState[num_clc_stages]): CLC consumer pipeline state. - throttle_pipeline (
GroupedCLCWorkIterator[tile_m, tile_n, tile_k, max_groups, num_clc_stages, cta_group].ThrottlePipeline): Throttle pipeline for load/scheduler sync. - full_mbar (
UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): CLC full barriers (signaled by scheduler when work is ready). - empty_mbar (
UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]): CLC empty barriers (signaled by workers when done). - clc_response (
UnsafePointer[UInt128, MutAnyOrigin, address_space=AddressSpace.SHARED]): CLC response storage (contains work info). - cumulative_tiles (
StaticTuple[UInt32, (max_groups + 1)]): Cumulative tile count at the start of each group. - problem_m (
StaticTuple[UInt32, max_groups]): M dimension for each group. - problem_n (
StaticTuple[UInt32, max_groups]): N dimension for each group. - problem_k (
StaticTuple[UInt32, max_groups]): K dimension for each group. - num_groups (
UInt32): Number of active groups. - total_tiles (
UInt32): Total tiles across all groups. - use_clc_fetch (
Bool): If True, next waits on CLC barriers (for MMA warp).
Implemented traits
AnyType,
Copyable,
ImplicitlyDestructible,
Iterable,
Iterator,
Movable,
RegisterPassable
comptime members
Element
comptime Element = GroupedWorkInfo
IteratorType
comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[mut=iterable_mut]] = GroupedCLCWorkIterator[tile_m, tile_n, tile_k, max_groups, num_clc_stages, cta_group]
Parameters
ThrottlePipeline
comptime ThrottlePipeline = ProducerConsumerPipeline[num_clc_stages]
Methods
__init__
__init__(problem_sizes: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int, full_mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], empty_mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], clc_response: UnsafePointer[UInt128, MutAnyOrigin, address_space=AddressSpace.SHARED], throttle_ptr: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], initial_work: GroupedWorkInfo, use_clc_fetch: Bool = False) -> Self
Initialize CLC work iterator.
Args:
- problem_sizes (
TileTensor): (num_groups, 4) tensor with [M, N, K, L] per group. - num_groups (
Int): Number of active groups. - full_mbar (
UnsafePointer): CLC full barrier pointer. - empty_mbar (
UnsafePointer): CLC empty barrier pointer. - clc_response (
UnsafePointer): CLC response storage pointer. - throttle_ptr (
UnsafePointer): Throttle pipeline barrier pointer. - initial_work (
GroupedWorkInfo): Initial work item (first tile). - use_clc_fetch (
Bool): If True, next waits on CLC barriers (for MMA warp).
__iter__
__iter__(ref self) -> Self
__next__
__next__(mut self) -> GroupedWorkInfo
Return current work item and advance.
When use_clc_fetch is True (MMA warp), waits on CLC barriers for synchronization. Otherwise uses simple linear advance.
Returns:
GroupedWorkInfo
Raises:
StopIteration: When there is no more work to process.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!