Skip to content

Commit 36a6b00

Browse files
LuggiStruggipatrick-kidger
authored andcommitted
Fix for making vmap over diffeqsolve possible (#578)
* _integrate.py * Added new test checking gradient of vmapped diffeqsolve * Import optimistix * Fixed issue * added .any() * diffrax root finder
1 parent 0217f92 commit 36a6b00

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

diffrax/_integrate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,8 @@ def body_fun(state):
649649
event_mask = final_state.event_mask
650650
flat_mask = jtu.tree_leaves(event_mask)
651651
assert all(jnp.shape(x) == () for x in flat_mask)
652-
event_happened = jnp.any(jnp.stack(flat_mask))
652+
float_mask = jnp.array(flat_mask).astype(jnp.float32)
653+
event_happened = jnp.max(float_mask) > 0.0
653654

654655
def _root_find():
655656
_interpolator = solver.interpolation_cls(

test/test_integrate.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,3 +792,50 @@ def func(self, terms, t0, y0, args):
792792
ValueError, match=r"Terms are not compatible with solver!"
793793
):
794794
diffrax.diffeqsolve(term, solver, 0.0, 1.0, 0.1, y0)
795+
796+
797+
def test_vmap_backprop():
798+
def dynamics(t, y, args):
799+
param = args
800+
return param - y
801+
802+
def event_fn(t, y, args, **kwargs):
803+
return y - 1.5
804+
805+
def single_loss_fn(param):
806+
solver = diffrax.Euler()
807+
root_finder = diffrax.VeryChord(rtol=1e-3, atol=1e-6)
808+
event = diffrax.Event(event_fn, root_finder)
809+
term = diffrax.ODETerm(dynamics)
810+
sol = diffrax.diffeqsolve(
811+
term,
812+
solver=solver,
813+
t0=0.0,
814+
t1=2.0,
815+
dt0=0.1,
816+
y0=0.0,
817+
args=param,
818+
event=event,
819+
max_steps=1000,
820+
)
821+
assert sol.ys is not None
822+
final_y = sol.ys[-1]
823+
return param**2 + final_y**2
824+
825+
def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray:
826+
return jax.vmap(single_loss_fn)(params)
827+
828+
def grad_fn(params: jnp.ndarray) -> jnp.ndarray:
829+
return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params)
830+
831+
batch = jnp.array([1.0, 2.0, 3.0])
832+
833+
try:
834+
grad = grad_fn(batch)
835+
except NotImplementedError as e:
836+
pytest.fail(f"NotImplementedError was raised: {e}")
837+
except Exception as e:
838+
pytest.fail(f"An unexpected exception was raised: {e}")
839+
840+
assert not jnp.isnan(grad).any(), "Gradient should not be NaN."
841+
assert not jnp.isinf(grad).any(), "Gradient should not be infinite."

0 commit comments

Comments
 (0)