Skip to content

Conversation

swfsql
Copy link
Contributor

@swfsql swfsql commented Feb 1, 2024

  • Add the try_normalize_rms related functions.
  • Add the LayerRMSNorm1D module.

Implements RMS layer normalization as described in Root Mean Square Layer Normalization.
The layer normalizes a tensor axis to have stddev of 1.0, but differently from the other normal layer normalization, the mean is not forced to zero.
Computes tensor / (tensor.square().mean() + epsilon).sqrt().

  • Not sure if the bias (delta) should be removed from this layer. It may not be needed.

Note: I haven't made an actual pytorch test to compare, not even locally, but have made a bigger test that depend on this functionality and it appeared to work ok. So this PR should be considered a draft.

@swfsql swfsql changed the title add RMS normalization Add RMS normalization Feb 1, 2024
@swfsql swfsql mentioned this pull request Feb 2, 2024
13 tasks
@swfsql swfsql marked this pull request as draft March 1, 2024 14:54
swfsql added 3 commits March 1, 2024 15:46
- Add the try_normalize_rms related functions.
- Add the `LayerRMSNorm1D` module.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants