Skip to main content

Mojo struct

GroupedWorkIterator1D1D

struct GroupedWorkIterator1D1D[static_N: Int, tile_shape: IndexList[3], cluster: IndexList[3] = Index(1, 1, 1), cta_group: Int = 1, swizzle: Bool = False]

Work iterator for 1D-1D grouped block-scaled matmul.

Iterates through work tiles using offset-based addressing:

  • a_offsets: Prefix sum of token counts per active expert
  • expert_ids: Mapping from active expert index to actual expert ID
  • expert_scales: Per-expert output scaling factors

Fields

  • num_active_experts (Int):
  • group_offsets (GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle].OffsetsTile):
  • expert_ids (GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle].ExpertIdsTile):
  • expert_scales (GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle].ExpertScalesTile):
  • current_iter (Int32):
  • current_group_idx (UInt32):
  • current_dynamic_dim_cumsum (UInt32):
  • block_idx_start (UInt32):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

cta_group_tile_shape

comptime cta_group_tile_shape = Index((tile_shape.__getitem__[Int](0) * cta_group), (tile_shape.__getitem__[Int](1) * cta_group))

div_dynamic_block

comptime div_dynamic_block = FastDiv(GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle].cta_group_tile_shape.__getitem__[Int](0))

ExpertIdsTile

comptime ExpertIdsTile = TileTensor[DType.int32, GMEMLayout1D, MutAnyOrigin]

ExpertScalesTile

comptime ExpertScalesTile = TileTensor[DType.float32, GMEMLayout1D, MutAnyOrigin]

kNum1DBlocksPerGroup

comptime kNum1DBlocksPerGroup = 16

num_static_dim_blocks

comptime num_static_dim_blocks = SIMD(ceildiv(static_N, tile_shape.__getitem__[Int](1)))

OffsetsTile

comptime OffsetsTile = TileTensor[DType.uint32, GMEMLayout1D, MutAnyOrigin]

Methods

__init__

__init__(out self, num_active_experts: Int, group_offsets: TileTensor[DType.uint32, GMEMLayout1D, MutAnyOrigin], expert_ids: TileTensor[DType.int32, GMEMLayout1D, MutAnyOrigin], expert_scales: TileTensor[DType.float32, GMEMLayout1D, MutAnyOrigin])

next

next(mut self) -> GroupedWorkContext1D1D

Fetch next work tile and return context with work info and scale.

Returns:

GroupedWorkContext1D1D

current_expert_id

current_expert_id(self) -> Int32

Get the expert ID for the current group.

Returns:

Int32

Was this page helpful?