Skip to main content

Mojo struct

WorkIteratorSplitK

@register_passable(trivial) struct WorkIteratorSplitK[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int, num_split_k: Int]

Per-warp work iterator for split-K that owns work_info and pipeline state. Throttle pipeline is obtained from the scheduler.

Fields

  • scheduler (WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType):
  • work_info (WorkInfo):
  • consumer_state (PipelineState[num_stages]):
  • throttle_pipeline (WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].ThrottlePipeline):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

SchedulerType

comptime SchedulerType = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]

ThrottlePipeline

comptime ThrottlePipeline = WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType.ThrottlePipeline

Methods

__init__

__init__(scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], work_info: WorkInfo) -> Self

Create work iterator. Throttle pipeline from scheduler.

has_work

has_work(self) -> Bool

Check if there is more work to process.

Returns:

Bool

next

next[state_origin: MutOrigin, //](ref [state_origin] self) -> AdvanceAfterWorkContextSplitK[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.consumer_state), num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]

Get next work item (advance AFTER work pattern).

Returns:

AdvanceAfterWorkContextSplitK

next_prefetch

next_prefetch[state_origin: MutOrigin, //](ref [state_origin] self) -> PrefetchBeforeWorkContextSplitK[origin_of(state_origin._mlir_origin.work_info)]

Get next work item with prefetch (advance BEFORE work pattern).

Returns:

PrefetchBeforeWorkContextSplitK

throttle_signal

throttle_signal(mut self, is_first_cta_in_cluster: Bool)

Signal CLC throttle if this is the first CTA in cluster.

Args:

  • is_first_cta_in_cluster (Bool): Only first CTA signals to avoid duplicates.

Was this page helpful?