IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TileScheduler

struct TileScheduler[num_stages: Int, cluster_shape: IndexList[Int(3), element_type=DType.uint32] = Index[Int, Int, Int, dtype=DType.uint32](Int(1), Int(1), Int(1)), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = Int(8)]

Fields​

  • ​cluster_dim (StaticTuple[Int32, Int(3)]):
  • ​log_cluster_dim_m (FastDiv[DType.uint32]):
  • ​log_cluster_dim_n (FastDiv[DType.uint32]):
  • ​log_cluster_dim_k (FastDiv[DType.uint32]):
  • ​clc_response (UnsafePointer[UInt128, MutUntrackedOrigin, address_space=AddressSpace.SHARED]):
  • ​full_mbar (UnsafePointer[SharedMemBarrier, MutUntrackedOrigin, address_space=AddressSpace.SHARED]):
  • ​empty_mbar (UnsafePointer[SharedMemBarrier, MutUntrackedOrigin, address_space=AddressSpace.SHARED]):
  • ​throttle_pipeline (TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size].ThrottlePipeline):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

ClcBarrierArray​

comptime ClcBarrierArray = SMemArray[SharedMemBarrier, num_stages]

ClcResponseArray​

comptime ClcResponseArray = SMemArray[UInt128, num_stages]

cluster_size​

comptime cluster_size = (Int((mul cluster_shape[Int(0)], cluster_shape[Int(1)])) * cluster_shape[Int(2)])

log_cluster_k​

comptime log_cluster_k = FastDiv(cluster_shape[Int(2)])

log_cluster_m​

comptime log_cluster_m = FastDiv(cluster_shape[Int(0)])

log_cluster_n​

comptime log_cluster_n = FastDiv(cluster_shape[Int(1)])

ThrottleBarrierArray​

comptime ThrottleBarrierArray = SMemArray[SharedMemBarrier, (num_stages * Int(2))]

ThrottlePipeline​

comptime ThrottlePipeline = ProducerConsumerPipeline[num_stages]

Methods​

__init__​

def __init__(cluster_dim: StaticTuple[Int32, Int(3)], clc_response: SMemArray[UInt128, num_stages], clc_full: SMemArray[SharedMemBarrier, num_stages], clc_empty: SMemArray[SharedMemBarrier, num_stages], clc_throttle: SMemArray[SharedMemBarrier, (num_stages * Int(2))]) -> Self

Initialize from typed barrier arrays.

init_throttle_barriers​

static def init_throttle_barriers(storage_ptr: UnsafePointer[SharedMemBarrier, MutUntrackedOrigin, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)

Initialize throttle pipeline barriers. Called once by elect_one thread.

work_info_from_clc_response​

static def work_info_from_clc_response(result: UnsafePointer[UInt128, MutUntrackedOrigin, address_space=AddressSpace.SHARED]) -> WorkInfo

Returns:

WorkInfo

work_info_from_cluster​

static def work_info_from_cluster(work_info: WorkInfo, cluster_dim: StaticTuple[Int32, Int(3)], log_cluster_dim_m: FastDiv[DType.uint32], log_cluster_dim_n: FastDiv[DType.uint32]) -> WorkInfo

Returns:

WorkInfo

initial_work_info​

def initial_work_info(self) -> WorkInfo

Returns:

WorkInfo

fetch_next_work​

def fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo

Returns:

WorkInfo

throttle_signal​

def throttle_signal(mut self, is_first_cta_in_cluster: Bool)

Signal CLC throttle if this is the first CTA in cluster.

The Load warp acts as producer for CLC throttle, signaling that it has started processing a new work item. This prevents the scheduler from getting too far ahead.

Args:

  • ​is_first_cta_in_cluster (Bool): Only first CTA signals to avoid duplicates.

wait_and_advance_work​

def wait_and_advance_work[work_origin: MutOrigin, //](self, ref[num_stages] work_info: WorkInfo, mut consumer_state: PipelineState[num_stages]) -> WaitAndAdvanceContext[work_origin]

Wait for next work from CLC and advance.

Encapsulates the CLC barrier wait (called on scheduler directly).

Usage: with scheduler.wait_and_advance_work(work_info, state) as current: do_mma(current) # After: work_info updated to next value

Returns:

WaitAndAdvanceContext[work_origin]

work_iterator​

def work_iterator(self) -> WorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size]

Create a per-warp work iterator using next-style iteration.

Each warp should create its own work iterator. The iterator owns work_info and pipeline state internally.

Usage: var work_iter = scheduler.work_iterator() for current in work_iter: scheduler.throttle_signal(ctx.is_first_cta_in_cluster) do_work(current)

Returns:

WorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size]

scheduler_iterator​

def scheduler_iterator(self) -> SchedulerWorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size]

Create iterator for Scheduler warp (owns work_info and both pipeline states).

The Scheduler warp uniquely needs to both consume work responses and produce new work requests. This iterator owns everything internally.

Usage: var sched_iter = scheduler.scheduler_iterator() for _ in sched_iter: sched_iter.signal_and_advance() sched_iter.drain()

Returns:

SchedulerWorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size]

advance_to_next_work​

def advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]

Returns:

PipelineState[num_stages]