Skip to main content

Mojo struct

QueuedTileScheduler

@register_passable(trivial) struct QueuedTileScheduler[tile_shape: SIMD[uint32, 1], num_heads: SIMD[uint32, 1], /, decoding: Bool, num_ctas: SIMD[uint32, 1] = SIMD(GPUInfo("H100", Vendor(2), "cuda", "hopper", 9, "sm_90a", 132, 32, 2048, 233472, 65536, 1024).sm_count), schedule: MHASchedule = MHASchedule(0)]

If decoding == False, then num_heads is q_num_heads. If decoding == True, then num_heads is kv_num_heads.

Fields

  • gidx_ptr (UnsafePointer[SIMD[uint32, 1], address_space=AddressSpace(1)]):

Implemented traits

AnyType, Copyable, DevicePassable, ExplicitlyCopyable, MHATileScheduler, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

device_type

alias device_type = QueuedTileScheduler[tile_shape, num_heads, decoding, num_ctas, schedule]

may_advance

alias may_advance = True

mha_schedule

alias mha_schedule = schedule

Methods

__init__

__init__(gidx_ptr: UnsafePointer[SIMD[uint32, 1]]) -> Self

get_current_work_info

get_current_work_info(self, ts: MHATileSummary, state: MHATileState) -> WorkInfo

Returns:

WorkInfo

advance

advance[ragged: Bool, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization(1)](self, ts: MHATileSummary, mut state: MHATileState, pipeline_idx: SIMD[uint32, 1]) -> OptionalReg[SeqInfo]

The parameter func must return a Bool indicating whether the WorkInfo arg is valid. This function returns whether the current idx corresponds to a valid WorkInfo. Note that if MHASchedulerSynchronization is NONE, then we assume it is only called by thread_idx.x==0.

Returns:

OptionalReg

grid_dim

static grid_dim(batch_size: SIMD[uint32, 1], max_num_prompt_tiles: SIMD[uint32, 1]) -> Tuple[Int, Int, Int]

Returns:

Tuple

initial_state

initial_state(self, ptr: UnsafePointer[SIMD[uint32, 1], address_space=AddressSpace(3)], tile_summary: MHATileSummary) -> MHATileState

Returns:

MHATileState

unsafe_seq_info

unsafe_seq_info[ragged: Bool](self, ts: MHATileSummary, state: MHATileState) -> SeqInfo

Returns:

SeqInfo

get_type_name

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait).

Returns:

String: The host type's name.

get_device_type_name

static get_device_type_name() -> String

Gets device_type's name.

Returns:

String: The device type's name.

Was this page helpful?