IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SplitKTileScheduler

struct SplitKTileScheduler[locks_origin: MutOrigin, //, problem_shape_nk: IndexList[2], tile_shape: IndexList[3], splits: UInt32, num_consumer: UInt32, num_pipeline_stages: UInt32, cluster_shape: IndexList[2], raster_order: RasterOrder, reduction_mode: ReductionMode = ReductionMode.Deterministic]

Fields​

  • ​prob_shape (IndexList[3]):
  • ​block_id_in_cluster (IndexList[2]):
  • ​blocks_per_problem (UInt32):
  • ​current_work_linear_idx (UInt32):
  • ​log_cluster_shape_major (UInt32):
  • ​log_cluster_shape_minor (UInt32):
  • ​cluster_blk_major (UInt32):
  • ​locks_ptr (UnsafePointer[Int32, locks_origin]):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

k_tiles_per_output_tile​

comptime k_tiles_per_output_tile = SIMD(ceildiv(problem_shape_nk[1], tile_shape[2]))

k_tiles_per_split​

comptime k_tiles_per_split = (SIMD(ceildiv(problem_shape_nk[1], tile_shape[2])) // splits)

log_cluster_size​

comptime log_cluster_size = log2_floor((cluster_shape[0] * cluster_shape[1]))

WorkTileType​

comptime WorkTileType[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin]

Parameters​

Methods​

__init__​

def __init__(prob_shape: IndexList[3], block_id_in_cluster: IndexList[2], locks_ptr: UnsafePointer[UInt8, locks_origin]) -> Self

get_sm_num​

def get_sm_num(self) -> UInt32

Returns:

UInt32

get_problem_blocks_shape​

static def get_problem_blocks_shape(problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> IndexList[2]

Returns:

IndexList[2]

initial_work_tile_info​

def initial_work_tile_info(mut self) -> WorkInfo

Returns:

WorkInfo

get_current_work_info​

def get_current_work_info(mut self) -> WorkInfo

Returns:

WorkInfo

get_worktile_m_n_idx​

def get_worktile_m_n_idx(mut self, mut work_tile_info: WorkInfo, linear_tile_id: UInt32)

assign_work​

def assign_work(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32)

get_k_start_and_linear_tile_id​

def get_k_start_and_linear_tile_id(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32) -> UInt32

Returns:

UInt32

fetch_next_work​

def fetch_next_work(mut self, mut work_tile_info: WorkInfo) -> WorkInfo

Returns:

WorkInfo

requires_reduction​

def requires_reduction(self, work_tile_info: WorkInfo) -> Bool

Returns:

Bool

advance_to_next_work​

def advance_to_next_work(mut self)

is_last_split​

def is_last_split(self, work_tile_info: WorkInfo) -> Bool

Returns:

Bool

get_grid_shape​

static def get_grid_shape(dyn_cluster_shape: IndexList[3], dyn_raster_order: RasterOrder = RasterOrder.AlongN) -> IndexList[3]

Returns:

IndexList[3]

get_num_tiles​

static def get_num_tiles(problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> Int

Returns:

Int

get_required_locks_buffer_size_bytes​

static def get_required_locks_buffer_size_bytes[accum_type: DType, dyn_num_consumer: UInt32](problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> Int

Returns:

Int

get_linear_idx_from_m_and_n​

def get_linear_idx_from_m_and_n(self, tile_m: UInt32, tile_n: UInt32) -> UInt32

Returns:

UInt32

output_tile_index​

def output_tile_index(self, work_tile_info: WorkInfo) -> UInt32

Returns:

UInt32

reduction​

def reduction[accum_type: DType, c_reg_layout: Layout, workspace_layout: TensorLayout](self, reduction_workspace: TileTensor[accum_type, workspace_layout], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], work_tile_info: WorkInfo, num_barriers: UInt32, warp_group_local_idx: UInt32)

def reduction[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], work_tile_info: WorkInfo, num_barriers: UInt32, warp_group_local_idx: UInt32)

wait_eq​

static def wait_eq(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)

wait_lt​

static def wait_lt(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)

arrive_set​

static def arrive_set(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, increment: UInt32)

store_accumulator​

def store_accumulator[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)

reduce_add​

def reduce_add[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout, //, *, write_back: Bool](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)