Skip to content

Conversation

@PTNobel
Copy link
Collaborator

@PTNobel PTNobel commented Dec 11, 2025

@ZedongPeng I couldn't get the backward pass working with MPAX:

  1. The MPAX solver uses jax.lax.while_loop with dynamic termination for iterative optimization
  2. JAX cannot differentiate through while_loop in reverse mode (this is a JAX limitation, not a bug)

How are we supposed to differentiate MPAX?

@ZedongPeng
Copy link

Hi @PTNobel. To obtain derivatives through the unrolled iterations, we need to set unroll=True. When jit=True, MPAX under the hood uses jax.lax.scan.

@PTNobel
Copy link
Collaborator Author

PTNobel commented Dec 12, 2025

Sorry, when do I have to set unroll=True? In the solve call?

@ZedongPeng
Copy link

when you defined the solver, say solver = r2HPDHG(eps_abs=1e-4, eps_rel=1e-4, verbose=True, unroll=True).
https://github.com/MIT-Lu-Lab/MPAX/blob/ca1c669fea422c2509fae4bc30e1d79e6ca8977c/mpax/rapdhg.py#L306
https://github.com/MIT-Lu-Lab/MPAX/blob/ca1c669fea422c2509fae4bc30e1d79e6ca8977c/mpax/r2hpdhg.py#L59

@PTNobel
Copy link
Collaborator Author

PTNobel commented Dec 17, 2025

@ZedongPeng Sorry to bother you again but I tried adding unroll=True (see the commit) and it started trying to allocate a terabyte or two of memory and subsequently crashing. Any idea why? Do I need to relax the tolerances?

@ZedongPeng
Copy link

I don't think it is related to tolerance. Do you have a sample script or data file handy that I could test with?

@PTNobel
Copy link
Collaborator Author

PTNobel commented Dec 18, 2025

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants