Mojo function
rms_norm_rope_gpu
rms_norm_rope_gpu[input_dtype: DType, cos_sin_dtype: DType, rank: Int, //, input_fn: def[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[input_dtype, width], cos_fn: def[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[cos_sin_dtype, width], sin_fn: def[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[cos_sin_dtype, width], output_fn: def[width: Int, alignment: Int](IndexList[rank], SIMD[input_dtype, width]) capturing -> None, multiply_before_cast: Bool, pdl_level: PDLLevel = PDLLevel(1)](shape: IndexList[rank, element_type=shape.element_type], gamma: TileTensor[input_dtype, gamma.LayoutType, gamma.origin, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], epsilon: Scalar[input_dtype], weight_offset: Scalar[input_dtype], cos_vals: TileTensor[cos_sin_dtype, cos_vals.LayoutType, cos_vals.origin, address_space=cos_vals.address_space, linear_idx_type=cos_vals.linear_idx_type, element_size=cos_vals.element_size], sin_vals: TileTensor[cos_sin_dtype, sin_vals.LayoutType, sin_vals.origin, address_space=sin_vals.address_space, linear_idx_type=sin_vals.linear_idx_type, element_size=sin_vals.element_size], ctx: DeviceContext)
Fused RMS normalization followed by Rotary Position Embedding (RoPE) for GPU.
Computes: normed = rms_norm(input, gamma, epsilon, weight_offset) x1, x2 = split(normed, axis=-1) # halves along last dim rotated = concat(-x2, x1, axis=-1) output = normed * cos_vals + rotated * sin_vals
The last dimension must be a known even number.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!