Skip to main content

Mojo function

mla_prefill_decode_graph_fp8

mla_prefill_decode_graph_fp8[dtype: DType, fp8_dtype: DType, fp8_scale_dtype: DType, collection_t: KVCollectionT, //, m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, mask_str: StringSlice[StaticConstantOrigin], kv_input_fn: def[width: Int](IndexList[2]) capturing -> SIMD[DType.bfloat16, width], target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](output: TileTensor[dtype, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, q.LayoutType, q.origin, linear_idx_type=q.linear_idx_type, element_size=q.element_size], input_row_offsets: TileTensor[DType.uint32, input_row_offsets.LayoutType, input_row_offsets.origin, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], freqs_cis: TileTensor[freqs_cis.dtype, freqs_cis.LayoutType, freqs_cis.origin, linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size], kv_norm_gamma: TileTensor[kv_norm_gamma.dtype, kv_norm_gamma.LayoutType, kv_norm_gamma.origin, linear_idx_type=kv_norm_gamma.linear_idx_type, element_size=kv_norm_gamma.element_size], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, epsilon: Float32, buffer_row_offsets: TileTensor[DType.uint32, buffer_row_offsets.LayoutType, buffer_row_offsets.origin, linear_idx_type=buffer_row_offsets.linear_idx_type, element_size=buffer_row_offsets.element_size], cache_offsets: TileTensor[DType.uint32, cache_offsets.LayoutType, cache_offsets.origin, linear_idx_type=cache_offsets.linear_idx_type, element_size=cache_offsets.element_size], buffer_length: Int, max_seq_len: Int, w_k: TileTensor[fp8_dtype, w_k.LayoutType, w_k.origin, linear_idx_type=w_k.linear_idx_type, element_size=w_k.element_size], w_k_scale: TileTensor[fp8_scale_dtype, w_k_scale.LayoutType, w_k_scale.origin, linear_idx_type=w_k_scale.linear_idx_type, element_size=w_k_scale.element_size], w_uk: TileTensor[fp8_dtype, w_uk.LayoutType, w_uk.origin, linear_idx_type=w_uk.linear_idx_type, element_size=w_uk.element_size], w_uk_scale: TileTensor[fp8_scale_dtype, w_uk_scale.LayoutType, w_uk_scale.origin, linear_idx_type=w_uk_scale.linear_idx_type, element_size=w_uk_scale.element_size], w_uv: TileTensor[fp8_dtype, w_uv.LayoutType, w_uv.origin, linear_idx_type=w_uv.linear_idx_type, element_size=w_uv.element_size], w_uv_scale: TileTensor[fp8_scale_dtype, w_uv_scale.LayoutType, w_uv_scale.origin, linear_idx_type=w_uv_scale.linear_idx_type, element_size=w_uv_scale.element_size], scalar_args_buf: TileTensor[DType.int64, scalar_args_buf.LayoutType, scalar_args_buf.origin, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size], ctx: DeviceContext)

This is a manually fused kernel that performs the following operations: - Perform MLA prefill or decode based on the maximum sequence length.

Was this page helpful?