Mojo struct
WorkIteratorSplitK
@register_passable(trivial)
struct WorkIteratorSplitK[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int, num_split_k: Int]
Per-warp work iterator for split-K that owns work_info and pipeline state. Throttle pipeline is obtained from the scheduler.
Fields
- scheduler (
WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType): - work_info (
WorkInfo): - consumer_state (
PipelineState[num_stages]): - throttle_pipeline (
WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].ThrottlePipeline):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
SchedulerType
comptime SchedulerType = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]
ThrottlePipeline
comptime ThrottlePipeline = WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType.ThrottlePipeline
Methods
__init__
__init__(scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], work_info: WorkInfo) -> Self
Create work iterator. Throttle pipeline from scheduler.
has_work
next
next[state_origin: MutOrigin, //](ref [state_origin] self) -> AdvanceAfterWorkContextSplitK[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.consumer_state), num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]
Get next work item (advance AFTER work pattern).
Returns:
AdvanceAfterWorkContextSplitK
next_prefetch
next_prefetch[state_origin: MutOrigin, //](ref [state_origin] self) -> PrefetchBeforeWorkContextSplitK[origin_of(state_origin._mlir_origin.work_info)]
Get next work item with prefetch (advance BEFORE work pattern).
Returns:
PrefetchBeforeWorkContextSplitK
throttle_signal
throttle_signal(mut self, is_first_cta_in_cluster: Bool)
Signal CLC throttle if this is the first CTA in cluster.
Args:
- is_first_cta_in_cluster (
Bool): Only first CTA signals to avoid duplicates.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!