Skip to main content
Log in

Mojo function

rms_norm_kv_cache_ragged_continuous_batching

rms_norm_kv_cache_ragged_continuous_batching[type: DType, num_heads: Int, head_dim: Int, //, target: StringSlice[StaticConstantOrigin], multiply_before_cast: Bool](kv_collection: ContinuousBatchingKVCacheCollection[type, KVCacheStaticParams(UInt(num_heads), UInt(head_dim))], gamma: NDBuffer[type, 1, origin, shape, strides], epsilon: SIMD[type, 1], weight_offset: SIMD[type, 1], layer_idx: SIMD[uint32, 1], total_seq_len: SIMD[uint32, 1], input_row_offsets: NDBuffer[uint32, 1, origin, shape, strides], context: DeviceContextPtr)

Performs RMSNorm in place on new entries in the key cache.

This is done by first creating the ragged tensor weight_shape (total_seq_len, num_heads, head_dim) of the new token tensor. To do this we need to pass in total_seq_len on host. Then, using input_row_offsets we find the corresponding batch and token index, and use that together with the static head and channel indices to store to/load from the key cache. This uses the input/output lambdas on the RMSNorm kernel.

This function could apply RMSNorm to a subset of dimensions in each head, determined by the size of the gamma tensor. In this case, it operates on a ragged tensor view of the key cache with shape (total_seq_len, num_heads, rms_norm_cols), where rms_norm_cols is the length of gamma and must be <= head_size.

weight_offset is a constant offset argument added to the learned weights at runtime. Here, we don't use any offset, so we pass in a zero scalar.

multiply_before_cast is a boolean parameter that determines whether to multiply the normalized values by the gamma tensor before casting to the output type or not. We set it to True by default.

Was this page helpful?