Mojo function
mla_prefill_decode_graph_fp8
mla_prefill_decode_graph_fp8[dtype: DType, fp8_dtype: DType, fp8_scale_dtype: DType, collection_t: KVCollectionT, //, m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin] = "cpu"](output: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], q: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input_row_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], freqs_cis: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_norm_gamma: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, epsilon: Float32, buffer_row_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], cache_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], buffer_length: Int, max_seq_len: Int, kv_b_proj: TileTensor[fp8_dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_b_proj_scale: TileTensor[fp8_scale_dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], w_uk: TileTensor[fp8_dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], w_uk_scale: TileTensor[fp8_scale_dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], w_uv: TileTensor[fp8_dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], w_uv_scale: TileTensor[fp8_scale_dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ctx: DeviceContext)
This is a manually fused kernel that performs the following operations: - Perform MLA prefill or decode based on the maximum sequence length.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!