Skip to content

Reference implementation of selective resetting for parallel prefix scans, as proposed in "Generalized Orders of Magnitude for Scalable, Parallel, High-Dynamic-Range Computation" (Heinsen and Kozachkov, 2025).

License

Notifications You must be signed in to change notification settings

glassroom/selective_resetting

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

40 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

selective_resetting

Reference implementation of our method for selectively resetting interim states at any step in a linear recurrence, as we compute all states in the linear recurrence in parallel via a prefix scan, for PyTorch.

Installing

  1. Clone this repository.

  2. Install the Python dependencies in requirements.txt.

  3. There is no third step.

Sample Use

Our selective-resetting method applies to any linear recurrence (diagonal or not, time-variant or not, over real numbers or other fields) computed in parallel via a prefix scan. This repository provides a sample implementation for only one case: non-diagonal linear recurrences over real numbers, either time-variant or time-invariant. The implementation, for PyTorch, is in a single file: sample_implementation.py.

We will walk through an example to show how to use our sample implementation. Launch a Python interpreter (e.g., in a notebook), and execute the following code to compute a non-diagonal linear recurrence $X_t = A_t X_{t-1}$, with initial state $X_0 = I$, in parallel, via a prefix scan:

import torch
import torch_parallel_scan as tps

# Create sequence of random transition matrices A:
n, d = (50, 3)
A = torch.randn(n, d, d)

# Compute linear recurrence X in parallel via a prefix scan:
X_without_resets = tps.prefix_scan(A, torch.matmul, dim=-3)

The random vectors in each transition matrix of A tend to have L-2 norm greater than 1, so the L-2 norm of state vectors in X_without_resets will tend to increase with each additional matrix multiplication. If we plot the max vector norm within each state matrix in X_without_resets, using the following code,

import matplotlib.pyplot as plt

fig, axis = plt.subplots(layout='constrained', figsize=(5, 3.5))
axis.set_title("Max L-2 Norm of Each State's Trailing Dimension")
axis.bar(range(n), X_without_resets.norm(dim=-1).max(dim=-1).values)
axis.set(yscale='log', ylim=(1e-1, X_without_resets.max() * 10))
axis.grid(axis='y')

we obtain a plot similar to this one (it won't be the same because the matrices in A are random):

states without resetting

Let's say we don't want the L-2 norms of state vectors to spiral out of control. The solution is to rescale state vectors whenever their L-2 norm starts getting too large -- say, whenever the L-2 norm exceeds 10, to keep things simple. Unfortunately, we can't do that in parallel, can we?

Actually, yes, we can. Our selective-resetting method allows us to reset interim states that meet a selection criteria we specify, with a reset function we specify, in parallel, as we compute all states via a prefix scan:

from sample_implementation import ParallelizedLeftToRightRecurrenceWithSelectiveResetting

# Define non-diagonal linear recurrence with selective resetting:
parallelized_recurrence_with_sr = ParallelizedLeftToRightRecurrenceWithSelectiveResetting(
    d=d,
    select_func=lambda mats: (mats.norm(dim=-1) > 10).any(dim=-1)[..., None, None],
    reset_func=lambda mats: torch.nn.functional.normalize(mats, dim=-1),
)

# Compute recurrence with selective resets via parallel prefix scan:
X_with_resets = parallelized_recurrence_with_sr(A)

If we compare the max vector norms of X_without_resets and X_with_resets with the following code,

fig, axes = plt.subplots(ncols=2, sharey=True, layout='constrained', figsize=(10, 3.5))
fig.suptitle("Max L-2 Norm of Each State's Trailing Dimension")

axis = axes[0]
axis.set_title("Without Selective Resetting")
axis.bar(range(n), X_without_resets.norm(dim=-1).max(dim=-1).values)
axis.set(yscale='log', ylim=(1e-1, X_without_resets.max() * 10))
axis.grid(axis='y')

axis = axes[1]
axis.set_title("With Selective Resetting")
axis.bar(range(n), X_with_resets.norm(dim=-1).max(dim=-1).values)
axis.set(yscale='log', ylim=(1e-1, X_without_resets.max() * 10))
axis.grid(axis='y')

we obtain a plot similar to this one (it won't be the same because the matrices in A are random):

comparison

During the parallel prefix scan, whenever the norm of any vector in an interim state exceeds 10, the vectors in that state are reset to unit norm. Note that interim states can reach a max vector norm just below 10 without being reset, and they may be multiplied with other interim states that also reach a max vector norm just below 10, so some final compounded states may have a max vector norm that exceeds 10.

If you're interested in understanding the intuition behind our selective-resetting method, Appendix C of our paper has an informal explanation of it with step-by-step examples.

Extending to Non-Diagonal Linear Recurrences with Biases

We can use our sample implementation with recurrences of the form $x_t = A_t x_{t-1} + b_t$, given initial condition $x_0 \in \mathbb{R}^d$, by reformulating the recurrence as a sequence of matrix products,

linear recurrence with biases

and applying ParallelizedLeftToRightRecurrenceWithSelectiveResetting over the chain of reformulated matrices. Our code computes the recurrence left-to-right, as is typical in PyTorch applications, so we would actually implement $\tilde{x}^T_t = \tilde{x}^T_0 \tilde{A}^T_1 \tilde{A}^T_2 \dots \tilde{A}^T_t$, for example, with the following code:

parallelized_recurrence_with_sr = ParallelizedLeftToRightRecurrenceWithSelectiveResetting(
    d=d,
    select_func=my_select_func,
    reset_func=my_reset_func,
)

cumul_tilde_A_with_resets = parallelized_recurrence_with_sr(tilde_A)
tilde_x = torch.matmul(tilde_x_0, cumul_tilde_A_with_resets)

where my_select_func and my_reset_func are functions you must define, and tilde_x_0 and tilde_A are the reformulated initial condition and transition matrices, respectively, which you must compute in advance.

Using as a Component of PyTorch Models

ParallelizedLeftToRightRecurrenceWithSelectiveResetting is a standard PyTorch nn.Module, so you can use it as a component of any PyTorch model, trainable via SGD with conventional techniques.

Other Implementations

Our algorithm for parallel estimation of the spectrum of Lyapunov exponents of dynamical systems applies our selective-resetting method to prevent vector states from becoming colinear as we apply a parallel prefix scan to a sequence of Jacobian matrix values, over generalized orders of magnitude (GOOMs), represented as complex tensors. Our reference implementation of the parallel algorithm for estimation of Lyapunov exponents is at https://github.com/glassroom/parallel_lyapunov_exponents.

Citing

@article{
heinsen2025generalized,
title={Generalized Orders of Magnitude for Scalable, Parallel, High-Dynamic-Range Computation},
author={Franz A. Heinsen and Leo Kozachkov},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2025},
url={https://openreview.net/forum?id=SUuzb0SOGu},
note={}
}

Notes

The work here originated with casual conversations over email between us, the authors, in which we wondered if it might be possible to find a succinct expression for computing non-diagonal linear recurrences in parallel, by mapping them to the complex plane. Our casual conversations gradually evolved into the development of generalized orders of magnitude, along with an algorithm for estimating Lyapunov exponents in parallel, and a novel method for selectively resetting interim states in a parallel prefix scan.

We hope others find our work and our code useful.

About

Reference implementation of selective resetting for parallel prefix scans, as proposed in "Generalized Orders of Magnitude for Scalable, Parallel, High-Dynamic-Range Computation" (Heinsen and Kozachkov, 2025).

Topics

Resources

License

Stars

Watchers

Forks

Contributors 2

  •  
  •  

Languages