-
-
Notifications
You must be signed in to change notification settings - Fork 152
AbstractReversibleSolver + ReversibleAdjoint #603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
AbstractReversibleSolver + ReversibleAdjoint #603
Conversation
* _integrate.py * Added new test checking gradient of vmapped diffeqsolve * Import optimistix * Fixed issue * added .any() * diffrax root finder
in python-poetry ~=3.9 is interpreted as >=3.9<3.10 [2], though it should be >=3.9,<4.0 [2] https://python-poetry.org/docs/dependency-specification/
merge changes from AbstractReversibleSolver
I've also added the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, gosh, this one took far too long for me to get around. Thank you for your patience! If I can I'd like this to be the next big thing I focus on getting in to Diffrax.
diffrax/__init__.py
Outdated
@@ -105,6 +106,7 @@ | |||
MultiButcherTableau as MultiButcherTableau, | |||
QUICSORT as QUICSORT, | |||
Ralston as Ralston, | |||
Reversible as Reversible, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call this something more specific to help distinguish what it is! I think here it's not clear that it's a solver, and even if it was ReversibleSolver
then that still wouldn't disambiguate amongst the various kinds of reversibility it's possible to cook up.
What do you call this yourself / in your paper? Its structure is Hamiltonian/coupled/etc so maybe a word of that sort is suitable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call will look like this:
solver = diffrax.Reversible(diffrax.Tsit5())
with the idea being that you are "reversifying" Tsit5. So it isn't a solver in itself, but a wrapper. The boring name could be something like ReversibleWrapper
and the fun name could be something like Reversify
. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but in the future we may cook up some other way of reversifying a solver! We should pick a name for this one that leaves that possibility open.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
James has gone for U-Reversible (after the Uno reverse card :). The analogy is that we take a step forward from z0
to y1
, then reverse and pull back from y1
onto z1
.
reversible_save_index + 1, tprev, reversible_ts | ||
) | ||
reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A very minor bug here: if it just so happens that we run with t0 == t1
then we'll end up with reversible_ts = [t0 inf inf inf ...]
, which will not produce desired results in the backward solve.
We have a special branch to handle the saving in the t0 == t1
case, we should add a line handling the state.reversible_ts is not None
case there.
diffrax/_adjoint.py
Outdated
|
||
# Pull solver_state gradients back onto y0, args, terms. | ||
|
||
_, init_vjp = eqx.filter_vjp(solver.init, terms, ts[0], ts[1], y0, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not super clear to me that ts[0]
, ts[1]
are the correct values here. It looks like the saving routine is storing tprev
, which is not necessarily the same as state.tnext
, and the latter is what solver.init
was originally called with.
In principle the step size controller could return anything at all; in practice it is possible for tprev
to be 2 ULPs greater than state.tnext
when passing to the other side of a discontinuity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ts
here are reversible_ts
which follows the same logic as SaveAt(steps=True)
.
Is it not the case that the state.tnext
identified in diffeqsolve
(and used for solver.init
) has to be the first step that the solver took? I appreciate that they can be different at later points in the solve, but my understanding was that the first step was set in diffeqsolve?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think in this case you're saving the tprev
of the second step, not the tnext
of the first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, you're right.
We now return the tprev
and tnext
passed to solver.init
as a residual in the reversible loop and use these to get the vjp
.
diffrax/_integrate.py
Outdated
# Reversible info | ||
if max_steps is None: | ||
reversible_ts = None | ||
reversible_save_index = None | ||
else: | ||
reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype) | ||
reversible_save_index = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've thought of an alternative for this extra buffer, btw: ReversibleAdjoint.loop
could intercept saveat
and add an SubSaveAt(steps=True, save_fn=lambda t, y, args: None)
to record the extra times. Then peel it off again when returning the final state.
I think that (a) might be doable without making any changes to _integrate.py
and (b) would allow for also supporting SaveAt(steps=True)
. (As in that case we can just skip adding the extra SubSaveAt.) And (c) would avoid a few of the subtle issues I've commented on above about exactly which tprev/tnext-like value is actually being saved, because you can trust in the rest of the existing diffeqsolve to do that for you.
It's not a strong suggestion though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this was the original idea I tried but I couldn't get around a leaked tracer error! I'm willing to give it another go if you start feeling strongly about it though ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's nail everything else down and then consider this. Reflecting on this, I do suspect it will make the code much easier to maintain in the long run.
diffrax/_solver/reversible.py
Outdated
y1 = (self.coupling_parameter * (ω(y0) - ω(z0)) + ω(step_z0)).ω | ||
|
||
step_y1, y_error, _, _, result2 = self.solver.step( | ||
terms, t1, t0, y1, args, original_solver_state, True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just spotted this has t0 and t1 back-to-front, which I think may in general mess with our solving logic as described previously. Is this intended / is it possible not to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is intended and is essential to the solver!
0cfd4ec
to
3a26ac3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if you were ready for a review on this yet, but I took a look over anyway 😁 We're making really good progress! In particular now that we're settled on just AbstractERK
then I think all our complicated state-reconstruction concerns go away, so the chance of footgunning ourselves has gone way down 😁
# (i.e. the state used on the forward). Otherwise, in `ReversibleAdjoint`, | ||
# we would take a local forward step from an incorrect `solver_state`. | ||
solver_state = jax.lax.cond( | ||
tm1 > 0, lambda _: (tm1, ym1, dt), lambda _: (t0, y0, dt), None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this predicate is assuming we're solving over the time interval of the form [0, T]
? But in practice we might have that the left endpoint is nonzero.
This aside it's (a) possible to use jnp.where
, which is lower-overhead for small stuff like this, but also (b) I think lax.cond(pred, lambda: foo, lambda: bar)
should work, without arguments.
# We pre-compute the step size to avoid numerical instability during the | ||
# backward_step. This is okay (albeit slightly ugly) as `LeapfrogMidpoint` can't | ||
# be used with adaptive step sizes. | ||
dt = t1 - t0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will mean we silently get the wrong results if used with diffrax.StepTo
with non-constant step size.
I think we might be able to save this by adjusting the backward step logic, so that it computes y0
in the step that returns it, rather than in the step before:
def backward_step(...):
t2, y2 = solver_state
control = terms.contr(t0, t2)
y0 = (y2**ω - terms.vf_prod(t0, y1, args, control) ** ω).ω
solver_state = (t1, y1)
return y0, ...
note that this is also essentially the same idea as what we currently do in the forward step. I think this might require some detail around how to handle the first and last step, still.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's much nicer! This should work, but as you say, requires special handling of the first and last step. I've had a think about how to do this (and the point above about assuming [0, T]
) and I believe it's not possible without knowing where the start and end point of the solve is (or at least when we're at the first/last step)!
This feels like it's questioning the attempt to use a single-step api to represent multi-step solvers - hinting at this line:
# TODO: support arbitrary linear multistep methods |
I have a few ideas for how to expand the leapfrog midpoint api to support this (e.g. have first_step
, last_step
flags or hold t0
, t1
as attributes) but I'm not sure these are elegant enough for diffrax lol... wdyt?
``` | ||
""" | ||
|
||
solver: AbstractERK |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we're special-casing to just this, then I am much more comfortable with the logic below!
diffrax/_solver/reversible.py
Outdated
raise ValueError( | ||
"`UReversible` is only compatible with `AbstractERK` base solvers." | ||
) | ||
original_solver_init = self.solver.init(terms, t0, t1, y0, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So IIUC we're always going to be using the non-FSAL version of AbstractERK
here?
If so then we get to sidstep all the painful solver_state
difficulties that we've been debating back-and-forth because this will just always be None
?
In which case can we have an assert original_solver_init is None
here, and likewise elide it throughout the rest of the logic below? (If need be we can also add a flag to AbstractRungeKutta
to force non-FSAL-ness, and set that here.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is correct. We are always using an AbstractERK
with made_jump=True
but we haven't explicitly turned off the fsal
flag. Before we switched to made_jump
we were turning off fsal
by:
object.__setattr__(self.solver.tableau, "fsal", False)
But this modifies solver
in place which is not ideal. This is to say that original_solver_init
is not None
in the current implementation. It would be nice to set fsal=False
so that we can pass around the None
throughout.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think something like this should do the trick then:
def __init__(self, solver: AbstractERK):
self.solver = eqx.tree_at(lambda s: s.disable_fsal, solver, True)
And then add this flag to AbstractRungeKutta
here:
diffrax/diffrax/_solver/runge_kutta.py
Line 356 in 482de90
scan_kind: None | Literal["lax", "checkpointed", "bounded"] = None |
Which takes effect here:
diffrax/diffrax/_solver/runge_kutta.py
Line 398 in 482de90
fsal = fsal and not vf_expensive |
Re-opening #593.
Implements
AbstractReversibleSolver
base class andReversibleAdjoint
for reversible back propagation.This updates
SemiImplicitEuler
,LeapfrogMidpoint
andReversibleHeun
to subclassAbstractReversibleSolver
.Implementation
AbstractReversibleSolver
subclassesAbstractSolver
and adds abackward_step
method:This method should reconstruct
y0
,solver_state
att0
fromy1
,solver_state
att1
. See the aforementioned solvers for examples.When backpropagating,
ReversibleAdjoint
uses thisbackward_step
to reconstruct state. We then take avjp
through a local forward step and accumulate gradients.ReversibleAdjoint
now also pulls back gradients from any interpolated values, so we can useSaveAt(ts=...)
!We allow arbitrary
solver_state
(provided it can be reconstructed reversibly) and calculate gradients w.r.t.solver_state
. Finally, we pull back these gradients ontoy0
,args
,terms
using thesolver.init
method.