Skip to main content

Mojo struct

SM100MHADepth512

struct SM100MHADepth512[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: Depth512SM100Config[KVLUTType.dtype], ValidLengthType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

accum_type

comptime accum_type = DType.float32

BM

comptime BM = config.BM

BM_eff

comptime BM_eff = config.BM_eff()

BN

comptime BN = config.BN

cta_group

comptime cta_group = 2

fuse_gqa

comptime fuse_gqa = config.fuse_gqa

group

comptime group = config.group

num_q_heads

comptime num_q_heads = config.num_q_heads

page_size

comptime page_size = KVLUTType.page_size

PairBM

comptime PairBM = (SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM * 2)

PairBM_mask

comptime PairBM_mask = (SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM_eff * 2)

PositionType

comptime PositionType = MHAPosition[SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].PairBM, config.BN, config.qk_depth, config.qk_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]

qkv_type

comptime qkv_type = KVLUTType.dtype

ragged

comptime ragged = not ValidLengthType.is_null.__bool__()

SmemType

comptime SmemType = Depth512AttentionSMem[config]

Methods

kernel

static kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, 4 if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, _padded_shape[4 if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=config.BM, group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[4 if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=config.BM, group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList((config // 2), 1, config, __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList((config // 2), 1, config, __list_literal__=Tuple()), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(config, 1, config, __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(config, 1, config, __list_literal__=Tuple()), config.swizzle_mode]()], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, BM=config.BM, BN=config.ov_depth, group=config.group if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 1], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, SchedulerType, ValidLengthType, NullPointer[DType.float32], KVRowOffsetsType, MaxSeqLenType, PartitionType])

mask_status

static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus

Was this page helpful?