Mojo struct
GroupedWorkIterator1D1D
struct GroupedWorkIterator1D1D[static_N: Int, tile_shape: IndexList[3], cluster: IndexList[3] = Index(1, 1, 1), cta_group: Int = 1, swizzle: Bool = False]
Work iterator for 1D-1D grouped block-scaled matmul.
Iterates through work tiles using offset-based addressing:
- a_offsets: Prefix sum of token counts per active expert
- expert_ids: Mapping from active expert index to actual expert ID
- expert_scales: Per-expert output scaling factors
Fields
- num_active_experts (
Int): - group_offsets (
GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle].OffsetsTile): - expert_ids (
GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle].ExpertIdsTile): - expert_scales (
GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle].ExpertScalesTile): - current_iter (
Int32): - current_group_idx (
UInt32): - current_dynamic_dim_cumsum (
UInt32): - block_idx_start (
UInt32):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
cta_group_tile_shape
comptime cta_group_tile_shape = Index((tile_shape.__getitem__[Int](0) * cta_group), (tile_shape.__getitem__[Int](1) * cta_group))
div_dynamic_block
comptime div_dynamic_block = FastDiv(GroupedWorkIterator1D1D[static_N, tile_shape, cluster, cta_group, swizzle].cta_group_tile_shape.__getitem__[Int](0))
ExpertIdsTile
comptime ExpertIdsTile = TileTensor[DType.int32, GMEMLayout1D, MutAnyOrigin]
ExpertScalesTile
comptime ExpertScalesTile = TileTensor[DType.float32, GMEMLayout1D, MutAnyOrigin]
kNum1DBlocksPerGroup
comptime kNum1DBlocksPerGroup = 16
num_static_dim_blocks
comptime num_static_dim_blocks = SIMD(ceildiv(static_N, tile_shape.__getitem__[Int](1)))
OffsetsTile
comptime OffsetsTile = TileTensor[DType.uint32, GMEMLayout1D, MutAnyOrigin]
Methods
__init__
__init__(out self, num_active_experts: Int, group_offsets: TileTensor[DType.uint32, GMEMLayout1D, MutAnyOrigin], expert_ids: TileTensor[DType.int32, GMEMLayout1D, MutAnyOrigin], expert_scales: TileTensor[DType.float32, GMEMLayout1D, MutAnyOrigin])
next
next(mut self) -> GroupedWorkContext1D1D
Fetch next work tile and return context with work info and scale.
Returns:
current_expert_id
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!