@@ -792,3 +792,50 @@ def func(self, terms, t0, y0, args):
792
792
ValueError , match = r"Terms are not compatible with solver!"
793
793
):
794
794
diffrax .diffeqsolve (term , solver , 0.0 , 1.0 , 0.1 , y0 )
795
+
796
+
797
+ def test_vmap_backprop ():
798
+ def dynamics (t , y , args ):
799
+ param = args
800
+ return param - y
801
+
802
+ def event_fn (t , y , args , ** kwargs ):
803
+ return y - 1.5
804
+
805
+ def single_loss_fn (param ):
806
+ solver = diffrax .Euler ()
807
+ root_finder = diffrax .VeryChord (rtol = 1e-3 , atol = 1e-6 )
808
+ event = diffrax .Event (event_fn , root_finder )
809
+ term = diffrax .ODETerm (dynamics )
810
+ sol = diffrax .diffeqsolve (
811
+ term ,
812
+ solver = solver ,
813
+ t0 = 0.0 ,
814
+ t1 = 2.0 ,
815
+ dt0 = 0.1 ,
816
+ y0 = 0.0 ,
817
+ args = param ,
818
+ event = event ,
819
+ max_steps = 1000 ,
820
+ )
821
+ assert sol .ys is not None
822
+ final_y = sol .ys [- 1 ]
823
+ return param ** 2 + final_y ** 2
824
+
825
+ def batched_loss_fn (params : jnp .ndarray ) -> jnp .ndarray :
826
+ return jax .vmap (single_loss_fn )(params )
827
+
828
+ def grad_fn (params : jnp .ndarray ) -> jnp .ndarray :
829
+ return jax .grad (lambda p : jnp .sum (batched_loss_fn (p )))(params )
830
+
831
+ batch = jnp .array ([1.0 , 2.0 , 3.0 ])
832
+
833
+ try :
834
+ grad = grad_fn (batch )
835
+ except NotImplementedError as e :
836
+ pytest .fail (f"NotImplementedError was raised: { e } " )
837
+ except Exception as e :
838
+ pytest .fail (f"An unexpected exception was raised: { e } " )
839
+
840
+ assert not jnp .isnan (grad ).any (), "Gradient should not be NaN."
841
+ assert not jnp .isinf (grad ).any (), "Gradient should not be infinite."
0 commit comments