Mojo struct
TileScheduler
@register_passable(trivial)
struct TileScheduler[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32] = Index[dtype=DType.uint32](1, 1, 1), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = 8, num_split_k: Int = 1]
Fields
- locks_ptr (
LegacyUnsafePointer[Int32]): - scheduler (
TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler): - total_k_tiles (
UInt32): - k_tiles_per_split (
UInt32): - throttle_pipeline (
TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].ThrottlePipeline):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
BK
comptime BK = reduction_tile_shape.__getitem__[3, DType.int64, Int](2)
BM
comptime BM = reduction_tile_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = reduction_tile_shape.__getitem__[3, DType.int64, Int](1)
ROW_SIZE
comptime ROW_SIZE = reduction_tile_shape.__getitem__[3, DType.int64, Int](1) if (reduction_tile_shape.__getitem__[3, DType.int64, Int](0) == 128) else (TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].MMA_N // 2)
ThrottlePipeline
comptime ThrottlePipeline = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ThrottlePipeline
UnderlyingScheduler
comptime UnderlyingScheduler = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]
Methods
__init__
__init__(cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], clc_response_ptr: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], full_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], throttle_storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], locks_ptr: LegacyUnsafePointer[UInt8]) -> Self
init_throttle_barriers
static init_throttle_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)
Initialize throttle pipeline barriers. Called once by elect_one thread.
convert_to_splitk_work_info
initial_work_info
advance_to_next_work
advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]
Returns:
fetch_next_work
fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo
Returns:
advance_after_work
advance_after_work[work_origin: MutOrigin, state_origin: MutOrigin, //](self, ref [work_origin] work_info: WorkInfo, ref [state_origin] consumer_state: PipelineState[num_stages]) -> AdvanceAfterWorkContextSplitK[work_origin, state_origin, num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]
Context for warps that do work THEN advance (Load/Scheduler/Epilogue).
Usage: with scheduler.advance_after_work(work_info, state) as current: do_work(current) syncwarp() # After: work_info updated, state stepped
Returns:
AdvanceAfterWorkContextSplitK
prefetch_before_work
prefetch_before_work[work_origin: MutOrigin, //](self, ref [work_origin] work_info: WorkInfo, mut consumer_state: PipelineState[num_stages]) -> PrefetchBeforeWorkContextSplitK[work_origin]
Context for MMA warp that prefetches BEFORE work (software pipelining).
Fetches next work and steps state IMMEDIATELY (before the with block).
Usage: with scheduler.prefetch_before_work(work_info, state) as current: do_mma(current) # Uses current, not prefetched # After: work_info updated to prefetched value
Returns:
PrefetchBeforeWorkContextSplitK
work_iterator
work_iterator(self) -> WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]
Create a per-warp work iterator that owns work_info internally. Throttle pipeline is obtained from the scheduler.
Returns:
WorkIteratorSplitK
scheduler_iterator
scheduler_iterator(self) -> SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]
Create iterator for Scheduler warp (owns work_info and both states). Throttle pipeline is obtained from the scheduler.
Returns:
SchedulerWorkIteratorSplitK
is_last_split
output_tile_index
store_to_workspace
store_to_workspace[accum_type: DType, workspace_layout: Layout, /, *, do_reduction: Bool = False, write_back: Bool = False](self, tmem_addr: UInt32, reduction_workspace: LayoutTensor[accum_type, workspace_layout, origin], epilogue_thread_idx: UInt, reduction_tile_idx: UInt32)
reduction
reduction[accum_type: DType, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, origin], tmem_addr: UInt32, epilogue_thread_idx: UInt, work_info: WorkInfo) -> Bool
Returns:
wait_eq
static wait_eq(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)
wait_lt
static wait_lt(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)
arrive_set
static arrive_set(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!