Mojo function
fused_rope_rmsnorm_kernel
fused_rope_rmsnorm_kernel[dtype: DType, freq_dtype: DType, gamma_dtype: DType, q_rope_output_shape_types: Variadic[CoordLike], q_rope_output_stride_types: Variadic[CoordLike], q_rope_shape_types: Variadic[CoordLike], q_rope_stride_types: Variadic[CoordLike], input_row_offsets_shape_types: Variadic[CoordLike], input_row_offsets_stride_types: Variadic[CoordLike], freqs_cis_shape_types: Variadic[CoordLike], freqs_cis_stride_types: Variadic[CoordLike], gamma_shape_types: Variadic[CoordLike], gamma_stride_types: Variadic[CoordLike], cache_t: KVCacheT, block_size: Int, n_rope_blocks: Int, n_rms_blocks: Int](q_rope_output: TileTensor[dtype, MutExternalOrigin], q_rope: TileTensor[dtype, ImmutExternalOrigin], input_row_offsets: TileTensor[DType.uint32, ImmutExternalOrigin], freqs_cis: TileTensor[freq_dtype, ImmutExternalOrigin], gamma: TileTensor[gamma_dtype, ImmutExternalOrigin], k_cache: cache_t, epsilon: Float32)
Fused GPU kernel that applies RoPE to query projections and RMSNorm to KV cache.
This kernel processes tokens in parallel across GPU blocks, with separate
block groups handling RoPE and RMSNorm operations. The RoPE blocks apply
rotary position embeddings to both the query rope part (in-place) and the
key cache rope part (in-place). The RMSNorm blocks normalize the first
kv_norm_dim elements of the key cache entries.
Parameters:
- dtype (
DType): Data type of query tensors. - freq_dtype (
DType): Data type of frequency cosine/sine values. - gamma_dtype (
DType): Data type of RMSNorm gamma weights. - q_rope_output_shape_types (
Variadic): Shape types of the output query rope tensor. - q_rope_output_stride_types (
Variadic): Stride types of the output query rope tensor. - q_rope_shape_types (
Variadic): Shape types of the input query rope tensor. - q_rope_stride_types (
Variadic): Stride types of the input query rope tensor. - input_row_offsets_shape_types (
Variadic): Shape types of the row offset indices tensor. - input_row_offsets_stride_types (
Variadic): Stride types of the row offset indices tensor. - freqs_cis_shape_types (
Variadic): Shape types of the frequency tensor. - freqs_cis_stride_types (
Variadic): Stride types of the frequency tensor. - gamma_shape_types (
Variadic): Shape types of the gamma weights tensor. - gamma_stride_types (
Variadic): Stride types of the gamma weights tensor. - cache_t (
KVCacheT): Type of the KV cache. - block_size (
Int): Number of threads per block. - n_rope_blocks (
Int): Number of blocks allocated for RoPE computation. - n_rms_blocks (
Int): Number of blocks allocated for RMSNorm computation.
Args:
- q_rope_output (
TileTensor): Output tensor for RoPE-applied query projections. Shape: [tot_seq_len, num_heads, rope_dim]. - q_rope (
TileTensor): Input query rope projections. Shape: [tot_seq_len, num_heads, rope_dim]. - input_row_offsets (
TileTensor): Row offsets indicating request boundaries. Shape: [num_batches + 1]. - freqs_cis (
TileTensor): Precomputed RoPE frequency values. Shape: [max_seq_len, rope_dim]. - gamma (
TileTensor): RMSNorm gamma weights. Shape: [kv_norm_dim]. - k_cache (
cache_t): Key cache to apply RoPE and RMSNorm to. - epsilon (
Float32): Small constant for numerical stability in RMSNorm.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!