Skip to main content

Mojo struct

SM100MHADepth512

struct SM100MHADepth512[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: Depth512SM100Config[KVLUTType.dtype], ValidLengthType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

accum_type

comptime accum_type = DType.float32

BM

comptime BM = Depth512SM100Config[KVLUTType.dtype].BM

BN

comptime BN = config.BN

cta_group

comptime cta_group = 2

group

comptime group = config.group

num_q_heads

comptime num_q_heads = config.num_q_heads

page_size

comptime page_size = KVLUTType.page_size

PairBM

comptime PairBM = (SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM * 2)

PositionType

comptime PositionType = MHAPosition[SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].PairBM, config.BN, config.qk_depth, config.qk_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]

qkv_type

comptime qkv_type = KVLUTType.dtype

ragged

comptime ragged = not ValidLengthType.is_null.__bool__()

SmemType

comptime SmemType = Depth512AttentionSMem[config]

TmemAllocType

comptime TmemAllocType = TmemAllocation[2]

Methods

kernel

static kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=Depth512SM100Config[KVLUTType.dtype].BM, group=config.group, depth=config.qk_depth, decoding=False, num_qk_stages=Depth512SM100Config[KVLUTType.dtype].num_qk_stages](), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=Depth512SM100Config[KVLUTType.dtype].BM, group=config.group, depth=config.qk_depth, decoding=False, num_qk_stages=Depth512SM100Config[KVLUTType.dtype].num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList((config // 2), 1, config, __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList((config // 2), 1, config, __list_literal__=Tuple()), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(config, 1, (config // 4), __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(config, 1, (config // 4), __list_literal__=Tuple()), config.swizzle_mode]()], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, Depth512SM100Config[KVLUTType.dtype].BM, config.ov_depth], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, SchedulerType, ValidLengthType, NullPointer[DType.float32], KVRowOffsetsType, MaxSeqLenType, PartitionType])

mask_status

static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus

Was this page helpful?