Mojo function
generic_fused_qk_rope_bshd_continuous_batch
generic_fused_qk_rope_bshd_continuous_batch[dtype: DType, //, *, interleaved: Bool, target: StringSlice[StaticConstantOrigin]](q_proj: TileTensor[dtype, q_proj.LayoutType, q_proj.origin, address_space=q_proj.address_space, linear_idx_type=q_proj.linear_idx_type, element_shape_types=q_proj.element_shape_types], kv_collection: ContinuousBatchingKVCacheCollection[kv_collection.dtype_, kv_collection.kv_params_], freqs_cis: TileTensor[dtype, freqs_cis.LayoutType, freqs_cis.origin, address_space=freqs_cis.address_space, linear_idx_type=freqs_cis.linear_idx_type, element_shape_types=freqs_cis.element_shape_types], layer_idx: UInt32, valid_lengths: TileTensor[DType.uint32, valid_lengths.LayoutType, valid_lengths.origin, address_space=valid_lengths.address_space, linear_idx_type=valid_lengths.linear_idx_type, element_shape_types=valid_lengths.element_shape_types], output: TileTensor[dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_shape_types=output.element_shape_types], context: DeviceContextPtr = DeviceContextPtr())
Performs a fused RoPE projection for Q and K projections.
We have a manually fused QKV projection with mo.opaque dtypes in our Llama model. Due to a limitation in custom op definitions, we can't declare both a tensor and opaque dtype as output from a custom kernel. This requires us to only note Q_proj as an output from the QKV projection. If we immediately follow the QKV proj kernel with a RoPE kernel applied to K, we'll get a race condition because the graph compiler doesn't know about the dependency between these kernels in the graph definition. Here we fuse the RoPE kernel applied to Q_proj with K_proj, so K_proj RoPE is only executed after QKV completes.
Args:
- q_proj (
TileTensor): Query projection tensor of shape [batch, seq_len, n_heads, head_dim]. - kv_collection (
ContinuousBatchingKVCacheCollection): The continuous batching KV cache collection. - freqs_cis (
TileTensor): Frequency tensor for RoPE of shape [max_seq_len, head_dim]. - layer_idx (
UInt32): The layer index for accessing the correct cache. - valid_lengths (
TileTensor): Tensor of shape [batch] containing the valid length for each sequence. RoPE is only applied to positions within these lengths. - output (
TileTensor): Output tensor for Q with RoPE applied, same shape as q_proj. - context (
DeviceContextPtr): Device context pointer for execution.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!