IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TileScheduler

struct TileScheduler[tile_shape: UInt32, num_heads: UInt32, /, num_ctas: UInt32 = SIMD(GPUInfo.from_family(AcceleratorArchitectureFamily(Int(32), Int(2048), Int(233472), Int(65536), Int(1024)), StringSlice("H100"), Vendor(Int8(2)), StringSlice("cuda"), StringSlice("hopper"), SIMD(9), StringSlice("sm_90a"), Int(132)).sm_count), schedule: MHASchedule = MHASchedule.DEFAULT]

Implemented traits​

AnyType, Copyable, Defaultable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, MHATileScheduler, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = TileScheduler[tile_shape, num_heads, num_ctas, schedule]

may_advance​

comptime may_advance = True

mha_schedule​

comptime mha_schedule = schedule

Methods​

__init__​

def __init__() -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

get_current_work_info​

def get_current_work_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo

Returns:

WorkInfo

fetch_next_work​

def fetch_next_work(self, ts: MHATileSummary, mut state: MHATileState) -> WorkInfo

Returns:

WorkInfo

advance​

def advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]

Returns:

OptionalReg[SeqInfo]

grid_dim​

static def grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]

Returns:

Tuple[Int, Int, Int]

initial_state​

def initial_state[ValidLengthType: OptionalPointer, //](self, ptr: UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState

Returns:

MHATileState

unsafe_seq_info​

def unsafe_seq_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo

Returns:

SeqInfo