Skip to main content

Mojo struct

SplitKTileScheduler

@register_passable(trivial) struct SplitKTileScheduler[problem_shape_nk: IndexList[2], tile_shape: IndexList[3], splits: SIMD[uint32, 1], num_consumer: SIMD[uint32, 1], num_pipeline_stages: SIMD[uint32, 1], cluster_shape: IndexList[2], raster_order: RasterOrder, reduction_mode: ReductionMode = ReductionMode(0)]

Fields

  • prob_shape (IndexList[3]):
  • block_id_in_cluster (IndexList[2]):
  • blocks_per_problem (SIMD[uint32, 1]):
  • current_work_linear_idx (SIMD[uint32, 1]):
  • log_cluster_shape_major (SIMD[uint32, 1]):
  • log_cluster_shape_minor (SIMD[uint32, 1]):
  • cluster_blk_major (SIMD[uint32, 1]):
  • locks_ptr (UnsafePointer[SIMD[int32, 1]]):

Implemented traits

AnyType, Copyable, Movable, UnknownDestructibility

Aliases

k_tiles_per_output_tile

alias k_tiles_per_output_tile = ceildiv[::CeilDivable](problem_shape_nk.__getitem__[::Indexer](1), tile_shape.__getitem__[::Indexer](2))

k_tiles_per_split

alias k_tiles_per_split = splits.__rfloordiv__(SIMD(ceildiv[::CeilDivable](problem_shape_nk.__getitem__[::Indexer](1), tile_shape.__getitem__[::Indexer](2))))

log_cluster_size

alias log_cluster_size = log2_floor((cluster_shape.__getitem__[::Indexer](0) * cluster_shape.__getitem__[::Indexer](1)))

Methods

__init__

__init__(prob_shape: IndexList[3], block_id_in_cluster: IndexList[2], locks_ptr: UnsafePointer[NoneType]) -> Self

get_sm_num

get_sm_num(self) -> SIMD[uint32, 1]

get_problem_blocks_shape

static get_problem_blocks_shape(problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> IndexList[2]

initial_work_tile_info

initial_work_tile_info(mut self) -> WorkInfo

get_current_work_info

get_current_work_info(mut self) -> WorkInfo

get_worktile_m_n_idx

get_worktile_m_n_idx(mut self, mut work_tile_info: WorkInfo, linear_tile_id: SIMD[uint32, 1])

assign_work

assign_work(mut self, mut work_tile_info: WorkInfo, linear_idx: SIMD[uint32, 1])

get_k_start_and_linear_tile_id

get_k_start_and_linear_tile_id(mut self, mut work_tile_info: WorkInfo, linear_idx: SIMD[uint32, 1]) -> SIMD[uint32, 1]

fetch_next_work

fetch_next_work(mut self, mut work_tile_info: WorkInfo) -> WorkInfo

requires_reduction

requires_reduction(self, work_tile_info: WorkInfo) -> Bool

advance_to_next_work

advance_to_next_work(mut self)

is_last_split

is_last_split(self, work_tile_info: WorkInfo) -> Bool

get_grid_shape

static get_grid_shape(cluster_shape: IndexList[3], raster_order: RasterOrder = RasterOrder(0)) -> IndexList[3]

get_num_tiles

static get_num_tiles(problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> Int

get_required_locks_buffer_size_bytes

static get_required_locks_buffer_size_bytes[accum_type: DType, num_consumer: SIMD[uint32, 1]](problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> Int

get_linear_idx_from_m_and_n

get_linear_idx_from_m_and_n(self, tile_m: SIMD[uint32, 1], tile_n: SIMD[uint32, 1]) -> SIMD[uint32, 1]

output_tile_index

output_tile_index(self, work_tile_info: WorkInfo) -> SIMD[uint32, 1]

reduction

reduction[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], work_tile_info: WorkInfo, num_barriers: SIMD[uint32, 1], warp_group_local_idx: SIMD[uint32, 1])

wait_eq

static wait_eq(lock_ptr: UnsafePointer[SIMD[int32, 1]], barrier_id: SIMD[int32, 1], barrier_group_thread_idx: Int, lock_idx: SIMD[uint32, 1], val: SIMD[uint32, 1])

wait_lt

static wait_lt(lock_ptr: UnsafePointer[SIMD[int32, 1]], barrier_id: SIMD[int32, 1], barrier_group_thread_idx: Int, lock_idx: SIMD[uint32, 1], count: SIMD[uint32, 1])

arrive_set

static arrive_set(lock_ptr: UnsafePointer[SIMD[int32, 1]], barrier_id: SIMD[int32, 1], barrier_group_thread_idx: Int, lock_idx: SIMD[uint32, 1], increment: SIMD[uint32, 1])

store_accumulator

store_accumulator[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], reduction_tile_idx: SIMD[uint32, 1], warp_group_local_idx: SIMD[uint32, 1], warp_group_thread_idx: SIMD[uint32, 1])

reduce_add

reduce_add[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout, //, *, write_back: Bool](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], reduction_tile_idx: SIMD[uint32, 1], warp_group_local_idx: SIMD[uint32, 1], warp_group_thread_idx: SIMD[uint32, 1])

Was this page helpful?