IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

fused_rope_rmsnorm_kernel

def fused_rope_rmsnorm_kernel[dtype: DType, freq_dtype: DType, gamma_dtype: DType, QRopeOutputLayoutType: TensorLayout, QRopeLayoutType: TensorLayout, InputRowOffsetsLayoutType: TensorLayout, FreqsCisLayoutType: TensorLayout, GammaLayoutType: TensorLayout, cache_t: KVCacheT, block_size: Int, n_rope_blocks: Int, n_rms_blocks: Int](q_rope_output: TileTensor[dtype, QRopeOutputLayoutType, MutUntrackedOrigin], q_rope: TileTensor[dtype, QRopeLayoutType, ImmutUntrackedOrigin], input_row_offsets: TileTensor[DType.uint32, InputRowOffsetsLayoutType, ImmutUntrackedOrigin], freqs_cis: TileTensor[freq_dtype, FreqsCisLayoutType, ImmutUntrackedOrigin], gamma: TileTensor[gamma_dtype, GammaLayoutType, ImmutUntrackedOrigin], k_cache: cache_t, epsilon: Float32)

Fused GPU kernel that applies RoPE to query projections and RMSNorm to KV cache.

This kernel processes tokens in parallel across GPU blocks, with separate block groups handling RoPE and RMSNorm operations. The RoPE blocks apply rotary position embeddings to both the query rope part (in-place) and the key cache rope part (in-place). The RMSNorm blocks normalize the first kv_norm_dim elements of the key cache entries.

Parameters:

  • ​dtype (DType): Data type of query tensors.
  • ​freq_dtype (DType): Data type of frequency cosine/sine values.
  • ​gamma_dtype (DType): Data type of RMSNorm gamma weights.
  • ​QRopeOutputLayoutType (TensorLayout): Layout types of the output query rope tensor.
  • ​QRopeLayoutType (TensorLayout): Layout types of the input query rope tensor.
  • ​InputRowOffsetsLayoutType (TensorLayout): Layout types of the row offset indices tensor.
  • ​FreqsCisLayoutType (TensorLayout): Layout types of the frequency tensor.
  • ​GammaLayoutType (TensorLayout): Layout types of the gamma weights tensor.
  • ​cache_t (KVCacheT): Type of the KV cache.
  • ​block_size (Int): Number of threads per block.
  • ​n_rope_blocks (Int): Number of blocks allocated for RoPE computation.
  • ​n_rms_blocks (Int): Number of blocks allocated for RMSNorm computation.

Args: