Python module
rms_norm
Normalization layer.
DistributedRMSNorm
class max.pipelines.nn.norm.rms_norm.DistributedRMSNorm(rms_norms: list[max.pipelines.nn.norm.rms_norm.RMSNorm], devices: list[max.graph.type.DeviceRef])
devices
devices*: list[max.graph.type.DeviceRef]*
rms_norms
rms_norms*: list[max.pipelines.nn.norm.rms_norm.RMSNorm]*
RMSNorm
class max.pipelines.nn.norm.rms_norm.RMSNorm(weight: max._mlir._mlir_libs._mlir.ir.Value | max.graph.value.TensorValue | max.graph.type.Shape | max.graph.type.Dim | int | float | numpy.integer | numpy.floating | numpy.ndarray, eps: float = 1e-06)
eps
eps*: float* = 1e-06
weight
weight*: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!