Skip to main content

Mojo struct

GroupedWorkIterator1D1D

struct GroupedWorkIterator1D1D[static_N: Int, tile_shape: IndexList[3], cluster: IndexList[3] = Index[Int, Int, Int](1, 1, 1), cta_group: Int = 1, swizzle: Bool = False, AB_swapped: Bool = False]

Work iterator for 1D-1D grouped block-scaled matmul.

Iterates through work tiles using offset-based addressing:

  • a_offsets: Prefix sum of token counts per active expert
  • expert_ids: Mapping from active expert index to actual expert ID
  • expert_scales: Per-expert output scaling factors

Yields only valid work tiles, skipping invalid ones internally.

Usage: for ctx in work_iter: process_tile(ctx)

Fields

  • num_active_experts (Int):
  • group_offsets (GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped].OffsetsTile):
  • expert_ids (GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped].ExpertIdsTile):
  • expert_scales (GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped].ExpertScalesTile):
  • current_iter (Int32):
  • current_group_idx (UInt32):
  • current_dynamic_dim_cumsum (UInt32):
  • block_idx_start (UInt32):

Implemented traits

AnyType, Copyable, ImplicitlyDestructible, Iterable, Iterator, Movable

comptime members

cta_group_tile_shape

comptime cta_group_tile_shape = Index[Int, Int]((tile_shape[1] * cta_group), (tile_shape[0] * cta_group)) if AB_swapped else Index[Int, Int](tile_shape[0], (tile_shape[1] * cta_group))

div_dynamic_block

comptime div_dynamic_block = FastDiv(GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped].cta_group_tile_shape[0])

Element

comptime Element = GroupedWorkContext1D1D

ExpertIdsTile

comptime ExpertIdsTile = TileTensor[DType.int32, Layout[*?, *?], MutAnyOrigin]

ExpertScalesTile

comptime ExpertScalesTile = TileTensor[DType.float32, Layout[*?, *?], MutAnyOrigin]

IteratorType

comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[mut=iterable_mut]] = GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped]

Parameters

  • iterable_mut (Bool):
  • iterable_origin (Origin):

kNum1DBlocksPerGroup

comptime kNum1DBlocksPerGroup = SIMD(16)

num_static_dim_blocks

comptime num_static_dim_blocks = SIMD(ceildiv(static_N, GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped].cta_group_tile_shape[1]))

OffsetsTile

comptime OffsetsTile = TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin]

Methods

__init__

__init__(out self, num_active_experts: Int, group_offsets: TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin], expert_ids: TileTensor[DType.int32, Layout[*?, *?], MutAnyOrigin], expert_scales: TileTensor[DType.float32, Layout[*?, *?], MutAnyOrigin])

__iter__

__iter__(ref self) -> Self

__next__

__next__(mut self) -> GroupedWorkContext1D1D

Return next valid work tile, skipping invalid ones.

Returns:

GroupedWorkContext1D1D Raises:

StopIteration: When all work is done.

next

next(mut self) -> GroupedWorkContext1D1D

Fetch next work tile and return context with work info and scale.

Returns:

GroupedWorkContext1D1D

current_expert_id

current_expert_id(self) -> Int32

Get the expert ID for the current group.

Returns:

Int32

Was this page helpful?