diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index cacc1070..5f6d05d5 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -649,7 +649,8 @@ def body_fun(state): event_mask = final_state.event_mask flat_mask = jtu.tree_leaves(event_mask) assert all(jnp.shape(x) == () for x in flat_mask) - event_happened = jnp.any(jnp.stack(flat_mask)) + float_mask = jnp.array(flat_mask).astype(jnp.float32) + event_happened = jnp.max(float_mask) > 0.0 def _root_find(): _interpolator = solver.interpolation_cls( diff --git a/test/test_integrate.py b/test/test_integrate.py index 555d6ade..af7e1b8d 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -792,3 +792,50 @@ def func(self, terms, t0, y0, args): ValueError, match=r"Terms are not compatible with solver!" ): diffrax.diffeqsolve(term, solver, 0.0, 1.0, 0.1, y0) + + +def test_vmap_backprop(): + def dynamics(t, y, args): + param = args + return param - y + + def event_fn(t, y, args, **kwargs): + return y - 1.5 + + def single_loss_fn(param): + solver = diffrax.Euler() + root_finder = diffrax.VeryChord(rtol=1e-3, atol=1e-6) + event = diffrax.Event(event_fn, root_finder) + term = diffrax.ODETerm(dynamics) + sol = diffrax.diffeqsolve( + term, + solver=solver, + t0=0.0, + t1=2.0, + dt0=0.1, + y0=0.0, + args=param, + event=event, + max_steps=1000, + ) + assert sol.ys is not None + final_y = sol.ys[-1] + return param**2 + final_y**2 + + def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray: + return jax.vmap(single_loss_fn)(params) + + def grad_fn(params: jnp.ndarray) -> jnp.ndarray: + return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params) + + batch = jnp.array([1.0, 2.0, 3.0]) + + try: + grad = grad_fn(batch) + except NotImplementedError as e: + pytest.fail(f"NotImplementedError was raised: {e}") + except Exception as e: + pytest.fail(f"An unexpected exception was raised: {e}") + + assert not jnp.isnan(grad).any(), "Gradient should not be NaN." + assert not jnp.isinf(grad).any(), "Gradient should not be infinite."