Mojo function
allreduce_rmsnorm_fp8
allreduce_rmsnorm_fp8[in_dtype: DType, out_dtype: DType, scales_dtype: DType, rank: Int, ngpus: Int, //](input_buffers: InlineArray[NDBuffer[in_dtype, rank, ImmutAnyOrigin], ngpus], output: NDBuffer[out_dtype, rank, output.origin, output.shape, output.strides, alignment2=output.alignment2, address_space=output.address_space, exclusive=output.exclusive], gamma: TileTensor[in_dtype, gamma.LayoutType, gamma.origin, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], epsilon: Scalar[in_dtype], weight_offset: Scalar[in_dtype], scale_ub: Float32, scale_output: NDBuffer[scales_dtype, rank, scale_output.origin, scale_output.shape, scale_output.strides, alignment2=scale_output.alignment2, address_space=scale_output.address_space, exclusive=scale_output.exclusive], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext)
Fused allreduce + RMSNorm + FP8 quantization.
Combines P2P allreduce across GPUs, RMSNorm normalization, and FP8 dynamic quantization into a single kernel launch. Eliminates the global memory round-trip between allreduce output and RMSNorm input.
Note: This kernel does not issue an end barrier. The FP8 output and scale buffers are safe to read only on the local GPU (stream ordering guarantees visibility). If a remote GPU needs to read these outputs, the caller must insert an explicit barrier. The start barrier of the NEXT allreduce call protects the input buffers that are read by remote GPUs.
Signal buffer sizing: 1-stage path (payload < threshold): size_ofSignal only. 2-stage path (payload > threshold): size_ofSignal + ceildiv(rows, ngpus) * cols (fp8 data) + align_up(ceildiv(rows, ngpus) * sizeof(scales_dtype), simd_width) (scales + pad)
Parameters:
- in_dtype (
DType): Input data type (e.g. bfloat16). - out_dtype (
DType): FP8 output data type (e.g. float8_e4m3fn). - scales_dtype (
DType): Scale factor data type (e.g. float32). - rank (
Int): Tensor rank of input/output/scale buffers. - ngpus (
Int): Number of GPUs participating.
Args:
- input_buffers (
InlineArray): Per-GPU input buffers (last dim = cols). - output (
NDBuffer): Output buffer for FP8 values (same shape as input). - gamma (
TileTensor): RMSNorm gamma weights (1D TileTensor of length cols). - epsilon (
Scalar): RMSNorm epsilon for numerical stability. - weight_offset (
Scalar): Additive offset for gamma weights. - scale_ub (
Float32): Upper bound for FP8 scale clamping. - scale_output (
NDBuffer): Output buffer for per-row FP8 scales (last dim = 1). - rank_sigs (
InlineArray): Per-GPU signal pointers for synchronization. - ctx (
DeviceContext): Device context for this GPU.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!