Skip to main content

Mojo function

generic_fused_qk_rope_bshd_continuous_batch

generic_fused_qk_rope_bshd_continuous_batch[dtype: DType, //, *, interleaved: Bool, target: StringSlice[StaticConstantOrigin]](q_proj: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: ContinuousBatchingKVCacheCollection[dtype_, kv_params_], freqs_cis: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], layer_idx: UInt32, valid_lengths: LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr = DeviceContextPtr())

Performs a fused RoPE projection for Q and K projections.

We have a manually fused QKV projection with mo.opaque dtypes in our Llama model. Due to a limitation in custom op definitions, we can't declare both a tensor and opaque dtype as output from a custom kernel. This requires us to only note Q_proj as an output from the QKV projection. If we immediately follow the QKV proj kernel with a RoPE kernel applied to K, we'll get a race condition because the graph compiler doesn't know about the dependency between these kernels in the graph definition. Here we fuse the RoPE kernel applied to Q_proj with K_proj, so K_proj RoPE is only executed after QKV completes.

Args:

  • q_proj (LayoutTensor): Query projection tensor of shape [batch, seq_len, n_heads, head_dim].
  • kv_collection (ContinuousBatchingKVCacheCollection): The continuous batching KV cache collection.
  • freqs_cis (LayoutTensor): Frequency tensor for RoPE of shape [max_seq_len, head_dim].
  • layer_idx (UInt32): The layer index for accessing the correct cache.
  • valid_lengths (LayoutTensor): Tensor of shape [batch] containing the valid length for each sequence. RoPE is only applied to positions within these lengths.
  • output (LayoutTensor): Output tensor for Q with RoPE applied, same shape as q_proj.
  • context (DeviceContextPtr): Device context pointer for execution.

Was this page helpful?