Miles Cranmer, Sam Greydanus, Stephan Hoyer, Peter Battaglia, David Spergel, Shirley Ho
- Paper
- Blog
- Self-Contained Tutorial
- Paper example notebook: double pendulum
- Paper example notebook: special relativity
- Paper example notebook: wave equation
Warning
This project was developed with JAX 0.1.55 (2020). For compatibility, use pixi to install the original environment (see guide). Alternatively, adapt the core equation snippets below for modern JAX.
In this project we propose Lagrangian Neural Networks (LNNs), which can parameterize arbitrary Lagrangians using neural networks. In contrast to Hamiltonian Neural Networks, these models do not require canonical coordinates and perform well in situations where generalized momentum is difficult to compute (e.g., the double pendulum). This is particularly appealing for use with a learned latent representation, a case where HNNs struggle. Unlike previous work on learning Lagrangians, LNNs are fully general and extend to non-holonomic systems such as the 1D wave equation.
| Neural Networks | Neural ODEs | HNN | DLN (ICLR'19) | LNN (this work) | |
|---|---|---|---|---|---|
| Learns dynamics | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ |
| Learns continuous-time dynamics | ✔️ | ✔️ | ✔️ | ✔️ | |
| Learns exact conservation laws | ✔️ | ✔️ | ✔️ | ||
| Learns from arbitrary coordinates | ✔️ | ✔️ | ✔️ | ✔️ | |
| Learns arbitrary Lagrangians | ✔️ |
The key innovation of LNNs is the automatic derivation of equations of motion from learned Lagrangians. The core equation of motion is version-independent and can be adapted to any JAX version:
import jax
import jax.numpy as jnp
def lagrangian_eom(lagrangian, state, t=None):
"""Compute Euler-Lagrange equation of motion.
Args:
lagrangian: A function L(q, q_dot) representing the Lagrangian
state: Concatenated position q and velocity q_dot
Returns:
Concatenated velocity and acceleration
"""
q, q_dot = jnp.split(state, 2)
# Euler-Lagrange equation: d/dt(∂L/∂q̇) - ∂L/∂q = 0
# Rearranged to solve for acceleration q_ddot
q_ddot = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_dot))
@ (jax.grad(lagrangian, 0)(q, q_dot)
- jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_dot) @ q_dot))
return jnp.concatenate([q_dot, q_ddot])
# Custom initialization for MLPs learning Lagrangians
def custom_init(layer_sizes, seed=0):
"""Initialize MLP to learn Lagrangians more effectively."""
# See lnn/core.py for full implementation
# Uses optimized scale for each layerThis equation of motion works with any JAX version. The rest of the codebase is maintained with JAX 0.1.55 for reproducibility of the paper's environment.
This project requires specific versions of JAX and other dependencies from early 2020 to ensure reproducibility. We use pixi for reproducible environment management.
-
Install pixi (if not already installed):
curl -fsSL https://pixi.sh/install.sh | bashOr via Homebrew:
brew install pixi
-
Clone this repository and install dependencies:
git clone https://github.com/MilesCranmer/lagrangian_nns.git cd lagrangian_nns pixi install -
Run the notebooks:
pixi run jupyter notebook
The environment includes:
- Python 3.7
- JAX 0.1.55 & jaxlib 0.1.37 (January 2020 versions)
- NumPy 1.19.5
- Matplotlib 3.1.2 (visualization)
- MoviePy 1.0.0 (visualization)
- celluloid 0.2.0 (visualization)
All dependencies are pinned to versions from early 2020 to ensure the code behaves exactly as it did when the paper was published.