Skip to content

Add BatchNorm layers to CNN in MNIST tutorial for improved training stability#4773

Merged
copybara-service[bot] merged 5 commits into
google:mainfrom
sanepunk:main
Aug 4, 2025
Merged

Add BatchNorm layers to CNN in MNIST tutorial for improved training stability#4773
copybara-service[bot] merged 5 commits into
google:mainfrom
sanepunk:main

Conversation

@sanepunk

@sanepunk sanepunk commented Jun 7, 2025

Copy link
Copy Markdown
Contributor

What does this PR do?

Integrates nnx.BatchNorm into the CNN model used in docs_nnx/mnist_tutorial.ipynb. Enables batch-norm-aware behavior with .train() and .eval() modes to improve convergence and metrics visualization.

Highlights

  • Added nnx.BatchNorm() after each convolution
  • Updated training loop to call model.train() so running statistics are updated
  • Switched to model.eval() in the evaluation loop for deterministic inference
  • Observed smoother loss and accuracy curves in the notebook’s metrics graphs

Why?

BatchNorm stabilizes and accelerates training by normalizing activations between layers, leading to better gradient flow and faster convergence.

Testing

Ran the MNIST tutorial notebook end-to-end and confirmed:

  • Training loss decreases more smoothly
  • Validation accuracy improves more quickly
  • Metrics plots clearly reflect these improvements

Closes #

@google-cla

google-cla Bot commented Jun 7, 2025

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@review-notebook-app

Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@8bitmp3

8bitmp3 commented Jun 30, 2025

Copy link
Copy Markdown
Contributor

@IvyZX @jburnim @cgarciae WDYT

@pccoeit

pccoeit commented Jul 18, 2025

Copy link
Copy Markdown

Thanks for the update

@sanepunk

sanepunk commented Aug 2, 2025

Copy link
Copy Markdown
Contributor Author

Can you merge this PR? @8bitmp3 @IvyZX @jburnim @cgarciae

…ning stability and performance. Update training and evaluation steps to switch model modes appropriately. Clarify documentation on the use of .eval() for inference.
@sanepunk

sanepunk commented Aug 2, 2025

Copy link
Copy Markdown
Contributor Author

Can you merge this PR? @8bitmp3 @IvyZX @jburnim @cgarciae

I’ve made the necessary changes to fix the GitHub Actions failure, everything checks out now.

@cgarciae

cgarciae commented Aug 3, 2025

Copy link
Copy Markdown
Collaborator

@sanepunk looks good! You can fix pre-commit with:

pip install pre-commit
pre-commit run --all-files

- Removed explicit `ipython3` specification from code cells for uniformity.
- Improved formatting and readability of code snippets throughout the document.
@copybara-service copybara-service Bot merged commit dfaaa48 into google:main Aug 4, 2025
17 checks passed
@sanepunk

sanepunk commented Aug 5, 2025

Copy link
Copy Markdown
Contributor Author

thank you @cgarciae @8bitmp3 @IvyZX @jburnim for helping me with my first PR to FLAX

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants