IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TransientScheduler

struct TransientScheduler[tile_shape: UInt32, num_heads: UInt32, flip_prompt_idx: Bool, pair_cta: Bool = False]

Implemented traits​

AnyType, Copyable, Defaultable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, MHATileScheduler, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = TransientScheduler[tile_shape, num_heads, flip_prompt_idx, pair_cta]

may_advance​

comptime may_advance = False

mha_schedule​

comptime mha_schedule = MHASchedule.DEFAULT

Methods​

__init__​

def __init__() -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

get_current_work_info​

def get_current_work_info(self, num_prompt_tiles: UInt32) -> WorkInfo

Returns:

WorkInfo

def get_current_work_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo

Returns:

WorkInfo

advance​

def advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]

Returns:

OptionalReg[SeqInfo]

grid_dim​

static def grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]

Returns:

Tuple[Int, Int, Int]

initial_state​

def initial_state[ValidLengthType: OptionalPointer, //](self, ptr: UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState

Returns:

MHATileState

unsafe_seq_info​

def unsafe_seq_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo

Returns:

SeqInfo