Mojo struct
GroupedWorkIterator
struct GroupedWorkIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, cta_group: Int = 1]
Per-warp work iterator for grouped GEMM using next-style iteration.
This iterator traverses tiles across all groups, tracking when groups change to trigger tensormap updates. It uses linear iteration instead of CLC.
For 2SM (cta_group=2), both CTAs in a cluster work on the same logical tile. The cluster index (block_idx.x // cta_group) is used for tile assignment, and advance step is grid_dim.x // cta_group (number of clusters).
Usage: var work_iter = scheduler.work_iterator() for current in work_iter: if current.group_changed: update_tensormaps(current.group_idx) process_tile(current)
Fields
- work_info (
GroupedWorkInfo): Current work item. - linear_tile_idx (
UInt32): Current linear tile index (across all groups). - total_tiles (
UInt32): Total number of tiles across all groups. - prev_group_idx (
UInt32): Previous group index for detecting group changes. - cumulative_tiles (
StaticTuple[UInt32, (max_groups + 1)]): Cumulative tile count at the start of each group. - problem_m (
StaticTuple[UInt32, max_groups]): M dimension for each group. - problem_n (
StaticTuple[UInt32, max_groups]): N dimension for each group. - problem_k (
StaticTuple[UInt32, max_groups]): K dimension for each group. - num_groups (
UInt32): Number of active groups.
Implemented traits
AnyType,
Copyable,
ImplicitlyDestructible,
Iterable,
Iterator,
Movable,
RegisterPassable
comptime members
Element
comptime Element = GroupedWorkInfo
IteratorType
comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[mut=iterable_mut]] = GroupedWorkIterator[tile_m, tile_n, tile_k, max_groups, cta_group]
Parameters
Methods
__init__
__init__(problem_sizes: TileTensor[DType.int32, Layout[*?, *?], MutAnyOrigin], num_groups: Int, grid_size: UInt32) -> Self
Initialize work iterator with problem sizes.
Args:
- problem_sizes (
TileTensor): (num_groups, 4) tensor with [M, N, K, L] per group. - num_groups (
Int): Number of active groups. - grid_size (
UInt32): Number of blocks in the grid.
__iter__
__iter__(ref self) -> Self
__next__
__next__(mut self) -> GroupedWorkInfo
Return current work item, deferring advance to next call.
Returns:
GroupedWorkInfo
Raises:
StopIteration: When there is no more work to process.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!