IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TileScheduler

struct TileScheduler[num_stages: Int, reduction_tile_shape: IndexList[Int(3)], cluster_shape: IndexList[Int(3), element_type=DType.uint32] = Index[Int, Int, Int, dtype=DType.uint32](Int(1), Int(1), Int(1)), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = Int(8), num_split_k: Int = Int(1)]

Fields​

  • ​locks_ptr (UnsafePointer[Int32, MutAnyOrigin]):
  • ​scheduler (TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler):
  • ​total_k_tiles (UInt32):
  • ​k_tiles_per_split (UInt32):
  • ​throttle_pipeline (TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].ThrottlePipeline):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

BK​

comptime BK = reduction_tile_shape[Int(2)]

BM​

comptime BM = reduction_tile_shape[Int(0)]

ClcBarrierArray​

comptime ClcBarrierArray = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ClcBarrierArray

ClcResponseArray​

comptime ClcResponseArray = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ClcResponseArray

MMA_N​

comptime MMA_N = reduction_tile_shape[Int(1)]

ROW_SIZE​

comptime ROW_SIZE = reduction_tile_shape[Int(1)] if (reduction_tile_shape[Int(0)] == Int(128)) else (reduction_tile_shape[Int(1)] // Int(2))

ThrottleBarrierArray​

comptime ThrottleBarrierArray = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ThrottleBarrierArray

ThrottlePipeline​

comptime ThrottlePipeline = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ThrottlePipeline

UnderlyingScheduler​

comptime UnderlyingScheduler = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]

WorkspaceTileLayout​

comptime WorkspaceTileLayout = Layout[*?, *?]

Methods​

__init__​

def __init__(cluster_dim: StaticTuple[Int32, Int(3)], mnk: StaticTuple[UInt32, Int(3)], clc_response: SMemArray[UInt128, num_stages], clc_full: SMemArray[SharedMemBarrier, num_stages], clc_empty: SMemArray[SharedMemBarrier, num_stages], clc_throttle: SMemArray[SharedMemBarrier, (num_stages * Int(2))], locks_ptr: UnsafePointer[UInt8, MutAnyOrigin]) -> Self

Initialize from typed barrier arrays.

init_throttle_barriers​

static def init_throttle_barriers(storage_ptr: UnsafePointer[SharedMemBarrier, MutUntrackedOrigin, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)

Initialize throttle pipeline barriers. Called once by elect_one thread.

convert_to_splitk_work_info​

def convert_to_splitk_work_info(self, work_info: WorkInfo) -> WorkInfo

Returns:

WorkInfo

initial_work_info​

def initial_work_info(self) -> WorkInfo

Returns:

WorkInfo

advance_to_next_work​

def advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]

Returns:

PipelineState[num_stages]

fetch_next_work​

def fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo

Returns:

WorkInfo

throttle_signal​

def throttle_signal(mut self, is_first_cta_in_cluster: Bool)

Signal CLC throttle if this is the first CTA in cluster.

Args:

  • ​is_first_cta_in_cluster (Bool): Only first CTA signals to avoid duplicates.

wait_and_advance_work​

def wait_and_advance_work[work_origin: MutOrigin, //](self, ref[num_stages] work_info: WorkInfo, mut consumer_state: PipelineState[num_stages]) -> WaitAndAdvanceContextSplitK[work_origin]

Wait for next work from CLC and advance (Split-K).

Encapsulates the CLC barrier wait (called on scheduler directly).

Usage: with scheduler.wait_and_advance_work(work_info, state) as current: do_mma(current) # After: work_info updated to next value

Returns:

WaitAndAdvanceContextSplitK[work_origin]

work_iterator​

def work_iterator(self) -> WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]

Create a per-warp work iterator that owns work_info internally. Throttle pipeline is obtained from the scheduler.

Returns:

WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]

scheduler_iterator​

def scheduler_iterator(self) -> SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]

Create iterator for Scheduler warp (owns work_info and both states). Throttle pipeline is obtained from the scheduler.

Returns:

SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]

is_last_split​

def is_last_split(self, work_tile_info: WorkInfo) -> Bool

Returns:

Bool

output_tile_index​

def output_tile_index(self, work_info: WorkInfo) -> UInt32

Returns:

UInt32

store_to_workspace​

def store_to_workspace[accum_type: DType, workspace_layout: TensorLayout, /, *, do_reduction: Bool = False, write_back: Bool = False](self, tmem: TmemAddress, reduction_workspace: TileTensor[accum_type, workspace_layout, MutAnyOrigin], epilogue_thread_idx: Int, reduction_tile_idx: UInt32)

reduction​

def reduction[accum_type: DType, workspace_layout: TensorLayout](self, reduction_workspace: TileTensor[accum_type, workspace_layout, MutAnyOrigin], tmem: TmemAddress, epilogue_thread_idx: Int, work_info: WorkInfo) -> Bool

Returns:

Bool

wait_eq​

static def wait_eq(lock_ptr: UnsafePointer[Int32, MutAnyOrigin], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)

wait_lt​

static def wait_lt(lock_ptr: UnsafePointer[Int32, MutAnyOrigin], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)

arrive_set​

static def arrive_set(lock_ptr: UnsafePointer[Int32, MutAnyOrigin], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)