Mojo struct
GroupedWorkIterator1D1D
struct GroupedWorkIterator1D1D[static_N: Int, tile_shape: IndexList[3], cluster: IndexList[3] = Index[Int, Int, Int](1, 1, 1), cta_group: Int = 1, swizzle: Bool = False, AB_swapped: Bool = False]
Work iterator for 1D-1D grouped block-scaled matmul.
Iterates through work tiles using offset-based addressing:
- a_offsets: Prefix sum of token counts per active expert
- expert_ids: Mapping from active expert index to actual expert ID
- expert_scales: Per-expert output scaling factors
Yields only valid work tiles, skipping invalid ones internally.
Usage: for ctx in work_iter: process_tile(ctx)
Fields
- num_active_experts (
Int): - group_offsets (
GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped].OffsetsTile): - expert_ids (
GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped].ExpertIdsTile): - expert_scales (
GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped].ExpertScalesTile): - current_iter (
Int32): - current_group_idx (
UInt32): - current_dynamic_dim_cumsum (
UInt32): - block_idx_start (
UInt32):
Implemented traits
AnyType,
Copyable,
ImplicitlyDestructible,
Iterable,
Iterator,
Movable
comptime members
cta_group_tile_shape
comptime cta_group_tile_shape = Index[Int, Int]((tile_shape[1] * cta_group), (tile_shape[0] * cta_group)) if AB_swapped else Index[Int, Int](tile_shape[0], (tile_shape[1] * cta_group))
div_dynamic_block
comptime div_dynamic_block = FastDiv(GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped].cta_group_tile_shape[0])
Element
comptime Element = GroupedWorkContext1D1D
ExpertIdsTile
comptime ExpertIdsTile = TileTensor[DType.int32, Layout[*?, *?], MutAnyOrigin]
ExpertScalesTile
comptime ExpertScalesTile = TileTensor[DType.float32, Layout[*?, *?], MutAnyOrigin]
IteratorType
comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[mut=iterable_mut]] = GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped]
Parameters
kNum1DBlocksPerGroup
comptime kNum1DBlocksPerGroup = SIMD(16)
num_static_dim_blocks
comptime num_static_dim_blocks = SIMD(ceildiv(static_N, GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle, AB_swapped].cta_group_tile_shape[1]))
OffsetsTile
comptime OffsetsTile = TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin]
Methods
__init__
__init__(out self, num_active_experts: Int, group_offsets: TileTensor[DType.uint32, Layout[*?, *?], MutAnyOrigin], expert_ids: TileTensor[DType.int32, Layout[*?, *?], MutAnyOrigin], expert_scales: TileTensor[DType.float32, Layout[*?, *?], MutAnyOrigin])
__iter__
__iter__(ref self) -> Self
__next__
__next__(mut self) -> GroupedWorkContext1D1D
Return next valid work tile, skipping invalid ones.
Returns:
GroupedWorkContext1D1D
Raises:
StopIteration: When all work is done.
next
next(mut self) -> GroupedWorkContext1D1D
Fetch next work tile and return context with work info and scale.
Returns:
current_expert_id
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!