IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

QueuedTileScheduler

struct QueuedTileScheduler[tile_shape: UInt32, num_heads: UInt32, /, decoding: Bool, num_ctas: UInt32 = SIMD(GPUInfo.from_family(AcceleratorArchitectureFamily(Int(32), Int(2048), Int(233472), Int(65536), Int(1024)), StringSlice("H100"), Vendor(Int8(2)), StringSlice("cuda"), StringSlice("hopper"), SIMD(9), StringSlice("sm_90a"), Int(132)).sm_count), schedule: MHASchedule = MHASchedule.DEFAULT]

If decoding == False, then num_heads is q_num_heads. If decoding == True, then num_heads is kv_num_heads.

Fields​

  • ​gidx_ptr (UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.GLOBAL]):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, MHATileScheduler, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = QueuedTileScheduler[tile_shape, num_heads, decoding, num_ctas, schedule]

may_advance​

comptime may_advance = True

mha_schedule​

comptime mha_schedule = schedule

Methods​

__init__​

def __init__(gidx_ptr: UnsafePointer[UInt32, MutAnyOrigin]) -> Self

get_current_work_info​

def get_current_work_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo

Returns:

WorkInfo

advance​

def advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]

The parameter func must return a Bool indicating whether the WorkInfo arg is valid. This function returns whether the current idx corresponds to a valid WorkInfo. Note that if MHASchedulerSynchronization is NONE, then we assume it is only called by thread_idx.x==0.

Returns:

OptionalReg[SeqInfo]

grid_dim​

static def grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]

Returns:

Tuple[Int, Int, Int]

initial_state​

def initial_state[ValidLengthType: OptionalPointer, //](self, ptr: UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState

Returns:

MHATileState

unsafe_seq_info​

def unsafe_seq_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo

Returns:

SeqInfo

get_type_name​

static def get_type_name() -> String

Gets the name of the host type (the one implementing this trait).

Returns:

String: The host type's name.