Skip to main content

Mojo function

mla_prefill_decode_graph_bf16

mla_prefill_decode_graph_bf16[dtype: DType, collection_t: KVCollectionT, //, mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin] = "cpu"](output: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], q: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input_row_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], freqs_cis: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_norm_gamma: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, epsilon: Float32, buffer_row_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], cache_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], buffer_length: Int, max_seq_len: Int, kv_b_proj: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], w_uk: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], w_uv: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ctx: DeviceContext)

BF16 MLA prefill/decode graph.

Dispatches to prefill or decode based on max sequence length in the batch.

Was this page helpful?