Mojo struct
TileScheduler
struct TileScheduler[tile_shape: UInt32, num_heads: UInt32, /, num_ctas: UInt32 = H100.sm_count, schedule: MHASchedule = MHASchedule.DEFAULT]
Implemented traitsβ
AnyType,
Copyable,
Defaultable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
MHATileScheduler,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
device_typeβ
comptime device_type = TileScheduler[tile_shape, num_heads, num_ctas, schedule]
may_advanceβ
comptime may_advance = True
mha_scheduleβ
comptime mha_schedule = schedule
Methodsβ
__init__β
__init__() -> Self
get_type_nameβ
get_current_work_infoβ
get_current_work_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo
Returns:
fetch_next_workβ
advanceβ
advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]
Returns:
grid_dimβ
static grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]
Returns:
initial_stateβ
initial_state[ValidLengthType: OptionalPointer, //](self, ptr: UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState
Returns:
unsafe_seq_infoβ
unsafe_seq_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo
Returns:
SeqInfo
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!