Skip to content

AbstractReversibleSolver + ReversibleAdjoint #603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 43 commits into
base: main
Choose a base branch
from

Conversation

sammccallum
Copy link

Re-opening #593.

Implements AbstractReversibleSolver base class and ReversibleAdjoint for reversible back propagation.

This updates SemiImplicitEuler, LeapfrogMidpoint and ReversibleHeun to subclass AbstractReversibleSolver.

Implementation

AbstractReversibleSolver subclasses AbstractSolver and adds a backward_step method:

@abc.abstractmethod
def backward_step(
    self,
    terms: PyTree[AbstractTerm],
    t0: RealScalarLike,
    t1: RealScalarLike,
    y1: Y,
    args: Args,
    solver_state: _SolverState,
    made_jump: BoolScalarLike,
) -> tuple[Y, DenseInfo, _SolverState]:

This method should reconstruct y0, solver_state at t0 from y1, solver_state at t1. See the aforementioned solvers for examples.

When backpropagating, ReversibleAdjoint uses this backward_step to reconstruct state. We then take a vjp through a local forward step and accumulate gradients.

ReversibleAdjoint now also pulls back gradients from any interpolated values, so we can use SaveAt(ts=...)!

We allow arbitrary solver_state (provided it can be reconstructed reversibly) and calculate gradients w.r.t. solver_state. Finally, we pull back these gradients onto y0, args, terms using the solver.init method.

ricor07 and others added 30 commits February 8, 2025 22:44
* _integrate.py

* Added new test checking gradient of vmapped diffeqsolve

* Import optimistix

* Fixed issue

* added .any()

* diffrax root finder
in python-poetry ~=3.9 is interpreted as >=3.9<3.10 [2], though it should be >=3.9,<4.0
[2] https://python-poetry.org/docs/dependency-specification/
merge changes from AbstractReversibleSolver
@sammccallum
Copy link
Author

I've also added the Reversible RK solvers here which just subclasses AbstractReversibleAdjoint. Let me know what you think of this and I can add some documentation when it's good to go!

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, gosh, this one took far too long for me to get around. Thank you for your patience! If I can I'd like this to be the next big thing I focus on getting in to Diffrax.

@@ -105,6 +106,7 @@
MultiButcherTableau as MultiButcherTableau,
QUICSORT as QUICSORT,
Ralston as Ralston,
Reversible as Reversible,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this something more specific to help distinguish what it is! I think here it's not clear that it's a solver, and even if it was ReversibleSolver then that still wouldn't disambiguate amongst the various kinds of reversibility it's possible to cook up.

What do you call this yourself / in your paper? Its structure is Hamiltonian/coupled/etc so maybe a word of that sort is suitable.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call will look like this:

solver = diffrax.Reversible(diffrax.Tsit5())

with the idea being that you are "reversifying" Tsit5. So it isn't a solver in itself, but a wrapper. The boring name could be something like ReversibleWrapper and the fun name could be something like Reversify. Thoughts?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but in the future we may cook up some other way of reversifying a solver! We should pick a name for this one that leaves that possibility open.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

James has gone for U-Reversible (after the Uno reverse card :). The analogy is that we take a step forward from z0 to y1, then reverse and pull back from y1 onto z1.

reversible_save_index + 1, tprev, reversible_ts
)
reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very minor bug here: if it just so happens that we run with t0 == t1 then we'll end up with reversible_ts = [t0 inf inf inf ...], which will not produce desired results in the backward solve.

We have a special branch to handle the saving in the t0 == t1 case, we should add a line handling the state.reversible_ts is not None case there.


# Pull solver_state gradients back onto y0, args, terms.

_, init_vjp = eqx.filter_vjp(solver.init, terms, ts[0], ts[1], y0, args)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not super clear to me that ts[0], ts[1] are the correct values here. It looks like the saving routine is storing tprev, which is not necessarily the same as state.tnext, and the latter is what solver.init was originally called with.

In principle the step size controller could return anything at all; in practice it is possible for tprev to be 2 ULPs greater than state.tnext when passing to the other side of a discontinuity.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ts here are reversible_ts which follows the same logic as SaveAt(steps=True).

Is it not the case that the state.tnext identified in diffeqsolve (and used for solver.init) has to be the first step that the solver took? I appreciate that they can be different at later points in the solve, but my understanding was that the first step was set in diffeqsolve?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think in this case you're saving the tprev of the second step, not the tnext of the first.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you're right.

We now return the tprev and tnext passed to solver.init as a residual in the reversible loop and use these to get the vjp.

Comment on lines 1409 to 1415
# Reversible info
if max_steps is None:
reversible_ts = None
reversible_save_index = None
else:
reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype)
reversible_save_index = 0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've thought of an alternative for this extra buffer, btw: ReversibleAdjoint.loop could intercept saveat and add an SubSaveAt(steps=True, save_fn=lambda t, y, args: None) to record the extra times. Then peel it off again when returning the final state.

I think that (a) might be doable without making any changes to _integrate.py and (b) would allow for also supporting SaveAt(steps=True). (As in that case we can just skip adding the extra SubSaveAt.) And (c) would avoid a few of the subtle issues I've commented on above about exactly which tprev/tnext-like value is actually being saved, because you can trust in the rest of the existing diffeqsolve to do that for you.

It's not a strong suggestion though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was the original idea I tried but I couldn't get around a leaked tracer error! I'm willing to give it another go if you start feeling strongly about it though ;)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's nail everything else down and then consider this. Reflecting on this, I do suspect it will make the code much easier to maintain in the long run.

y1 = (self.coupling_parameter * (ω(y0) - ω(z0)) + ω(step_z0)).ω

step_y1, y_error, _, _, result2 = self.solver.step(
terms, t1, t0, y1, args, original_solver_state, True
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just spotted this has t0 and t1 back-to-front, which I think may in general mess with our solving logic as described previously. Is this intended / is it possible not to?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is intended and is essential to the solver!

@sammccallum sammccallum force-pushed the AbstractReversibleSolver branch from 0cfd4ec to 3a26ac3 Compare May 14, 2025 10:14
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you were ready for a review on this yet, but I took a look over anyway 😁 We're making really good progress! In particular now that we're settled on just AbstractERK then I think all our complicated state-reconstruction concerns go away, so the chance of footgunning ourselves has gone way down 😁

# (i.e. the state used on the forward). Otherwise, in `ReversibleAdjoint`,
# we would take a local forward step from an incorrect `solver_state`.
solver_state = jax.lax.cond(
tm1 > 0, lambda _: (tm1, ym1, dt), lambda _: (t0, y0, dt), None
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this predicate is assuming we're solving over the time interval of the form [0, T]? But in practice we might have that the left endpoint is nonzero.

This aside it's (a) possible to use jnp.where, which is lower-overhead for small stuff like this, but also (b) I think lax.cond(pred, lambda: foo, lambda: bar) should work, without arguments.

Comment on lines +70 to +73
# We pre-compute the step size to avoid numerical instability during the
# backward_step. This is okay (albeit slightly ugly) as `LeapfrogMidpoint` can't
# be used with adaptive step sizes.
dt = t1 - t0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will mean we silently get the wrong results if used with diffrax.StepTo with non-constant step size.

I think we might be able to save this by adjusting the backward step logic, so that it computes y0 in the step that returns it, rather than in the step before:

def backward_step(...):
    t2, y2 = solver_state
    control = terms.contr(t0, t2)
    y0 = (y2**ω - terms.vf_prod(t0, y1, args, control) ** ω).ω
    solver_state = (t1, y1)
    return y0, ...

note that this is also essentially the same idea as what we currently do in the forward step. I think this might require some detail around how to handle the first and last step, still.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's much nicer! This should work, but as you say, requires special handling of the first and last step. I've had a think about how to do this (and the point above about assuming [0, T]) and I believe it's not possible without knowing where the start and end point of the solve is (or at least when we're at the first/last step)!

This feels like it's questioning the attempt to use a single-step api to represent multi-step solvers - hinting at this line:

# TODO: support arbitrary linear multistep methods

I have a few ideas for how to expand the leapfrog midpoint api to support this (e.g. have first_step, last_step flags or hold t0, t1 as attributes) but I'm not sure these are elegant enough for diffrax lol... wdyt?

```
"""

solver: AbstractERK
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're special-casing to just this, then I am much more comfortable with the logic below!

raise ValueError(
"`UReversible` is only compatible with `AbstractERK` base solvers."
)
original_solver_init = self.solver.init(terms, t0, t1, y0, args)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So IIUC we're always going to be using the non-FSAL version of AbstractERK here?

If so then we get to sidstep all the painful solver_state difficulties that we've been debating back-and-forth because this will just always be None?

In which case can we have an assert original_solver_init is None here, and likewise elide it throughout the rest of the logic below? (If need be we can also add a flag to AbstractRungeKutta to force non-FSAL-ness, and set that here.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct. We are always using an AbstractERK with made_jump=True but we haven't explicitly turned off the fsal flag. Before we switched to made_jump we were turning off fsal by:

object.__setattr__(self.solver.tableau, "fsal", False)

But this modifies solver in place which is not ideal. This is to say that original_solver_init is not None in the current implementation. It would be nice to set fsal=False so that we can pass around the None throughout.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something like this should do the trick then:

def __init__(self, solver: AbstractERK):
    self.solver = eqx.tree_at(lambda s: s.disable_fsal, solver, True)

And then add this flag to AbstractRungeKutta here:

scan_kind: None | Literal["lax", "checkpointed", "bounded"] = None

Which takes effect here:

fsal = fsal and not vf_expensive

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants