Mojo module
rms_norm_fp8
Fused RMSNorm + FP8 quantization kernel.
Provides the fused RMSNorm + FP8 quantization primitive used by both the standalone normalization layer and the fused allreduce + RMSNorm + FP8 kernel. Lives in comm/ so that allreduce_residual_rmsnorm_fp8 can depend on it without introducing a comm → nn → comm circular dependency.
Functions
-
block_reduce_sum_and_max: Combined block reduction for sum and max in a single barrier pass. -
rms_norm_fused_fp8: Fused RMSNorm + FP8 quantization kernel (TileTensor overload).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!