Skip to main content

Mojo struct

TileScheduler

@register_passable(trivial) struct TileScheduler[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32] = Index[dtype=DType.uint32](1, 1, 1), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = 8, num_split_k: Int = 1]

Fields

  • locks_ptr (LegacyUnsafePointer[Int32]):
  • scheduler (TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]):
  • total_k_tiles (UInt32):
  • k_tiles_per_split (UInt32):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

BK

comptime BK = reduction_tile_shape.__getitem__[3, DType.int64, Int](2)

BM

comptime BM = reduction_tile_shape.__getitem__[3, DType.int64, Int](0)

MMA_N

comptime MMA_N = reduction_tile_shape.__getitem__[3, DType.int64, Int](1)

ROW_SIZE

comptime ROW_SIZE = reduction_tile_shape.__getitem__[3, DType.int64, Int](1) if (reduction_tile_shape.__getitem__[3, DType.int64, Int](0) == 128) else (TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].MMA_N // 2)

UnderlyingScheduler

comptime UnderlyingScheduler = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]

Methods

__init__

__init__(cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], clc_response_ptr: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], full_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], locks_ptr: LegacyUnsafePointer[UInt8]) -> Self

convert_to_splitk_work_info

convert_to_splitk_work_info(self, work_info: WorkInfo) -> WorkInfo

Returns:

WorkInfo

initial_work_info

initial_work_info(self) -> WorkInfo

Returns:

WorkInfo

advance_to_next_work

advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]

Returns:

PipelineState

fetch_next_work

fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo

Returns:

WorkInfo

is_last_split

is_last_split(self, work_tile_info: WorkInfo) -> Bool

Returns:

Bool

output_tile_index

output_tile_index(self, work_info: WorkInfo) -> UInt32

Returns:

UInt32

store_to_workspace

store_to_workspace[accum_type: DType, workspace_layout: Layout, /, *, do_reduction: Bool = False, write_back: Bool = False](self, tmem_addr: UInt32, reduction_workspace: LayoutTensor[accum_type, workspace_layout, origin], epilogue_thread_idx: UInt, reduction_tile_idx: UInt32)

reduction

reduction[accum_type: DType, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, origin], tmem_addr: UInt32, epilogue_thread_idx: UInt, work_info: WorkInfo) -> Bool

Returns:

Bool

wait_eq

static wait_eq(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)

wait_lt

static wait_lt(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)

arrive_set

static arrive_set(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)

Was this page helpful?