Mojo function
allreduce_rmsnorm_fp8
allreduce_rmsnorm_fp8[in_dtype: DType, out_dtype: DType, scales_dtype: DType, ngpus: Int, in_layout: TensorLayout, in_origin: Origin[mut=in_origin.mut], //](input_buffers: InlineArray[TileTensor[in_dtype, in_layout, in_origin], ngpus], output: TileTensor[out_dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], gamma: TileTensor[in_dtype, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], epsilon: Scalar[in_dtype], weight_offset: Scalar[in_dtype], scale_ub: Float32, scale_output: TileTensor[scales_dtype, address_space=scale_output.address_space, linear_idx_type=scale_output.linear_idx_type, element_size=scale_output.element_size], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext)
Fused allreduce + RMSNorm + FP8 quantization.
Combines P2P allreduce across GPUs, RMSNorm normalization, and FP8 dynamic quantization into a single kernel launch. Eliminates the global memory round-trip between allreduce output and RMSNorm input.
Note: This kernel does not issue an end barrier. The FP8 output and scale buffers are safe to read only on the local GPU (stream ordering guarantees visibility). If a remote GPU needs to read these outputs, the caller must insert an explicit barrier. The start barrier of the NEXT allreduce call protects the input buffers that are read by remote GPUs.
Signal buffer sizing: 1-stage path (payload < threshold): size_ofSignal only. 2-stage path (payload > threshold): size_ofSignal + ceildiv(rows, ngpus) * cols (fp8 data) + align_up(ceildiv(rows, ngpus) * sizeof(scales_dtype), simd_width) (scales + pad)
Parameters:
- βin_dtype (
DType): Input data type (e.g. bfloat16). - βout_dtype (
DType): FP8 output data type (e.g. float8_e4m3fn). - βscales_dtype (
DType): Scale factor data type (e.g. float32). - βngpus (
Int): Number of GPUs participating. - βin_layout (
TensorLayout): Layout of the input TileTensors. - βin_origin (
Origin[mut=in_origin.mut]): Origin of the input TileTensors.
Args:
- βinput_buffers (
InlineArray[TileTensor[in_dtype, in_layout, in_origin], ngpus]): Per-GPU input buffers as TileTensors. - βoutput (
TileTensor[out_dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Output buffer for FP8 values as a TileTensor. - βgamma (
TileTensor[in_dtype, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size]): RMSNorm gamma weights (1D TileTensor of length cols). - βepsilon (
Scalar[in_dtype]): RMSNorm epsilon for numerical stability. - βweight_offset (
Scalar[in_dtype]): Additive offset for gamma weights. - βscale_ub (
Float32): Upper bound for FP8 scale clamping. - βscale_output (
TileTensor[scales_dtype, address_space=scale_output.address_space, linear_idx_type=scale_output.linear_idx_type, element_size=scale_output.element_size]): Output buffer for per-row FP8 scales as a TileTensor. - βrank_sigs (
InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8]): Per-GPU signal pointers for synchronization. - βctx (
DeviceContext): Device context for this GPU.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!