Mojo struct
SplitKTileScheduler
struct SplitKTileScheduler[problem_shape_nk: IndexList[2], tile_shape: IndexList[3], splits: UInt32, num_consumer: UInt32, num_pipeline_stages: UInt32, cluster_shape: IndexList[2], raster_order: RasterOrder, reduction_mode: ReductionMode = ReductionMode.Deterministic]
Fieldsβ
- βprob_shape (
IndexList[3]): - βblock_id_in_cluster (
IndexList[2]): - βblocks_per_problem (
UInt32): - βcurrent_work_linear_idx (
UInt32): - βlog_cluster_shape_major (
UInt32): - βlog_cluster_shape_minor (
UInt32): - βcluster_blk_major (
UInt32): - βlocks_ptr (
UnsafePointer[Int32, MutAnyOrigin]):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
k_tiles_per_output_tileβ
comptime k_tiles_per_output_tile = SIMD(ceildiv(problem_shape_nk[1], tile_shape[2]))
k_tiles_per_splitβ
comptime k_tiles_per_split = (SIMD(ceildiv(problem_shape_nk[1], tile_shape[2])) // splits)
log_cluster_sizeβ
comptime log_cluster_size = log2_floor((cluster_shape[0] * cluster_shape[1]))
WorkTileTypeβ
comptime WorkTileType[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin]
Parametersβ
Methodsβ
__init__β
__init__(prob_shape: IndexList[3], block_id_in_cluster: IndexList[2], locks_ptr: UnsafePointer[UInt8]) -> Self
get_sm_numβ
get_problem_blocks_shapeβ
static get_problem_blocks_shape(problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> IndexList[2]
Returns:
initial_work_tile_infoβ
get_current_work_infoβ
get_worktile_m_n_idxβ
get_worktile_m_n_idx(mut self, mut work_tile_info: WorkInfo, linear_tile_id: UInt32)
assign_workβ
assign_work(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32)
get_k_start_and_linear_tile_idβ
get_k_start_and_linear_tile_id(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32) -> UInt32
Returns:
fetch_next_workβ
requires_reductionβ
advance_to_next_workβ
advance_to_next_work(mut self)
is_last_splitβ
get_grid_shapeβ
static get_grid_shape(dyn_cluster_shape: IndexList[3], dyn_raster_order: RasterOrder = RasterOrder.AlongN) -> IndexList[3]
Returns:
get_num_tilesβ
static get_num_tiles(problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> Int
Returns:
get_required_locks_buffer_size_bytesβ
static get_required_locks_buffer_size_bytes[accum_type: DType, dyn_num_consumer: UInt32](problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> Int
Returns:
get_linear_idx_from_m_and_nβ
output_tile_indexβ
reductionβ
reduction[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], work_tile_info: WorkInfo, num_barriers: UInt32, warp_group_local_idx: UInt32)
wait_eqβ
static wait_eq(lock_ptr: UnsafePointer[Int32, MutAnyOrigin], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)
wait_ltβ
static wait_lt(lock_ptr: UnsafePointer[Int32, MutAnyOrigin], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)
arrive_setβ
static arrive_set(lock_ptr: UnsafePointer[Int32, MutAnyOrigin], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, increment: UInt32)
store_accumulatorβ
store_accumulator[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)
reduce_addβ
reduce_add[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout, //, *, write_back: Bool](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!