Skip to main content

Mojo struct

GroupedWorkIterator

@register_passable(trivial) struct GroupedWorkIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, cta_group: Int = 1]

Per-warp work iterator for grouped GEMM.

This iterator traverses tiles across all groups, tracking when groups change to trigger tensormap updates. It uses linear iteration instead of CLC.

For 2SM (cta_group=2), both CTAs in a cluster work on the same logical tile. The cluster index (block_idx.x // cta_group) is used for tile assignment, and advance step is grid_dim.x // cta_group (number of clusters).

Usage: var work_iter = scheduler.work_iterator() while work_iter.has_work(): var current = work_iter.current() if current.group_changed: update_tensormaps(current.group_idx) process_tile(current) work_iter.advance()

Fields

  • work_info (GroupedWorkInfo): Current work item.
  • linear_tile_idx (UInt32): Current linear tile index (across all groups).
  • total_tiles (UInt32): Total number of tiles across all groups.
  • prev_group_idx (UInt32): Previous group index for detecting group changes.
  • cumulative_tiles (StaticTuple[UInt32, (max_groups + 1)]): Cumulative tile count at the start of each group.
  • problem_m (StaticTuple[UInt32, max_groups]): M dimension for each group.
  • problem_n (StaticTuple[UInt32, max_groups]): N dimension for each group.
  • problem_k (StaticTuple[UInt32, max_groups]): K dimension for each group.
  • num_groups (UInt32): Number of active groups.

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Methods

__init__

__init__(problem_sizes: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int, grid_size: UInt32) -> Self

Initialize work iterator with problem sizes.

Args:

  • problem_sizes (TileTensor): (num_groups, 4) tensor with [M, N, K, L] per group.
  • num_groups (Int): Number of active groups.
  • grid_size (UInt32): Number of blocks in the grid.

has_work

has_work(self) -> Bool

Check if there is more work to process.

Returns:

Bool

current

current(self) -> GroupedWorkInfo

Get current work item.

Returns:

GroupedWorkInfo

advance

advance(mut self)

Advance to next tile.

next

next[state_origin: MutOrigin, //](ref[tile_m] self) -> GroupedAdvanceContext[origin_of(state_origin.work_info), origin_of(state_origin.linear_tile_idx)]

Get context manager that returns current work and advances on exit.

Compatible with the working kernel's pattern: with work_iter.next() as current: process_tile(current) # After: work_iter.work_info updated to next work

Pre-computes next state, then on exit updates work_info and linear_idx.

Returns:

GroupedAdvanceContext

wait_and_advance

wait_and_advance[state_origin: MutOrigin, //](ref[tile_m] self) -> GroupedAdvanceContext[origin_of(state_origin.work_info), origin_of(state_origin.linear_tile_idx)]

Same as next() - no CLC waiting for grouped GEMM.

For compatibility with MMA warp pattern. Since we don't use CLC, this behaves identically to next().

Returns:

GroupedAdvanceContext

Was this page helpful?