Skip to main content

Mojo function

mla_decode_branch_bf16

mla_decode_branch_bf16[dtype: DType, collection_t: KVCollectionT, //, mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin] = "cpu"](output: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], q: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], input_row_offsets: TileTensor[DType.uint32, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], freqs_cis: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_norm_gamma: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, epsilon: Float32, w_uk: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], w_uv: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], ctx: DeviceContext)

BF16 MLA decode path.

Applies RoPE and RMSNorm, projects q_nope to latent space, concatenates with q_rope, and runs decode.

Was this page helpful?