IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MHAPosition

struct MHAPosition[BM: Int, BN: Int, depth: Int, padded_depth: Int, q_num_heads: Int, group: Int, decoding: Bool]

Position of the MHA-kernel. When decoding=False, q_head_stride == q_num_heads. When decoding=True, q_head_stride == 1.

Fields​

  • ​q_row (UInt32):
  • ​q_col (UInt32):
  • ​q_out_offset (Int):
  • ​num_keys (UInt32):
  • ​start_pos (UInt32):
  • ​seq_len (UInt32):
  • ​head_idx (UInt32):
  • ​prompt_offset (UInt32):
  • ​prompt_idx (UInt32):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

num_q_heads_per_thread​

comptime num_q_heads_per_thread = min(2, ceildiv(group, 8)) if decoding else 1

q_output_gmem_layout​

comptime q_output_gmem_layout = Layout(IntTuple(BM, depth), IntTuple(MHAPosition[BM, BN, depth, padded_depth, q_num_heads, group, decoding].q_stride, 1))

q_stride​

comptime q_stride = depth if decoding else (depth * q_num_heads)

split_gmem_layout​

comptime split_gmem_layout = Layout(IntTuple((BM // 2), depth), IntTuple(MHAPosition[BM, BN, depth, padded_depth, q_num_heads, group, decoding].q_stride, 1))

Methods​

__init__​

def __init__(q_row: UInt32, q_col: UInt32, q_out_offset: Int, num_keys: UInt32, start_pos: UInt32, seq_info: SeqInfo) -> Self

__eq__​

def __eq__(self, other: Self) -> Bool

Returns:

Bool

__ne__​

def __ne__(self, other: Self) -> Bool

Returns:

Bool

q_head_idx​

def q_head_idx(self) -> UInt32

Returns:

UInt32

kv_head_idx​

def kv_head_idx(self) -> UInt32

Returns:

UInt32

write_to​

def write_to(self, mut writer: T)

q_tile_num_rows​

def q_tile_num_rows(self) -> UInt32

Returns:

UInt32

q_out_gmem_tensor​

def q_out_gmem_tensor[dtype: DType](self, ptr: UnsafePointer[Scalar[dtype]]) -> LayoutTensor[dtype, Self.q_output_gmem_layout, ptr.origin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]

Returns:

LayoutTensor[dtype, Self.q_output_gmem_layout, ptr.origin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]

mask_status​

def mask_status[MaskType: MHAMask](self, mask: MaskType, kv_tile_start_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus

get_score_row​

def get_score_row(self) -> UInt32

Returns:

UInt32

exp_sum_qk_max_ptr​

def exp_sum_qk_max_ptr[partition_t: MHAPartitionScheme](self, partition: partition_t, batch_size: UInt32) -> Tuple[UnsafePointer[Scalar[partition_t.accum_dtype], MutAnyOrigin], UnsafePointer[Scalar[partition_t.accum_dtype], MutAnyOrigin]]

Returns:

Tuple[UnsafePointer[Scalar[partition_t.accum_dtype], MutAnyOrigin], UnsafePointer[Scalar[partition_t.accum_dtype], MutAnyOrigin]]

get_start_and_end_for_partitions​

def get_start_and_end_for_partitions[PartitionType: MHAPartitionScheme, MaskType: MHAMask, //, page_size: Int](self, partition: PartitionType, mask: MaskType) -> Tuple[UInt32, UInt32]

Returns:

Tuple[UInt32, UInt32]

get_q_gmem_row​

static def get_q_gmem_row[MaxSeqLenType: OptionallyStaticInt, //, ragged: Bool](seq_info: SeqInfo, max_seq_len: MaxSeqLenType) -> UInt32

Returns:

UInt32

static def get_q_gmem_row[ragged: Bool](seq_info: SeqInfo, max_seq_len: UInt32) -> UInt32

Returns:

UInt32