Mojo struct
SplitKTileScheduler
@register_passable(trivial)
struct SplitKTileScheduler[problem_shape_nk: IndexList[2], tile_shape: IndexList[3], splits: SIMD[uint32, 1], num_consumer: SIMD[uint32, 1], num_pipeline_stages: SIMD[uint32, 1], cluster_shape: IndexList[2], raster_order: RasterOrder, reduction_mode: ReductionMode = ReductionMode(0)]
Fields
- prob_shape (
IndexList[3]
): - block_id_in_cluster (
IndexList[2]
): - blocks_per_problem (
SIMD[uint32, 1]
): - current_work_linear_idx (
SIMD[uint32, 1]
): - log_cluster_shape_major (
SIMD[uint32, 1]
): - log_cluster_shape_minor (
SIMD[uint32, 1]
): - cluster_blk_major (
SIMD[uint32, 1]
): - locks_ptr (
UnsafePointer[SIMD[int32, 1]]
):
Implemented traits
AnyType
,
Copyable
,
Movable
,
UnknownDestructibility
Aliases
k_tiles_per_output_tile
alias k_tiles_per_output_tile = ceildiv[::CeilDivable](problem_shape_nk.__getitem__[::Indexer](1), tile_shape.__getitem__[::Indexer](2))
k_tiles_per_split
alias k_tiles_per_split = splits.__rfloordiv__(SIMD(ceildiv[::CeilDivable](problem_shape_nk.__getitem__[::Indexer](1), tile_shape.__getitem__[::Indexer](2))))
log_cluster_size
alias log_cluster_size = log2_floor((cluster_shape.__getitem__[::Indexer](0) * cluster_shape.__getitem__[::Indexer](1)))
Methods
__init__
__init__(prob_shape: IndexList[3], block_id_in_cluster: IndexList[2], locks_ptr: UnsafePointer[NoneType]) -> Self
get_sm_num
get_sm_num(self) -> SIMD[uint32, 1]
get_problem_blocks_shape
static get_problem_blocks_shape(problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> IndexList[2]
initial_work_tile_info
initial_work_tile_info(mut self) -> WorkInfo
get_current_work_info
get_current_work_info(mut self) -> WorkInfo
get_worktile_m_n_idx
get_worktile_m_n_idx(mut self, mut work_tile_info: WorkInfo, linear_tile_id: SIMD[uint32, 1])
assign_work
assign_work(mut self, mut work_tile_info: WorkInfo, linear_idx: SIMD[uint32, 1])
get_k_start_and_linear_tile_id
get_k_start_and_linear_tile_id(mut self, mut work_tile_info: WorkInfo, linear_idx: SIMD[uint32, 1]) -> SIMD[uint32, 1]
fetch_next_work
fetch_next_work(mut self, mut work_tile_info: WorkInfo) -> WorkInfo
requires_reduction
requires_reduction(self, work_tile_info: WorkInfo) -> Bool
advance_to_next_work
advance_to_next_work(mut self)
is_last_split
is_last_split(self, work_tile_info: WorkInfo) -> Bool
get_grid_shape
static get_grid_shape(cluster_shape: IndexList[3], raster_order: RasterOrder = RasterOrder(0)) -> IndexList[3]
get_num_tiles
static get_num_tiles(problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> Int
get_required_locks_buffer_size_bytes
static get_required_locks_buffer_size_bytes[accum_type: DType, num_consumer: SIMD[uint32, 1]](problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> Int
get_linear_idx_from_m_and_n
get_linear_idx_from_m_and_n(self, tile_m: SIMD[uint32, 1], tile_n: SIMD[uint32, 1]) -> SIMD[uint32, 1]
output_tile_index
output_tile_index(self, work_tile_info: WorkInfo) -> SIMD[uint32, 1]
reduction
reduction[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], work_tile_info: WorkInfo, num_barriers: SIMD[uint32, 1], warp_group_local_idx: SIMD[uint32, 1])
wait_eq
static wait_eq(lock_ptr: UnsafePointer[SIMD[int32, 1]], barrier_id: SIMD[int32, 1], barrier_group_thread_idx: Int, lock_idx: SIMD[uint32, 1], val: SIMD[uint32, 1])
wait_lt
static wait_lt(lock_ptr: UnsafePointer[SIMD[int32, 1]], barrier_id: SIMD[int32, 1], barrier_group_thread_idx: Int, lock_idx: SIMD[uint32, 1], count: SIMD[uint32, 1])
arrive_set
static arrive_set(lock_ptr: UnsafePointer[SIMD[int32, 1]], barrier_id: SIMD[int32, 1], barrier_group_thread_idx: Int, lock_idx: SIMD[uint32, 1], increment: SIMD[uint32, 1])
store_accumulator
store_accumulator[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], reduction_tile_idx: SIMD[uint32, 1], warp_group_local_idx: SIMD[uint32, 1], warp_group_thread_idx: SIMD[uint32, 1])
reduce_add
reduce_add[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout, //, *, write_back: Bool](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], reduction_tile_idx: SIMD[uint32, 1], warp_group_local_idx: SIMD[uint32, 1], warp_group_thread_idx: SIMD[uint32, 1])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!