IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

generic_fused_qk_rope_bshd_continuous_batch

def generic_fused_qk_rope_bshd_continuous_batch[dtype: DType, //, *, interleaved: Bool, target: StringSlice[StaticConstantOrigin]](q_proj: TileTensor[dtype, Storage=q_proj.Storage, address_space=q_proj.address_space, linear_idx_type=q_proj.linear_idx_type, element_size=q_proj.element_size], kv_collection: ContinuousBatchingKVCacheCollection, freqs_cis: TileTensor[dtype, Storage=freqs_cis.Storage, address_space=freqs_cis.address_space, linear_idx_type=freqs_cis.linear_idx_type, element_size=freqs_cis.element_size], layer_idx: UInt32, valid_lengths: TileTensor[DType.uint32, Storage=valid_lengths.Storage, address_space=valid_lengths.address_space, linear_idx_type=valid_lengths.linear_idx_type, element_size=valid_lengths.element_size], output: TileTensor[dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], context: DeviceContext)

Performs a fused RoPE projection for Q and K projections.

We have a manually fused QKV projection with mo.opaque dtypes in our Llama model. Due to a limitation in custom op definitions, we can't declare both a tensor and opaque dtype as output from a custom kernel. This requires us to only note Q_proj as an output from the QKV projection. If we immediately follow the QKV proj kernel with a RoPE kernel applied to K, we'll get a race condition because the graph compiler doesn't know about the dependency between these kernels in the graph definition. Here we fuse the RoPE kernel applied to Q_proj with K_proj, so K_proj RoPE is only executed after QKV completes.

Args: