Skip to main content

Mojo function

fused_rope_rmsnorm_kernel

fused_rope_rmsnorm_kernel[dtype: DType, freq_dtype: DType, gamma_dtype: DType, q_rope_output_shape_types: Variadic[CoordLike], q_rope_output_stride_types: Variadic[CoordLike], q_rope_shape_types: Variadic[CoordLike], q_rope_stride_types: Variadic[CoordLike], input_row_offsets_shape_types: Variadic[CoordLike], input_row_offsets_stride_types: Variadic[CoordLike], freqs_cis_shape_types: Variadic[CoordLike], freqs_cis_stride_types: Variadic[CoordLike], gamma_shape_types: Variadic[CoordLike], gamma_stride_types: Variadic[CoordLike], cache_t: KVCacheT, block_size: Int, n_rope_blocks: Int, n_rms_blocks: Int](q_rope_output: TileTensor[dtype, MutExternalOrigin], q_rope: TileTensor[dtype, ImmutExternalOrigin], input_row_offsets: TileTensor[DType.uint32, ImmutExternalOrigin], freqs_cis: TileTensor[freq_dtype, ImmutExternalOrigin], gamma: TileTensor[gamma_dtype, ImmutExternalOrigin], k_cache: cache_t, epsilon: Float32)

Fused GPU kernel that applies RoPE to query projections and RMSNorm to KV cache.

This kernel processes tokens in parallel across GPU blocks, with separate block groups handling RoPE and RMSNorm operations. The RoPE blocks apply rotary position embeddings to both the query rope part (in-place) and the key cache rope part (in-place). The RMSNorm blocks normalize the first kv_norm_dim elements of the key cache entries.

Parameters:

  • dtype (DType): Data type of query tensors.
  • freq_dtype (DType): Data type of frequency cosine/sine values.
  • gamma_dtype (DType): Data type of RMSNorm gamma weights.
  • q_rope_output_shape_types (Variadic): Shape types of the output query rope tensor.
  • q_rope_output_stride_types (Variadic): Stride types of the output query rope tensor.
  • q_rope_shape_types (Variadic): Shape types of the input query rope tensor.
  • q_rope_stride_types (Variadic): Stride types of the input query rope tensor.
  • input_row_offsets_shape_types (Variadic): Shape types of the row offset indices tensor.
  • input_row_offsets_stride_types (Variadic): Stride types of the row offset indices tensor.
  • freqs_cis_shape_types (Variadic): Shape types of the frequency tensor.
  • freqs_cis_stride_types (Variadic): Stride types of the frequency tensor.
  • gamma_shape_types (Variadic): Shape types of the gamma weights tensor.
  • gamma_stride_types (Variadic): Stride types of the gamma weights tensor.
  • cache_t (KVCacheT): Type of the KV cache.
  • block_size (Int): Number of threads per block.
  • n_rope_blocks (Int): Number of blocks allocated for RoPE computation.
  • n_rms_blocks (Int): Number of blocks allocated for RMSNorm computation.

Args:

  • q_rope_output (TileTensor): Output tensor for RoPE-applied query projections. Shape: [tot_seq_len, num_heads, rope_dim].
  • q_rope (TileTensor): Input query rope projections. Shape: [tot_seq_len, num_heads, rope_dim].
  • input_row_offsets (TileTensor): Row offsets indicating request boundaries. Shape: [num_batches + 1].
  • freqs_cis (TileTensor): Precomputed RoPE frequency values. Shape: [max_seq_len, rope_dim].
  • gamma (TileTensor): RMSNorm gamma weights. Shape: [kv_norm_dim].
  • k_cache (cache_t): Key cache to apply RoPE and RMSNorm to.
  • epsilon (Float32): Small constant for numerical stability in RMSNorm.

Was this page helpful?