Skip to main content

Mojo function

generic_fused_qk_rope_bshd_paged_ragged

generic_fused_qk_rope_bshd_paged_ragged[dtype: DType, freq_dtype: DType, //, *, interleaved: Bool, has_position_ids: Bool, target: StringSlice[StaticConstantOrigin], mrope_types: Variadic[CoordLike] = , mrope_section: Optional[Coord[mrope_types]] = None](q_proj: TileTensor[dtype, q_proj.LayoutType, q_proj.origin, linear_idx_type=q_proj.linear_idx_type, element_size=q_proj.element_size], input_row_offsets: TileTensor[DType.uint32, input_row_offsets.LayoutType, input_row_offsets.origin, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], kv_collection: PagedKVCacheCollection[kv_collection.dtype_, kv_collection.kv_params_, kv_collection.page_size, kv_collection.scale_dtype_, kv_collection.quantization_granularity_], freqs_cis: TileTensor[freq_dtype, freqs_cis.LayoutType, freqs_cis.origin, linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size], position_ids: TileTensor[DType.uint32, position_ids.LayoutType, position_ids.origin, linear_idx_type=position_ids.linear_idx_type, element_size=position_ids.element_size], layer_idx: UInt32, output: TileTensor[dtype, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], context: DeviceContextPtr = DeviceContextPtr())

Performs a fused RoPE projection for Q and K projections.

We have a manually fused QKV projection with mo.opaque dtypes in our Llama model. Due to a limitation in custom op definitions, we can't declare both a tensor and opaque dtype as output from a custom kernel. This requires us to only note Q_proj as an output from the QKV projection. If we immediately follow the QKV proj kernel with a RoPE kernel applied to K, we'll get a race condition because the graph compiler doesn't know about the dependency between these kernels in the graph definition. Here we fuse the RoPE kernel applied to Q_proj with K_proj, so K_proj RoPE is only executed after QKV completes.

Was this page helpful?