Skip to main content

Mojo function

fused_rope_rmsnorm_quantization_kernel

fused_rope_rmsnorm_quantization_kernel[dtype: DType, freq_dtype: DType, gamma_dtype: DType, QRopeOutputLayoutType: TensorLayout, QRopeLayoutType: TensorLayout, KVLayoutType: TensorLayout, InputRowOffsetsLayoutType: TensorLayout, FreqsCisLayoutType: TensorLayout, GammaLayoutType: TensorLayout, cache_t: KVCacheT, block_size: Int, n_rope_blocks: Int, n_rms_blocks: Int](q_rope_output: TileTensor[dtype, QRopeOutputLayoutType, MutExternalOrigin], q_rope: TileTensor[dtype, QRopeLayoutType, ImmutExternalOrigin], kv: TileTensor[dtype, KVLayoutType, ImmutExternalOrigin], input_row_offsets: TileTensor[DType.uint32, InputRowOffsetsLayoutType, ImmutExternalOrigin], freqs_cis: TileTensor[freq_dtype, FreqsCisLayoutType, ImmutExternalOrigin], gamma: TileTensor[gamma_dtype, GammaLayoutType, ImmutExternalOrigin], k_cache: cache_t, epsilon: Float32)

Fused GPU kernel that applies RoPE to query projections and RMSNorm to KV cache, reading the inputs from a KV buffer and quantizing the final results before writing to the KVCache object.

This kernel processes tokens in parallel across GPU blocks, with separate block groups handling RoPE and RMSNorm operations. The RoPE blocks apply rotary position embeddings to both the query rope part (in-place) and the key cache rope part (in-place). The RMSNorm blocks normalize the first kv_norm_dim elements of the key cache entries.

Parameters:

  • dtype (DType): Data type of query tensors.
  • freq_dtype (DType): Data type of frequency cosine/sine values.
  • gamma_dtype (DType): Data type of RMSNorm gamma weights.
  • QRopeOutputLayoutType (TensorLayout): Layout types of the output query rope tensor.
  • QRopeLayoutType (TensorLayout): Layout types of the input query rope tensor.
  • KVLayoutType (TensorLayout): Layout type of the KV buffer tensor.
  • InputRowOffsetsLayoutType (TensorLayout): Layout types of the row offset indices tensor.
  • FreqsCisLayoutType (TensorLayout): Layout types of the frequency tensor.
  • GammaLayoutType (TensorLayout): Layout types of the gamma weights tensor.
  • cache_t (KVCacheT): Type of the KV cache.
  • block_size (Int): Number of threads per block.
  • n_rope_blocks (Int): Number of blocks allocated for RoPE computation.
  • n_rms_blocks (Int): Number of blocks allocated for RMSNorm computation.

Args:

  • q_rope_output (TileTensor): Output tensor for RoPE-applied query projections. Shape: [tot_seq_len, num_heads, rope_dim].
  • q_rope (TileTensor): Input query rope projections. Shape: [tot_seq_len, num_heads, rope_dim].
  • kv (TileTensor): KV latent tensor from the first projection. Shape: [tot_seq_len, cache_head_dim] where cache_head_dim = kv_lora_rank + qk_rope_head_dim.
  • input_row_offsets (TileTensor): Row offsets indicating request boundaries. Shape: [num_batches + 1].
  • freqs_cis (TileTensor): Precomputed RoPE frequency values. Shape: [max_seq_len, rope_dim].
  • gamma (TileTensor): RMSNorm gamma weights. Shape: [kv_norm_dim].
  • k_cache (cache_t): Key cache to apply RoPE and RMSNorm to.
  • epsilon (Float32): Small constant for numerical stability in RMSNorm.

Was this page helpful?