Skip to main content

Mojo function

rms_norm_kv_cache_ragged_paged

rms_norm_kv_cache_ragged_paged[dtype: DType, params: KVCacheStaticParams, page_size: Int, cache_dtype: DType, //, target: StringSlice[StaticConstantOrigin], multiply_before_cast: Bool, per_head_norm: Bool](kv_collection: PagedKVCacheCollection[cache_dtype, params, page_size], gamma: TileTensor[dtype, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], layer_idx: UInt32, total_seq_len: UInt32, input_row_offsets: TileTensor[DType.uint32, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], context: DeviceContextPtr)

Performs RMSNorm in place on new entries in the key cache.

This is done by first creating the ragged tensor weight_shape (total_seq_len, num_heads, head_dim) of the new token tensor. To do this we need to pass in total_seq_len on host. Then, using input_row_offsets we find the corresponding batch and token index, and use that together with the static head and channel indices to store to/load from the key cache. This uses the input/output lambdas on the RMSNorm kernel.

This function could apply RMSNorm to a subset of dimensions in each head, determined by the size of the gamma tensor. In this case, it operates on a ragged tensor view of the key cache with shape (total_seq_len, num_heads, rms_norm_cols), where rms_norm_cols is the length of gamma and must be <= head_size.

weight_offset is a constant offset argument added to the learned weights at runtime. Here, we don't use any offset, so we pass in a zero scalar.

multiply_before_cast is a boolean parameter that determines whether to multiply the normalized values by the gamma tensor before casting to the output dtype or not. We set it to True by default.