Mojo struct
TileScheduler
struct TileScheduler[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32] = Index[Int, Int, Int, dtype=DType.uint32](1, 1, 1), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = 8, num_split_k: Int = 1]
Fieldsβ
- βlocks_ptr (
UnsafePointer[Int32, MutAnyOrigin]): - βscheduler (
TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler): - βtotal_k_tiles (
UInt32): - βk_tiles_per_split (
UInt32): - βthrottle_pipeline (
TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].ThrottlePipeline):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
BKβ
comptime BK = reduction_tile_shape[2]
BMβ
comptime BM = reduction_tile_shape[0]
ClcBarrierArrayβ
comptime ClcBarrierArray = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ClcBarrierArray
ClcResponseArrayβ
comptime ClcResponseArray = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ClcResponseArray
MMA_Nβ
comptime MMA_N = reduction_tile_shape[1]
ROW_SIZEβ
comptime ROW_SIZE = reduction_tile_shape[1] if (TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].BM == 128) else (TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].MMA_N // 2)
ThrottleBarrierArrayβ
comptime ThrottleBarrierArray = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ThrottleBarrierArray
ThrottlePipelineβ
comptime ThrottlePipeline = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ThrottlePipeline
UnderlyingSchedulerβ
comptime UnderlyingScheduler = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]
WorkspaceTileLayoutβ
comptime WorkspaceTileLayout = Layout[*?, *?]
Methodsβ
__init__β
__init__(cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], clc_response: SMemArray[UInt128, num_stages], clc_full: SMemArray[SharedMemBarrier, num_stages], clc_empty: SMemArray[SharedMemBarrier, num_stages], clc_throttle: SMemArray[SharedMemBarrier, (num_stages * 2)], locks_ptr: UnsafePointer[UInt8, MutAnyOrigin]) -> Self
Initialize from typed barrier arrays.
init_throttle_barriersβ
static init_throttle_barriers(storage_ptr: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)
Initialize throttle pipeline barriers. Called once by elect_one thread.
convert_to_splitk_work_infoβ
initial_work_infoβ
advance_to_next_workβ
advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]
Returns:
fetch_next_workβ
fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo
Returns:
throttle_signalβ
throttle_signal(mut self, is_first_cta_in_cluster: Bool)
Signal CLC throttle if this is the first CTA in cluster.
Args:
- βis_first_cta_in_cluster (
Bool): Only first CTA signals to avoid duplicates.
wait_and_advance_workβ
wait_and_advance_work[work_origin: MutOrigin, //](self, ref[num_stages] work_info: WorkInfo, mut consumer_state: PipelineState[num_stages]) -> WaitAndAdvanceContextSplitK[work_origin]
Wait for next work from CLC and advance (Split-K).
Encapsulates the CLC barrier wait (called on scheduler directly).
Usage: with scheduler.wait_and_advance_work(work_info, state) as current: do_mma(current) # After: work_info updated to next value
Returns:
WaitAndAdvanceContextSplitK[work_origin]
work_iteratorβ
work_iterator(self) -> WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]
Create a per-warp work iterator that owns work_info internally. Throttle pipeline is obtained from the scheduler.
Returns:
WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]
scheduler_iteratorβ
scheduler_iterator(self) -> SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]
Create iterator for Scheduler warp (owns work_info and both states). Throttle pipeline is obtained from the scheduler.
Returns:
SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]
is_last_splitβ
output_tile_indexβ
store_to_workspaceβ
store_to_workspace[accum_type: DType, workspace_layout: TensorLayout, /, *, do_reduction: Bool = False, write_back: Bool = False](self, tmem: TmemAddress, reduction_workspace: TileTensor[accum_type, workspace_layout, MutAnyOrigin], epilogue_thread_idx: Int, reduction_tile_idx: UInt32)
reductionβ
reduction[accum_type: DType, workspace_layout: TensorLayout](self, reduction_workspace: TileTensor[accum_type, workspace_layout, MutAnyOrigin], tmem: TmemAddress, epilogue_thread_idx: Int, work_info: WorkInfo) -> Bool
Returns:
wait_eqβ
static wait_eq(lock_ptr: UnsafePointer[Int32, MutAnyOrigin], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)
wait_ltβ
static wait_lt(lock_ptr: UnsafePointer[Int32, MutAnyOrigin], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)
arrive_setβ
static arrive_set(lock_ptr: UnsafePointer[Int32, MutAnyOrigin], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!