Mojo struct
TransientScheduler
struct TransientScheduler[tile_shape: UInt32, num_heads: UInt32, flip_prompt_idx: Bool, pair_cta: Bool = False]
Implemented traitsβ
AnyType,
Copyable,
Defaultable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
MHATileScheduler,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
device_typeβ
comptime device_type = TransientScheduler[tile_shape, num_heads, flip_prompt_idx, pair_cta]
may_advanceβ
comptime may_advance = False
mha_scheduleβ
comptime mha_schedule = MHASchedule.DEFAULT
Methodsβ
__init__β
__init__() -> Self
get_type_nameβ
get_current_work_infoβ
get_current_work_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo
Returns:
advanceβ
advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]
Returns:
grid_dimβ
static grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]
Returns:
initial_stateβ
initial_state[ValidLengthType: OptionalPointer, //](self, ptr: UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState
Returns:
unsafe_seq_infoβ
unsafe_seq_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo
Returns:
SeqInfo
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!