Skip to main content

Mojo struct

QueuedTileScheduler

struct QueuedTileScheduler[tile_shape: UInt32, num_heads: UInt32, /, decoding: Bool, num_ctas: UInt32 = H100.sm_count, schedule: MHASchedule = MHASchedule.DEFAULT]

If decoding == False, then num_heads is q_num_heads. If decoding == True, then num_heads is kv_num_heads.

Fields​

  • ​gidx_ptr (UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.GLOBAL]):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, MHATileScheduler, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = QueuedTileScheduler[tile_shape, num_heads, decoding, num_ctas, schedule]

may_advance​

comptime may_advance = True

mha_schedule​

comptime mha_schedule = schedule

Methods​

__init__​

__init__(gidx_ptr: UnsafePointer[UInt32, MutAnyOrigin]) -> Self

get_current_work_info​

get_current_work_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo

Returns:

WorkInfo

advance​

advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]

The parameter func must return a Bool indicating whether the WorkInfo arg is valid. This function returns whether the current idx corresponds to a valid WorkInfo. Note that if MHASchedulerSynchronization is NONE, then we assume it is only called by thread_idx.x==0.

Returns:

OptionalReg[SeqInfo]

grid_dim​

static grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]

Returns:

Tuple[Int, Int, Int]

initial_state​

initial_state[ValidLengthType: OptionalPointer, //](self, ptr: UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState

Returns:

MHATileState

unsafe_seq_info​

unsafe_seq_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo

Returns:

SeqInfo

get_type_name​

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait).

Returns:

String: The host type's name.