Skip to main content

Mojo struct

SM100MHADepth512

struct SM100MHADepth512[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: Depth512SM100Config[KVLUTType.dtype], ValidLengthType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

accum_type​

comptime accum_type = DType.float32

BM​

comptime BM = config.BM

BM_eff​

comptime BM_eff = config.BM_eff()

BN​

comptime BN = config.BN

cta_group​

comptime cta_group = 2

fuse_gqa​

comptime fuse_gqa = config.fuse_gqa

group​

comptime group = config.group

k_sub_BN​

comptime k_sub_BN = kv_sub_tile_rows((SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BN // 2), SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].page_size)

num_q_heads​

comptime num_q_heads = config.num_q_heads

page_size​

comptime page_size = KVLUTType.page_size

PairBM​

comptime PairBM = (SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM * 2)

PairBM_mask​

comptime PairBM_mask = (SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM_eff * 2)

PositionType​

comptime PositionType = MHAPosition[SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].PairBM, config.BN, config.qk_depth, config.qk_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]

qkv_type​

comptime qkv_type = KVLUTType.dtype

ragged​

comptime ragged = not ValidLengthType.is_null.__bool__()

SmemType​

comptime SmemType = Depth512AttentionSMem[config]

v_sub_BN​

comptime v_sub_BN = kv_sub_tile_rows(config.BK1, SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].page_size)

Methods​

kernel​

static kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, 4 if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, _padded_shape[4 if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=config.BM, group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[4 if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=config.BM, group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].k_sub_BN, 1, config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].k_sub_BN, 1, config, __list_literal__=NoneType(None)), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].v_sub_BN, 1, config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].v_sub_BN, 1, config, __list_literal__=NoneType(None)), config.swizzle_mode]()], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, BM=config.BM, BN=config.ov_depth, group=config.group if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 1], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, SchedulerType, ValidLengthType, NullPointer[DType.float32], KVRowOffsetsType, MaxSeqLenType, PartitionType])

mask_status​

static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus