Mojo function
mla_decode_branch_bf16
mla_decode_branch_bf16[collection_t: KVCollectionT, //, mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin] = "cpu"](output: TileTensor[DType.bfloat16, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_shape_types=output.element_shape_types], q: TileTensor[DType.bfloat16, q.LayoutType, q.origin, linear_idx_type=q.linear_idx_type, element_shape_types=q.element_shape_types], input_row_offsets: TileTensor[DType.uint32, input_row_offsets.LayoutType, input_row_offsets.origin, linear_idx_type=input_row_offsets.linear_idx_type, element_shape_types=input_row_offsets.element_shape_types], freqs_cis: TileTensor[freqs_cis.dtype, freqs_cis.LayoutType, freqs_cis.origin, linear_idx_type=freqs_cis.linear_idx_type, element_shape_types=freqs_cis.element_shape_types], kv_norm_gamma: TileTensor[kv_norm_gamma.dtype, kv_norm_gamma.LayoutType, kv_norm_gamma.origin, linear_idx_type=kv_norm_gamma.linear_idx_type, element_shape_types=kv_norm_gamma.element_shape_types], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, epsilon: Float32, w_uk: TileTensor[DType.bfloat16, w_uk.LayoutType, w_uk.origin, linear_idx_type=w_uk.linear_idx_type, element_shape_types=w_uk.element_shape_types], w_uv: TileTensor[DType.bfloat16, w_uv.LayoutType, w_uv.origin, linear_idx_type=w_uv.linear_idx_type, element_shape_types=w_uv.element_shape_types], ctx: DeviceContext)
BF16 MLA decode path.
Applies RoPE and RMSNorm, projects q_nope to latent space, concatenates with q_rope, and runs decode.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!