Mojo struct
SM100MHADepth512
struct SM100MHADepth512[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, SchedulerType: MHATileScheduler, config: Depth512SM100Config[KVLUTType.dtype], ValidLengthType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
accum_type
comptime accum_type = DType.float32
BM
comptime BM = config.BM
BM_eff
comptime BM_eff = config.BM_eff()
BN
comptime BN = config.BN
cta_group
comptime cta_group = 2
fuse_gqa
comptime fuse_gqa = config.fuse_gqa
group
comptime group = config.group
num_q_heads
comptime num_q_heads = config.num_q_heads
page_size
comptime page_size = KVLUTType.page_size
PairBM
comptime PairBM = (SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM * 2)
PairBM_mask
comptime PairBM_mask = (SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].BM_eff * 2)
PositionType
comptime PositionType = MHAPosition[SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].PairBM, config.BN, config.qk_depth, config.qk_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]
qkv_type
comptime qkv_type = KVLUTType.dtype
ragged
comptime ragged = not ValidLengthType.is_null.__bool__()
SmemType
comptime SmemType = Depth512AttentionSMem[config]
Methods
kernel
static kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, 4 if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, _padded_shape[4 if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=config.BM, group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[4 if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=config.BM, group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList((config // 2), 1, config, __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList((config // 2), 1, config, __list_literal__=Tuple()), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(config, 1, config, __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(config, 1, config, __list_literal__=Tuple()), config.swizzle_mode]()], ragged_tma_store: RaggedTMA3DTile[output_type, config.swizzle_mode, BM=config.BM, BN=config.ov_depth, group=config.group if SM100MHADepth512[KVLUTType, output_type, MaskType, SchedulerType, config, ValidLengthType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType].fuse_gqa else 1], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, SchedulerType, ValidLengthType, NullPointer[DType.float32], KVRowOffsetsType, MaxSeqLenType, PartitionType])
mask_status
static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!