Skip to main content

Mojo trait

MHATileScheduler

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type

Indicate the type being used on accelerator devices.

may_advance​

comptime may_advance

mha_schedule​

comptime mha_schedule

The MHATileScheduler trait describes a schedule for the persistent kernel.

Required methods​

__init__​

__init__(out self: _Self, *, copy: _Self)

Create a new instance of the value by copying an existing one.

Args:

  • ​copy (_Self): The value to copy.

Returns:

_Self

__init__(out self: _Self, *, deinit take: _Self)

Create a new instance of the value by moving the value of another.

Args:

  • ​take (_Self): The value to move.

Returns:

_Self

get_current_work_info​

get_current_work_info[ValidLengthType: OptionalPointer, //](self: _Self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo

Returns the current WorkInfo.

Returns:

WorkInfo

advance​

advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self: _Self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]

Advance state to the next work item. func must return a Bool indicating whether there is more work. Returns True if there is more work.

Returns:

OptionalReg[SeqInfo]

grid_dim​

static grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]

Return the grid_dim required for the kernel.

Returns:

Tuple[Int, Int, Int]

initial_state​

initial_state[ValidLengthType: OptionalPointer, //](self: _Self, ptr: UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState

Create the initial state object.

Returns:

MHATileState

unsafe_seq_info​

unsafe_seq_info[ValidLengthType: OptionalPointer, //](self: _Self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo

Returns:

SeqInfo

get_type_name​

static get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

Provided methods​

copy​

copy(self: _Self) -> _Self

Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.

Returns:

_Self: A copy of this value.