Mojo function
allreduce_residual_rmsnorm_fp8
allreduce_residual_rmsnorm_fp8[in_dtype: DType, out_dtype: DType, scales_dtype: DType, ngpus: Int, in_layout: TensorLayout, in_origin: Origin[mut=in_origin.mut], //](input_buffers: InlineArray[TileTensor[in_dtype, in_layout, in_origin], ngpus], residual: TileTensor[in_dtype, residual.LayoutType, residual.origin, address_space=residual.address_space, linear_idx_type=residual.linear_idx_type, element_size=residual.element_size], output: TileTensor[out_dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], residual_output: TileTensor[in_dtype, residual_output.LayoutType, residual_output.origin, address_space=residual_output.address_space, linear_idx_type=residual_output.linear_idx_type, element_size=residual_output.element_size], gamma: TileTensor[in_dtype, gamma.LayoutType, gamma.origin, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], epsilon: Scalar[in_dtype], weight_offset: Scalar[in_dtype], scale_ub: Float32, scale_output: TileTensor[scales_dtype, scale_output.LayoutType, scale_output.origin, address_space=scale_output.address_space, linear_idx_type=scale_output.linear_idx_type, element_size=scale_output.element_size], rank_sigs: InlineArray[UnsafePointer[Signal, MutAnyOrigin], 8], ctx: DeviceContext)
TileTensor primary implementation of allreduce_residual_rmsnorm_fp8.
Parameters:
- in_dtype (
DType): Input data type (e.g. bfloat16). - out_dtype (
DType): FP8 output data type (e.g. float8_e4m3fn). - scales_dtype (
DType): Scale factor data type (e.g. float32). - ngpus (
Int): Number of GPUs participating. - in_layout (
TensorLayout): Layout of the input TileTensors. - in_origin (
Origin): Origin of the input TileTensors.
Args:
- input_buffers (
InlineArray): Per-GPU input buffers as TileTensors. - residual (
TileTensor): Residual buffer as a TileTensor. - output (
TileTensor): Output buffer for FP8 values as a TileTensor. - residual_output (
TileTensor): Output buffer for pre-norm sum as a TileTensor. - gamma (
TileTensor): RMSNorm gamma weights (1D TileTensor). - epsilon (
Scalar): RMSNorm epsilon for numerical stability. - weight_offset (
Scalar): Additive offset for gamma weights. - scale_ub (
Float32): Upper bound for FP8 scale clamping. - scale_output (
TileTensor): Output buffer for per-row FP8 scales as a TileTensor. - rank_sigs (
InlineArray): Per-GPU signal pointers for synchronization. - ctx (
DeviceContext): Device context for this GPU.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!