IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo trait

MHATileScheduler

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type

Indicate the type being used on accelerator devices.

may_advance​

comptime may_advance

mha_schedule​

comptime mha_schedule

The MHATileScheduler trait describes a schedule for the persistent kernel.

Required methods​

__init__​

def __init__(out self: _Self, *, copy: _Self)

Create a new instance of the value by copying an existing one.

Args:

  • ​copy (_Self): The value to copy.

Returns:

_Self

def __init__(out self: _Self, *, deinit move: _Self)

Create a new instance of the value by moving the value of another.

Args:

  • ​move (_Self): The value to move.

Returns:

_Self

get_current_work_info​

def get_current_work_info[ValidLengthType: OptionalPointer, //](self: _Self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo

Returns the current WorkInfo.

Returns:

WorkInfo

advance​

def advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self: _Self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]

Advance state to the next work item. func must return a Bool indicating whether there is more work. Returns True if there is more work.

Returns:

OptionalReg[SeqInfo]

grid_dim​

static def grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]

Return the grid_dim required for the kernel.

Returns:

Tuple[Int, Int, Int]

initial_state​

def initial_state[ValidLengthType: OptionalPointer, //](self: _Self, ptr: UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState

Create the initial state object.

Returns:

MHATileState

unsafe_seq_info​

def unsafe_seq_info[ValidLengthType: OptionalPointer, //](self: _Self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo

Returns:

SeqInfo

get_type_name​

static def get_type_name() -> String

Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.

Returns:

String: The host type's name.

Provided methods​

copy​

def copy(self: _Self) -> _Self

Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.

Overriding this method is not allowed.

Returns:

_Self: A copy of this value.