rms_scaling: If True, center and scale are ignored, and the inputs are scaled by gamma and the inverse square root of the square of all inputs. This is an approximate and faster approach that avoids ever computing the mean of the input.
However, in the implementation, it actually does the following:
if self.rms_scaling:
# Calculate outputs with only variance and gamma if rms scaling
# is enabled
# Calculate the variance along self.axis (layer activations).
variance = ops.var(inputs, axis=self.axis, keepdims=True)
inv = ops.rsqrt(variance + self.epsilon)
outputs = (
inputs * inv * ops.cast(_broadcast(self.gamma), inputs.dtype)
)
So the mean is indeed used, as variance is computed here rather than RMS norm.
There was also a discussion during the addition of RMS Normalization (https://github.com/keras-team/keras/pull/20911#issuecomment-2687658774) that confirms this behavior.
I think the docs could use an update to clarify this behavior. Right now, it sounds like the mean isn't used when rms_scaling is on, but the code suggests otherwise.
Comment From: dhantule
Hi @jennifermcguire76, thanks for reporting this.
We are working on this issue and you can track the progress here #21234. I am closing this as this is duplicate of #21234, feel free to raise another issue. Thanks!