For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
SplitKTileScheduler
struct SplitKTileScheduler[locks_origin: MutOrigin, //, problem_shape_nk: IndexList[Int(2)], tile_shape: IndexList[Int(3)], splits: UInt32, num_consumer: UInt32, num_pipeline_stages: UInt32, cluster_shape: IndexList[Int(2)], raster_order: RasterOrder, reduction_mode: ReductionMode = ReductionMode.Deterministic]
Fieldsβ
- βprob_shape (
IndexList[Int(3)]): - βblock_id_in_cluster (
IndexList[Int(2)]): - βblocks_per_problem (
UInt32): - βcurrent_work_linear_idx (
UInt32): - βlog_cluster_shape_major (
UInt32): - βlog_cluster_shape_minor (
UInt32): - βcluster_blk_major (
UInt32): - βlocks_ptr (
UnsafePointer[Int32, locks_origin]):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
k_tiles_per_output_tileβ
comptime k_tiles_per_output_tile = SIMD(ceildiv(problem_shape_nk[Int(1)], tile_shape[Int(2)]))
k_tiles_per_splitβ
comptime k_tiles_per_split = (SIMD(ceildiv(problem_shape_nk[Int(1)], tile_shape[Int(2)])) // splits)
log_cluster_sizeβ
comptime log_cluster_size = log2_floor(Int((mul cluster_shape[Int(0)], cluster_shape[Int(1)])))
WorkTileTypeβ
comptime WorkTileType[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin]
Parametersβ
Methodsβ
__init__β
def __init__(prob_shape: IndexList[Int(3)], block_id_in_cluster: IndexList[Int(2)], locks_ptr: UnsafePointer[UInt8, locks_origin]) -> Self
get_sm_numβ
get_problem_blocks_shapeβ
static def get_problem_blocks_shape(problem_shape: IndexList[Int(3)], dyn_tile_shape: IndexList[Int(3)], dyn_cluster_shape: IndexList[Int(2)]) -> IndexList[Int(2)]
Returns:
initial_work_tile_infoβ
get_current_work_infoβ
get_worktile_m_n_idxβ
def get_worktile_m_n_idx(mut self, mut work_tile_info: WorkInfo, linear_tile_id: UInt32)
assign_workβ
def assign_work(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32)
get_k_start_and_linear_tile_idβ
def get_k_start_and_linear_tile_id(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32) -> UInt32
Returns:
fetch_next_workβ
requires_reductionβ
advance_to_next_workβ
def advance_to_next_work(mut self)
is_last_splitβ
get_grid_shapeβ
static def get_grid_shape(dyn_cluster_shape: IndexList[Int(3)], dyn_raster_order: RasterOrder = RasterOrder.AlongN) -> IndexList[Int(3)]
Returns:
get_num_tilesβ
static def get_num_tiles(problem_shape: IndexList[Int(3)], dyn_tile_shape: IndexList[Int(3)], dyn_cluster_shape: IndexList[Int(2)]) -> Int
Returns:
get_required_locks_buffer_size_bytesβ
static def get_required_locks_buffer_size_bytes[accum_type: DType, dyn_num_consumer: UInt32](problem_shape: IndexList[Int(3)], dyn_tile_shape: IndexList[Int(3)], dyn_cluster_shape: IndexList[Int(2)]) -> Int
Returns:
get_linear_idx_from_m_and_nβ
output_tile_indexβ
reductionβ
def reduction[accum_type: DType, c_reg_layout: Layout, workspace_layout: TensorLayout](self, reduction_workspace: TileTensor[accum_type, workspace_layout], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], work_tile_info: WorkInfo, num_barriers: UInt32, warp_group_local_idx: UInt32)
def reduction[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], work_tile_info: WorkInfo, num_barriers: UInt32, warp_group_local_idx: UInt32)
wait_eqβ
static def wait_eq(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)
wait_ltβ
static def wait_lt(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)
arrive_setβ
static def arrive_set(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, increment: UInt32)
store_accumulatorβ
def store_accumulator[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)
reduce_addβ
def reduce_add[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout, //, *, write_back: Bool](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!