IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SM100MHADepth512

struct SM100MHADepth512[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: Depth512SM100Config[KVLUTType.dtype], ValidLengthType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

accum_type​

comptime accum_type = DType.float32

BM​

comptime BM = config.BM

BM_eff​

comptime BM_eff = config.BM_eff()

BN​

comptime BN = config.BN

cta_group​

comptime cta_group = 2

fuse_gqa​

comptime fuse_gqa = config.fuse_gqa

group​

comptime group = config.group

k_sub_BN​

comptime k_sub_BN = kv_sub_tile_rows((config // Int(2)), KVLUTType.page_size)

num_q_heads​

comptime num_q_heads = config.num_q_heads

page_size​

comptime page_size = KVLUTType.page_size

PairBM​

comptime PairBM = (config * Int(2))

PairBM_mask​

comptime PairBM_mask = (config.BM_eff() * Int(2))

PositionType​

comptime PositionType = MHAPosition[Int((mul config.BM, 2)), config.BN, config.qk_depth, config.qk_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]

qkv_type​

comptime qkv_type = KVLUTType.dtype

ragged​

comptime ragged = not ValidLengthType.is_null.__bool__()

SmemType​

comptime SmemType = Depth512AttentionSMem[config]

v_sub_BN​

comptime v_sub_BN = kv_sub_tile_rows(config.BK1, KVLUTType.page_size)

Methods​

kernel​

static def kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, Int(4) if Self.fuse_gqa else Int(3), _padded_shape[Int(4) if Self.fuse_gqa else Int(3), KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=config.BM, group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=Self.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[Int(4) if Self.fuse_gqa else Int(3), KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=config.BM, group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=Self.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, Int(3), _padded_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows((config // Int(2)), KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows((config // Int(2)), KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, Int(3), _padded_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BK1, KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[Int(3), KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BK1, KVLUTType.page_size), Int(1), config, __list_literal__=NoneType(None)), config.swizzle_mode]()], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, BM=config.BM, BN=config.ov_depth, group=config.group if config.fuse_gqa else Int(1)], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, SchedulerType, ValidLengthType, NullPointer[DType.float32], KVRowOffsetsType, MaxSeqLenType, PartitionType])

mask_status​

static def mask_status(mask: MaskType, seq_id: UInt32, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus