Skip to content

Commit 3a26ac3

Browse files
committed
restrict Reversible to AbstractERK and check result in adjoint
1 parent 9024d2f commit 3a26ac3

File tree

3 files changed

+11
-32
lines changed

3 files changed

+11
-32
lines changed

diffrax/_adjoint.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from ._heuristics import is_sde, is_unsafe_sde
1818
from ._saveat import save_y, SaveAt, SubSaveAt
19+
from ._solution import RESULTS
1920
from ._solver import (
2021
AbstractItoSolver,
2122
AbstractRungeKutta,
@@ -994,9 +995,10 @@ def _loop_reversible_bwd(
994995

995996
def grad_step(state):
996997
def forward_step(y0, solver_state, args, terms):
997-
y1, _, dense_info, new_solver_state, _ = solver.step(
998+
y1, _, dense_info, new_solver_state, result = solver.step(
998999
terms, t0, t1, y0, args, solver_state, False
9991000
)
1001+
assert result == RESULTS.successful
10001002
return y1, dense_info, new_solver_state
10011003

10021004
(
@@ -1016,6 +1018,7 @@ def forward_step(y0, solver_state, args, terms):
10161018
y0, dense_info, solver_state, result = solver.backward_step(
10171019
terms, t0, t1, y1, args, solver_state, False
10181020
)
1021+
assert result == RESULTS.successful
10191022

10201023
# Pull gradients back through interpolation
10211024

diffrax/_solver/reversible.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from .._solution import RESULTS, update_result
99
from .._solver.base import (
1010
AbstractReversibleSolver,
11-
AbstractSolver,
1211
AbstractWrappedSolver,
1312
)
1413
from .._term import AbstractTerm
14+
from .runge_kutta import AbstractERK
1515

1616

1717
ω = cast(Callable, ω)
@@ -26,7 +26,7 @@ class Reversible(
2626
"""
2727
Reversible solver method.
2828
29-
Allows any solver ([`diffrax.AbstractSolver`][]) to be made
29+
Allows any explicit Runge-Kutta solver ([`diffrax.AbstractERK`][]) to be made
3030
algebraically reversible.
3131
3232
**Arguments:**
@@ -77,7 +77,7 @@ class Reversible(
7777
```
7878
"""
7979

80-
solver: AbstractSolver
80+
solver: AbstractERK
8181
coupling_parameter: float = 0.999
8282

8383
@property
@@ -114,6 +114,10 @@ def init(
114114
y0: Y,
115115
args: Args,
116116
) -> _SolverState:
117+
if not isinstance(self.solver, AbstractERK):
118+
raise ValueError(
119+
"`Reversible` is only compatible with `AbstractERK` base solvers."
120+
)
117121
original_solver_init = self.solver.init(terms, t0, t1, y0, args)
118122
return (original_solver_init, y0)
119123

test/test_reversible.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -198,34 +198,6 @@ def test_reversible_explicit(stepsize_controller, saveat):
198198
)
199199

200200

201-
@pytest.mark.parametrize(
202-
"saveat",
203-
[
204-
diffrax.SaveAt(t0=True, t1=True),
205-
diffrax.SaveAt(t0=True, ts=jnp.linspace(0, 5, 10), t1=True),
206-
],
207-
)
208-
def test_reversible_implicit(saveat):
209-
n = 10
210-
y0 = jnp.linspace(1, 10, num=n)
211-
key = jr.PRNGKey(10)
212-
f = VectorField(n, n, n, depth=4, key=key)
213-
terms = diffrax.ODETerm(f)
214-
args = jnp.array([0.5])
215-
base_solver = diffrax.Kvaerno5()
216-
solver = diffrax.Reversible(base_solver)
217-
stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8)
218-
219-
_compare_grads(
220-
(y0, args, terms),
221-
solver,
222-
solver,
223-
saveat,
224-
stepsize_controller,
225-
dual_y0=False,
226-
)
227-
228-
229201
@pytest.mark.parametrize(
230202
"saveat",
231203
[

0 commit comments

Comments
 (0)