Mojo module
rms_norm_fp8
Fused RMSNorm + FP8 quantization kernel.
Provides the fused RMSNorm + FP8 quantization primitive used by both the standalone normalization layer and the fused allreduce + RMSNorm + FP8 kernel. Lives in comm/ so that allreduce_residual_rmsnorm_fp8 can depend on it without introducing a comm β nn β comm circular dependency.
Functionsβ
- β
block_reduce_sum_and_max: Combined block reduction for sum and max in a single barrier pass. - β
rms_norm_fused_fp8: Fused RMSNorm + FP8 quantization kernel (TileTensor overload).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!