Mojo struct
SM100MHADepth512
struct SM100MHADepth512[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: Depth512SM100Config[KVLUTType.dtype], ValidLengthType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
accum_type
comptime accum_type = DType.float32
BM
comptime BM = Depth512SM100Config[KVLUTType.dtype].BM
BN
comptime BN = config.BN
cta_group
comptime cta_group = 2
group
comptime group = config.group
num_q_heads
comptime num_q_heads = config.num_q_heads
page_size
comptime page_size = KVLUTType.page_size
PairBM
comptime PairBM = (SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM * 2)
PositionType
comptime PositionType = MHAPosition[SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].PairBM, config.BN, config.qk_depth, config.qk_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]
qkv_type
comptime qkv_type = KVLUTType.dtype
ragged
comptime ragged = not ValidLengthType.is_null.__bool__()
SmemType
comptime SmemType = Depth512AttentionSMem[config]
TmemAllocType
comptime TmemAllocType = TmemAllocation[2]
Methods
kernel
static kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=Depth512SM100Config[KVLUTType.dtype].BM, group=config.group, depth=config.qk_depth, decoding=False, num_qk_stages=Depth512SM100Config[KVLUTType.dtype].num_qk_stages](), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=Depth512SM100Config[KVLUTType.dtype].BM, group=config.group, depth=config.qk_depth, decoding=False, num_qk_stages=Depth512SM100Config[KVLUTType.dtype].num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList((config // 2), 1, config, __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList((config // 2), 1, config, __list_literal__=Tuple()), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(config, 1, (config // 4), __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(config, 1, (config // 4), __list_literal__=Tuple()), config.swizzle_mode]()], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, Depth512SM100Config[KVLUTType.dtype].BM, config.ov_depth], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, SchedulerType, ValidLengthType, NullPointer[DType.float32], KVRowOffsetsType, MaxSeqLenType, PartitionType])
mask_status
static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!