Skip to main content

Mojo struct

GroupedTileScheduler

@register_passable(trivial) struct GroupedTileScheduler[tile_m: Int, tile_n: Int, tile_k: Int, max_groups: Int, num_stages: Int = 0, cta_group: Int = 1]

Tile scheduler for grouped block-scaled GEMM.

Uses linear tile iteration to map tiles across groups. Does not use CLC (Cluster Launch Control) since work distribution is deterministic.

Parameters

  • tile_m (Int): M dimension of output tiles.
  • tile_n (Int): N dimension of output tiles.
  • tile_k (Int): K dimension of input tiles.
  • max_groups (Int): Maximum number of groups.
  • num_stages (Int): Pipeline stages (0 = single wave).
  • cta_group (Int): Number of CTAs cooperating per tile (1 or 2 for 2SM).

Fields

  • num_groups (Int): Number of active groups.
  • problem_sizes (TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin]): Problem sizes tensor (num_groups, 4) with [M, N, K, L] per group.

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Methods

__init__

__init__(problem_sizes: TileTensor[DType.int32, Layout[ComptimeInt[max_groups], ComptimeInt[4], ComptimeInt[4], ComptimeInt[1]], MutAnyOrigin], num_groups: Int) -> Self

Initialize scheduler with problem sizes.

Args:

  • problem_sizes (TileTensor): (num_groups, 4) tensor with [M, N, K, L] per group.
  • num_groups (Int): Number of active groups.

work_iterator

work_iterator(self) -> GroupedWorkIterator[tile_m, tile_n, tile_k, max_groups, cta_group]

Create a per-warp work iterator.

Each warp should create its own work iterator. The iterator owns work_info and cumulative tile counts internally.

For 2SM (cta_group=2), the iterator uses cluster-based indexing.

Returns:

GroupedWorkIterator

total_tiles

total_tiles(self) -> Int

Compute total number of tiles across all groups.

Returns:

Int

Was this page helpful?