Skip to main content

Mojo struct

TileScheduler

@register_passable(trivial) struct TileScheduler[num_stages: Int, cluster_shape: IndexList[3, element_type=DType.uint32] = Index[dtype=DType.uint32](1, 1, 1), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = 8]

Fields

  • cluster_dim (StaticTuple[Int32, 3]):
  • log_cluster_dim_m (FastDiv[DType.uint32]):
  • log_cluster_dim_n (FastDiv[DType.uint32]):
  • log_cluster_dim_k (FastDiv[DType.uint32]):
  • clc_response (LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]):
  • full_mbar (LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]):
  • empty_mbar (LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]):
  • throttle_pipeline (TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size].ThrottlePipeline):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

cluster_size

comptime cluster_size = ((cluster_shape.__getitem__[3, DType.uint32, Int](0) * cluster_shape.__getitem__[3, DType.uint32, Int](1)) * cluster_shape.__getitem__[3, DType.uint32, Int](2))

log_cluster_k

comptime log_cluster_k = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](2))

log_cluster_m

comptime log_cluster_m = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](0))

log_cluster_n

comptime log_cluster_n = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](1))

ThrottlePipeline

comptime ThrottlePipeline = ProducerConsumerPipeline[num_stages]

Methods

__init__

__init__(cluster_dim: StaticTuple[Int32, 3], clc_response_ptr: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], full_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], throttle_storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self

init_throttle_barriers

static init_throttle_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)

Initialize throttle pipeline barriers. Called once by elect_one thread.

Args:

  • storage_ptr (LegacyUnsafePointer): Pointer to shared memory barrier storage.
  • producer_arv_count (Int32): Expected arrival count for producer barriers.
  • consumer_arv_count (Int32): Expected arrival count for consumer barriers.

work_info_from_clc_response

static work_info_from_clc_response(result: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]) -> WorkInfo

Returns:

WorkInfo

work_info_from_cluster

static work_info_from_cluster(work_info: WorkInfo, cluster_dim: StaticTuple[Int32, 3], log_cluster_dim_m: FastDiv[DType.uint32], log_cluster_dim_n: FastDiv[DType.uint32]) -> WorkInfo

Returns:

WorkInfo

initial_work_info

initial_work_info(self) -> WorkInfo

Returns:

WorkInfo

fetch_next_work

fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo

Returns:

WorkInfo

advance_after_work

advance_after_work[work_origin: MutOrigin, state_origin: MutOrigin, //](self, ref [work_origin] work_info: WorkInfo, ref [state_origin] consumer_state: PipelineState[num_stages]) -> AdvanceAfterWorkContext[work_origin, state_origin, num_stages, cluster_shape, rasterize_order, block_swizzle_size]

Context for warps that do work THEN advance (Load/Scheduler/Epilogue).

Usage: with scheduler.advance_after_work(work_info, state) as current: do_work(current) syncwarp() # After: work_info updated, state stepped

Returns:

AdvanceAfterWorkContext

prefetch_before_work

prefetch_before_work[work_origin: MutOrigin, //](self, ref [work_origin] work_info: WorkInfo, mut consumer_state: PipelineState[num_stages]) -> PrefetchBeforeWorkContext[work_origin]

Context for MMA warp that prefetches BEFORE work (software pipelining).

Fetches next work and steps state IMMEDIATELY (before the with block).

Usage: with scheduler.prefetch_before_work(work_info, state) as current: do_mma(current) # Uses current, not prefetched # After: work_info updated to prefetched value

Returns:

PrefetchBeforeWorkContext

work_iterator

work_iterator(self) -> WorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size]

Create a per-warp work iterator with internally managed state.

Each warp should create its own work iterator. The iterator owns work_info, pipeline state, and throttle internally.

Usage: var work_iter = scheduler.work_iterator() while work_iter.has_work(): with work_iter.next() as current: work_iter.throttle_signal(ctx.is_first_cta_in_cluster) do_work(current)

Returns:

WorkIterator

scheduler_iterator

scheduler_iterator(self) -> SchedulerWorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size]

Create iterator for Scheduler warp (owns work_info and both pipeline states).

The Scheduler warp uniquely needs to both consume work responses and produce new work requests. This iterator owns everything internally.

Usage: var sched_iter = scheduler.scheduler_iterator() while sched_iter.has_work(): with sched_iter.next(): sched_iter.signal_and_advance() sched_iter.drain()

Returns:

SchedulerWorkIterator

advance_to_next_work

advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]

Returns:

PipelineState

Was this page helpful?