Mojo function
fused_rope_rmsnorm_quantization_kernel
fused_rope_rmsnorm_quantization_kernel[dtype: DType, freq_dtype: DType, gamma_dtype: DType, QRopeOutputLayoutType: TensorLayout, QRopeLayoutType: TensorLayout, InputRowOffsetsLayoutType: TensorLayout, FreqsCisLayoutType: TensorLayout, GammaLayoutType: TensorLayout, cache_t: KVCacheT, block_size: Int, n_rope_blocks: Int, n_rms_blocks: Int, out_rope_dtype: DType, kv_input_fn: def[width: Int](IndexList[2]) capturing -> SIMD[DType.bfloat16, width]](q_rope_output: TileTensor[out_rope_dtype, QRopeOutputLayoutType, MutExternalOrigin], q_rope: TileTensor[dtype, QRopeLayoutType, ImmutExternalOrigin], input_row_offsets: TileTensor[DType.uint32, InputRowOffsetsLayoutType, ImmutExternalOrigin], freqs_cis: TileTensor[freq_dtype, FreqsCisLayoutType, ImmutExternalOrigin], gamma: TileTensor[gamma_dtype, GammaLayoutType, ImmutExternalOrigin], k_cache: cache_t, epsilon: Float32)
Fused GPU kernel that applies RoPE to query projections and RMSNorm to KV cache, reading the inputs from a KV buffer and quantizing the final results before writing to the KVCache object.
This kernel processes tokens in parallel across GPU blocks, with separate
block groups handling RoPE and RMSNorm operations. The RoPE blocks apply
rotary position embeddings to both the query rope part (in-place) and the
key cache rope part (in-place). The RMSNorm blocks normalize the first
kv_norm_dim elements of the key cache entries.
Parameters:
- βdtype (
DType): Data type of query tensors. - βfreq_dtype (
DType): Data type of frequency cosine/sine values. - βgamma_dtype (
DType): Data type of RMSNorm gamma weights. - βQRopeOutputLayoutType (
TensorLayout): Layout types of the output query rope tensor. - βQRopeLayoutType (
TensorLayout): Layout types of the input query rope tensor. - βInputRowOffsetsLayoutType (
TensorLayout): Layout types of the row offset indices tensor. - βFreqsCisLayoutType (
TensorLayout): Layout types of the frequency tensor. - βGammaLayoutType (
TensorLayout): Layout types of the gamma weights tensor. - βcache_t (
KVCacheT): Type of the KV cache. - βblock_size (
Int): Number of threads per block. - βn_rope_blocks (
Int): Number of blocks allocated for RoPE computation. - βn_rms_blocks (
Int): Number of blocks allocated for RMSNorm computation. - βout_rope_dtype (
DType): Data type of the RoPE output. - βkv_input_fn (
def[width: Int](IndexList[2]) capturing -> SIMD[DType.bfloat16, width]): Input lambda function to load the KV latent values. Shape: [tot_seq_len, cache_head_dim]. Where cache_head_dim = kv_lora_rank- qk_rope_head_dim.
Args:
- βq_rope_output (
TileTensor[out_rope_dtype, QRopeOutputLayoutType, MutExternalOrigin]): Output tensor for RoPE-applied query projections. Shape: [tot_seq_len, num_heads, rope_dim]. - βq_rope (
TileTensor[dtype, QRopeLayoutType, ImmutExternalOrigin]): Input query rope projections. Shape: [tot_seq_len, num_heads, rope_dim]. - βinput_row_offsets (
TileTensor[DType.uint32, InputRowOffsetsLayoutType, ImmutExternalOrigin]): Row offsets indicating request boundaries. Shape: [num_batches + 1]. - βfreqs_cis (
TileTensor[freq_dtype, FreqsCisLayoutType, ImmutExternalOrigin]): Precomputed RoPE frequency values. Shape: [max_seq_len, rope_dim]. - βgamma (
TileTensor[gamma_dtype, GammaLayoutType, ImmutExternalOrigin]): RMSNorm gamma weights. Shape: [kv_norm_dim]. - βk_cache (
cache_t): Key cache to apply RoPE and RMSNorm to. - βepsilon (
Float32): Small constant for numerical stability in RMSNorm.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!