Skip to main content

Mojo struct

GroupedWorkIterator

struct GroupedWorkIterator[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, cta_group: Int = 1]

Per-warp work iterator for grouped GEMM using next-style iteration.

This iterator traverses tiles across all groups, tracking when groups change to trigger tensormap updates. It uses linear iteration instead of CLC.

For 2SM (cta_group=2), both CTAs in a cluster work on the same logical tile. The cluster index (block_idx.x // cta_group) is used for tile assignment, and advance step is grid_dim.x // cta_group (number of clusters).

Usage: var work_iter = scheduler.work_iterator() for current in work_iter: if current.group_changed: update_tensormaps(current.group_idx) process_tile(current)

Fields

  • work_info (GroupedWorkInfo): Current work item.
  • linear_tile_idx (UInt32): Current linear tile index (across all groups).
  • total_tiles (UInt32): Total number of tiles across all groups.
  • prev_group_idx (UInt32): Previous group index for detecting group changes.
  • cumulative_tiles (StaticTuple[UInt32, (max_groups + 1)]): Cumulative tile count at the start of each group.
  • problem_m (StaticTuple[UInt32, max_groups]): M dimension for each group.
  • problem_n (StaticTuple[UInt32, max_groups]): N dimension for each group.
  • problem_k (StaticTuple[UInt32, max_groups]): K dimension for each group.
  • num_groups (UInt32): Number of active groups.

Implemented traits

AnyType, Copyable, ImplicitlyDestructible, Iterable, Iterator, Movable, RegisterPassable

comptime members

Element

comptime Element = GroupedWorkInfo

IteratorType

comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[mut=iterable_mut]] = GroupedWorkIterator[tile_m, tile_n, tile_k, max_groups, cta_group]

Parameters

  • iterable_mut (Bool):
  • iterable_origin (Origin):

Methods

__init__

__init__(problem_sizes: TileTensor[DType.int32, Layout[*?, *?], MutAnyOrigin], num_groups: Int, grid_size: UInt32) -> Self

Initialize work iterator with problem sizes.

Args:

  • problem_sizes (TileTensor): (num_groups, 4) tensor with [M, N, K, L] per group.
  • num_groups (Int): Number of active groups.
  • grid_size (UInt32): Number of blocks in the grid.

__iter__

__iter__(ref self) -> Self

__next__

__next__(mut self) -> GroupedWorkInfo

Return current work item, deferring advance to next call.

Returns:

GroupedWorkInfo Raises:

StopIteration: When there is no more work to process.

Was this page helpful?