Mojo struct
GroupedWorkIterator
@register_passable(trivial)
struct GroupedWorkIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, cta_group: Int = 1]
Per-warp work iterator for grouped GEMM.
This iterator traverses tiles across all groups, tracking when groups change to trigger tensormap updates. It uses linear iteration instead of CLC.
For 2SM (cta_group=2), both CTAs in a cluster work on the same logical tile. The cluster index (block_idx.x // cta_group) is used for tile assignment, and advance step is grid_dim.x // cta_group (number of clusters).
Usage: var work_iter = scheduler.work_iterator() while work_iter.has_work(): var current = work_iter.current() if current.group_changed: update_tensormaps(current.group_idx) process_tile(current) work_iter.advance()
Fields
- work_info (
GroupedWorkInfo): Current work item. - linear_tile_idx (
UInt32): Current linear tile index (across all groups). - total_tiles (
UInt32): Total number of tiles across all groups. - prev_group_idx (
UInt32): Previous group index for detecting group changes. - cumulative_tiles (
StaticTuple[UInt32, (max_groups + 1)]): Cumulative tile count at the start of each group. - problem_m (
StaticTuple[UInt32, max_groups]): M dimension for each group. - problem_n (
StaticTuple[UInt32, max_groups]): N dimension for each group. - problem_k (
StaticTuple[UInt32, max_groups]): K dimension for each group. - num_groups (
UInt32): Number of active groups.
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Methods
__init__
__init__(problem_sizes: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int, grid_size: UInt32) -> Self
Initialize work iterator with problem sizes.
Args:
- problem_sizes (
TileTensor): (num_groups, 4) tensor with [M, N, K, L] per group. - num_groups (
Int): Number of active groups. - grid_size (
UInt32): Number of blocks in the grid.
has_work
current
advance
advance(mut self)
Advance to next tile.
next
next[state_origin: MutOrigin, //](ref[tile_m] self) -> GroupedAdvanceContext[origin_of(state_origin.work_info), origin_of(state_origin.linear_tile_idx)]
Get context manager that returns current work and advances on exit.
Compatible with the working kernel's pattern: with work_iter.next() as current: process_tile(current) # After: work_iter.work_info updated to next work
Pre-computes next state, then on exit updates work_info and linear_idx.
Returns:
GroupedAdvanceContext
wait_and_advance
wait_and_advance[state_origin: MutOrigin, //](ref[tile_m] self) -> GroupedAdvanceContext[origin_of(state_origin.work_info), origin_of(state_origin.linear_tile_idx)]
Same as next() - no CLC waiting for grouped GEMM.
For compatibility with MMA warp pattern. Since we don't use CLC, this behaves identically to next().
Returns:
GroupedAdvanceContext
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!