Mojo function
mla_prefill_branch_bf16
mla_prefill_branch_bf16[dtype: DType, collection_t: KVCollectionT, //, mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin] = "cpu"](output: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], q: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input_row_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], freqs_cis: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_norm_gamma: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, epsilon: Float32, buffer_row_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], cache_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], buffer_length: Int, kv_b_proj: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ctx: DeviceContext)
BF16 MLA prefill path.
Applies RoPE and RMSNorm, up-projects latent KV to full K/V, then runs prefill attention.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!