For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MHAPosition
struct MHAPosition[BM: Int, BN: Int, depth: Int, padded_depth: Int, q_num_heads: Int, group: Int, decoding: Bool]
Position of the MHA-kernel. When decoding=False, q_head_stride == q_num_heads. When decoding=True, q_head_stride == 1.
Fieldsβ
- βq_row (
UInt32): - βq_col (
UInt32): - βq_out_offset (
Int): - βnum_keys (
UInt32): - βstart_pos (
UInt32): - βseq_len (
UInt32): - βhead_idx (
UInt32): - βprompt_offset (
UInt32): - βprompt_idx (
UInt32):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
num_q_heads_per_threadβ
comptime num_q_heads_per_thread = min(Int(2), ceildiv(group, Int(8))) if decoding else Int(1)
q_output_gmem_layoutβ
comptime q_output_gmem_layout = Layout(IntTuple(BM, depth), IntTuple(depth if decoding else Int((mul depth, q_num_heads)), Int(1)))
q_strideβ
comptime q_stride = depth if decoding else (depth * q_num_heads)
split_gmem_layoutβ
comptime split_gmem_layout = Layout(IntTuple((BM // Int(2)), depth), IntTuple(depth if decoding else Int((mul depth, q_num_heads)), Int(1)))
Methodsβ
__init__β
def __init__(q_row: UInt32, q_col: UInt32, q_out_offset: Int, num_keys: UInt32, start_pos: UInt32, seq_info: SeqInfo) -> Self
__eq__β
__ne__β
q_head_idxβ
kv_head_idxβ
write_toβ
def write_to(self, mut writer: T)
q_tile_num_rowsβ
q_out_gmem_tensorβ
def q_out_gmem_tensor[dtype: DType](self, ptr: UnsafePointer[Scalar[dtype]]) -> LayoutTensor[dtype, Self.q_output_gmem_layout, ptr.origin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]
Returns:
mask_statusβ
def mask_status[MaskType: MHAMask](self, mask: MaskType, kv_tile_start_row: UInt32) -> TileMaskStatus
Returns:
get_score_rowβ
exp_sum_qk_max_ptrβ
def exp_sum_qk_max_ptr[partition_t: MHAPartitionScheme](self, partition: partition_t, batch_size: UInt32) -> Tuple[UnsafePointer[Scalar[partition_t.accum_dtype], MutAnyOrigin], UnsafePointer[Scalar[partition_t.accum_dtype], MutAnyOrigin]]
Returns:
get_start_and_end_for_partitionsβ
def get_start_and_end_for_partitions[PartitionType: MHAPartitionScheme, MaskType: MHAMask, //, page_size: Int](self, partition: PartitionType, mask: MaskType) -> Tuple[UInt32, UInt32]
Returns:
get_q_gmem_rowβ
static def get_q_gmem_row[MaxSeqLenType: OptionallyStaticInt, //, ragged: Bool](seq_info: SeqInfo, max_seq_len: MaxSeqLenType) -> UInt32
Returns:
static def get_q_gmem_row[ragged: Bool](seq_info: SeqInfo, max_seq_len: UInt32) -> UInt32
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!