Reference implementation of our method for selectively resetting interim states at any step in a linear recurrence, as we compute all states in the linear recurrence in parallel via a prefix scan, for PyTorch.
-
Clone this repository.
-
Install the Python dependencies in
requirements.txt
. -
There is no third step.
Our selective-resetting method applies to any linear recurrence (diagonal or not, time-variant or not, over real numbers or other fields) computed in parallel via a prefix scan. This repository provides a sample implementation for only one case: non-diagonal linear recurrences over real numbers, either time-variant or time-invariant. The implementation, for PyTorch, is in a single file: sample_implementation.py.
We will walk through an example to show how to use our sample implementation. Launch a Python interpreter (e.g., in a notebook), and execute the following code to compute a non-diagonal linear recurrence
import torch
import torch_parallel_scan as tps
# Create sequence of random transition matrices A:
n, d = (50, 3)
A = torch.randn(n, d, d)
# Compute linear recurrence X in parallel via a prefix scan:
X_without_resets = tps.prefix_scan(A, torch.matmul, dim=-3)
The random vectors in each transition matrix of A
tend to have L-2 norm greater than 1, so the L-2 norm of state vectors in X_without_resets
will tend to increase with each additional matrix multiplication. If we plot the max vector norm within each state matrix in X_without_resets
, using the following code,
import matplotlib.pyplot as plt
fig, axis = plt.subplots(layout='constrained', figsize=(5, 3.5))
axis.set_title("Max L-2 Norm of Each State's Trailing Dimension")
axis.bar(range(n), X_without_resets.norm(dim=-1).max(dim=-1).values)
axis.set(yscale='log', ylim=(1e-1, X_without_resets.max() * 10))
axis.grid(axis='y')
we obtain a plot similar to this one (it won't be the same because the matrices in A
are random):
Let's say we don't want the L-2 norms of state vectors to spiral out of control. The solution is to rescale state vectors whenever their L-2 norm starts getting too large -- say, whenever the L-2 norm exceeds 10, to keep things simple. Unfortunately, we can't do that in parallel, can we?
Actually, yes, we can. Our selective-resetting method allows us to reset interim states that meet a selection criteria we specify, with a reset function we specify, in parallel, as we compute all states via a prefix scan:
from sample_implementation import ParallelizedLeftToRightRecurrenceWithSelectiveResetting
# Define non-diagonal linear recurrence with selective resetting:
parallelized_recurrence_with_sr = ParallelizedLeftToRightRecurrenceWithSelectiveResetting(
d=d,
select_func=lambda mats: (mats.norm(dim=-1) > 10).any(dim=-1)[..., None, None],
reset_func=lambda mats: torch.nn.functional.normalize(mats, dim=-1),
)
# Compute recurrence with selective resets via parallel prefix scan:
X_with_resets = parallelized_recurrence_with_sr(A)
If we compare the max vector norms of X_without_resets
and X_with_resets
with the following code,
fig, axes = plt.subplots(ncols=2, sharey=True, layout='constrained', figsize=(10, 3.5))
fig.suptitle("Max L-2 Norm of Each State's Trailing Dimension")
axis = axes[0]
axis.set_title("Without Selective Resetting")
axis.bar(range(n), X_without_resets.norm(dim=-1).max(dim=-1).values)
axis.set(yscale='log', ylim=(1e-1, X_without_resets.max() * 10))
axis.grid(axis='y')
axis = axes[1]
axis.set_title("With Selective Resetting")
axis.bar(range(n), X_with_resets.norm(dim=-1).max(dim=-1).values)
axis.set(yscale='log', ylim=(1e-1, X_without_resets.max() * 10))
axis.grid(axis='y')
we obtain a plot similar to this one (it won't be the same because the matrices in A
are random):
During the parallel prefix scan, whenever the norm of any vector in an interim state exceeds 10, the vectors in that state are reset to unit norm. Note that interim states can reach a max vector norm just below 10 without being reset, and they may be multiplied with other interim states that also reach a max vector norm just below 10, so some final compounded states may have a max vector norm that exceeds 10.
If you're interested in understanding the intuition behind our selective-resetting method, Appendix C of our paper has an informal explanation of it with step-by-step examples.
We can use our sample implementation with recurrences of the form
and applying ParallelizedLeftToRightRecurrenceWithSelectiveResetting
over the chain of reformulated matrices. Our code computes the recurrence left-to-right, as is typical in PyTorch applications, so we would actually implement
parallelized_recurrence_with_sr = ParallelizedLeftToRightRecurrenceWithSelectiveResetting(
d=d,
select_func=my_select_func,
reset_func=my_reset_func,
)
cumul_tilde_A_with_resets = parallelized_recurrence_with_sr(tilde_A)
tilde_x = torch.matmul(tilde_x_0, cumul_tilde_A_with_resets)
where my_select_func
and my_reset_func
are functions you must define, and tilde_x_0
and tilde_A
are the reformulated initial condition and transition matrices, respectively, which you must compute in advance.
ParallelizedLeftToRightRecurrenceWithSelectiveResetting
is a standard PyTorch nn.Module
, so you can use it as a component of any PyTorch model, trainable via SGD with conventional techniques.
Our algorithm for parallel estimation of the spectrum of Lyapunov exponents of dynamical systems applies our selective-resetting method to prevent vector states from becoming colinear as we apply a parallel prefix scan to a sequence of Jacobian matrix values, over generalized orders of magnitude (GOOMs), represented as complex tensors. Our reference implementation of the parallel algorithm for estimation of Lyapunov exponents is at https://github.com/glassroom/parallel_lyapunov_exponents.
@article{
heinsen2025generalized,
title={Generalized Orders of Magnitude for Scalable, Parallel, High-Dynamic-Range Computation},
author={Franz A. Heinsen and Leo Kozachkov},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2025},
url={https://openreview.net/forum?id=SUuzb0SOGu},
note={}
}
The work here originated with casual conversations over email between us, the authors, in which we wondered if it might be possible to find a succinct expression for computing non-diagonal linear recurrences in parallel, by mapping them to the complex plane. Our casual conversations gradually evolved into the development of generalized orders of magnitude, along with an algorithm for estimating Lyapunov exponents in parallel, and a novel method for selectively resetting interim states in a parallel prefix scan.
We hope others find our work and our code useful.