Mojo function
generic_fused_qk_rope_bshd_paged
generic_fused_qk_rope_bshd_paged[dtype: DType, //, *, interleaved: Bool, target: StringSlice[StaticConstantOrigin]](q_proj: TileTensor[dtype, address_space=q_proj.address_space, linear_idx_type=q_proj.linear_idx_type, element_size=q_proj.element_size], kv_collection: PagedKVCacheCollection, freqs_cis: TileTensor[dtype, address_space=freqs_cis.address_space, linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size], layer_idx: UInt32, valid_lengths: TileTensor[DType.uint32, address_space=valid_lengths.address_space, linear_idx_type=valid_lengths.linear_idx_type, element_size=valid_lengths.element_size], output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], context: DeviceContextPtr = DeviceContextPtr())
Performs a fused RoPE projection for Q and K with paged KV cache.
This is the paged equivalent of generic_fused_qk_rope_bshd_continuous_batch. It applies RoPE to both Q (returned) and K (in paged cache) to ensure proper dependency ordering after fused_qkv_padded_matmul.
Args:
- βq_proj (
TileTensor[dtype, address_space=q_proj.address_space, linear_idx_type=q_proj.linear_idx_type, element_size=q_proj.element_size]): Query projection tensor of shape [batch, seq_len, n_heads, head_dim]. - βkv_collection (
PagedKVCacheCollection): The paged KV cache collection. - βfreqs_cis (
TileTensor[dtype, address_space=freqs_cis.address_space, linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size]): Frequency tensor for RoPE of shape [max_seq_len, head_dim]. - βlayer_idx (
UInt32): The layer index for accessing the correct cache. - βvalid_lengths (
TileTensor[DType.uint32, address_space=valid_lengths.address_space, linear_idx_type=valid_lengths.linear_idx_type, element_size=valid_lengths.element_size]): Tensor of shape [batch] containing the valid length for each sequence. RoPE is only applied to positions within these lengths. - βoutput (
TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Output tensor for Q with RoPE applied, same shape as q_proj. - βcontext (
DeviceContextPtr): Device context pointer for execution.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!