Skip to main content

Mojo function

fused_qk_rope

fused_qk_rope[dtype: DType, collection_t: KVCollectionT, //, cache_t: KVCacheT, *, interleaved: Bool, target: StringSlice[StaticConstantOrigin]](q_proj: TileTensor[dtype, q_proj.LayoutType, q_proj.origin, address_space=q_proj.address_space, linear_idx_type=q_proj.linear_idx_type, element_size=q_proj.element_size], kv_collection: collection_t, freqs_cis: TileTensor[dtype, freqs_cis.LayoutType, freqs_cis.origin, address_space=freqs_cis.address_space, linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size], layer_idx: UInt32, valid_lengths: TileTensor[DType.uint32, valid_lengths.LayoutType, valid_lengths.origin, address_space=valid_lengths.address_space, linear_idx_type=valid_lengths.linear_idx_type, element_size=valid_lengths.element_size], output: TileTensor[dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], context: Optional[DeviceContext])

Applies RoPE to query and key tensors.

Args:

  • q_proj (TileTensor): Query projection tensor of shape [batch, seq_len, n_heads, head_dim].
  • kv_collection (collection_t): The KV cache collection containing the key cache.
  • freqs_cis (TileTensor): Frequency tensor for RoPE of shape [max_seq_len, head_dim].
  • layer_idx (UInt32): The layer index for accessing the correct cache.
  • valid_lengths (TileTensor): Tensor of shape [batch] containing the valid length for each sequence. RoPE is only applied to positions within these lengths.
  • output (TileTensor): Output tensor for Q with RoPE applied, same shape as q_proj.
  • context (Optional): Optional device context for GPU execution.

Was this page helpful?