Skip to content

Added steps=n logic for skipping steps #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

alessandrofasse
Copy link

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution! I like the look of this. I've left a first round of review comments. :)

@alessandrofasse
Copy link
Author

alessandrofasse commented May 14, 2025

Dear @patrick-kidger ,

I continued on the pull request and I could fix most of the cases. However, I am now stuck in an issue were are a couple of hours I could not find the problem.

Im shorts its about _save_t1 function. I am not able to get this right and now several tests were failing. I added the following tests, as you suggested, and I think my assert would be correct (can you cross check?)

Can you maybe assist me here a bit, I think I am misunderstanding about the logic on how states are stored and logged anlong the way.

You help would be much appreciated! 🙏

@alessandrofasse
Copy link
Author

Hey, its me again. I also added now another test where I think that the results should be the same, i.e. when using diffrax.SaveAt(ts=ts[::n]) should be the same to diffrax.SaveAt(steps=n) when running a diffrax.StepTo(ts=ts). Is that correct or a mistake in my logic?

@alessandrofasse alessandrofasse marked this pull request as draft May 21, 2025 07:33
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pushing this so far!

For _save_t1, then you're right, it gets a bit hairy! I think the idea for reasoning this through is as follows. We had some logic around whether to save the final value, which is a complicated combination of .t1, .steps, and the root-finding machinery. The important part of it, however, is when we check subsaveat.steps: at this point we were checking whether the value at t1 is already going to be saved via steps=True. This is the logic we need to preserve: whether the value at t1 is going to be saved anyway via steps=True.

In this PR, that corresponds to the boolean value num_accepted_steps % subsaveat.steps == 0. If this is True then we will have already saved the t1 value via .steps. If it is False then we will not.


Armed with this intuition, let's see what that looks like in code. Now sadly the following will of course crash with a tracer error, but it highlights what I think the change in logic should be:

 def _save_t1(subsaveat, save_state):
     if event is None or event.root_finder is None:
-        if subsaveat.t1 and not subsaveat.steps:
+        if subsaveat.t1 and not (final_state.num_accepted_steps % subsaveat.steps == 0):
             # If subsaveat.steps then the final value is already saved.
             save_state = _save(
                 tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1
             )
     else:
-        if subsaveat.t1 or subsaveat.steps:
+        if subsaveat.t1 or (final_state.num_accepted_steps % subsaveat.steps == 0):
             # In this branch we need to replace the last value with tfinal
             # and yfinal returned by the root finder also if subsaveat.steps
             # because we deleted the last value after the event time above.
             save_state = _save(
                 tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1
             )
     return save_state

This hits a tracer error, as the value of the modulus might not be known at compile time, and we have to wait until runtime.

So at this point we then need to chase through some annoying JAX details, because our saving has now moved from 'definitely known at compile time' to 'maybe known at compile time but maybe only known at runtime'.

Untested, but I think that means rewriting the _save_t1 function as follows:

def _save_t1(subsaveat, save_state):
    if subsaveat.steps == 0:
        # We're not saving the final value via `steps`, so we might need to save it via `t1`.
        t1_saved_via_steps = False
    elif subsaveat.steps == 1:
        # We're definitely saving the final value via `steps`, so we can skip saving it via `t1`.
        t1_saved_via_steps = True
    else:
        # We might be saving the final value via `steps`, so we might need to save it via `t1`.
        t1_saved_via_steps = final_state.num_accepted_steps % subsaveat.steps == 0
    if event is None or event.root_finder is None:
        if type(t1_saved_via_steps) is bool:
            t1_not_saved_via_steps = not t1_saved_via_steps
        else:
            t1_not_saved_via_steps = jnp.logical_not(t1_saved_via_steps)
        pred = subsaveat.t1 & t1_not_saved_via_steps
    else:
        # If we're using an event with a root finder, and are saving steps, then we need to write
        # the final value here because we deleted the last value after the event time above.
        pred = subsaveat.t1 | t1_saved_via_steps
    if pred is not False:
        save_state = _save(
            tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1, pred=pred
        )
    return save_state

along with the following small diff to _save:

 def _save(
     t: FloatScalarLike,
     y: PyTree[Array],
     args: PyTree,
     fn: Callable,
     save_state: SaveState,
     repeat: int,
     pred: bool,
 ) -> SaveState:
     ts = save_state.ts
     ys = save_state.ys
     save_index = save_state.save_index
 
     ts = lax.dynamic_update_slice_in_dim(
         ts, jnp.broadcast_to(t, (repeat,)), save_index, axis=0
     )
+    y_to_save = lax.cond(pred, lambda: fn(t, y, args), lambda: jtu.tree_map(lambda ys_: ys_[save_index], ys))
     ys = jtu.tree_map(
         lambda ys_, y_: lax.dynamic_update_slice_in_dim(
             ys_, jnp.broadcast_to(y_, (repeat, *y_.shape)), save_index, axis=0
         ),
         ys,
-        fn(t, y, args),
+        y_to_save,
     )
     save_index = save_index + jnp.where(pred, repeat, 0)
 
     return eqx.tree_at(
         lambda s: [s.ts, s.ys, s.save_index], save_state, [ts, ys, save_index]
     )

But please do check what I've written here looks correct, I haven't tested this at all! 😁

@alessandrofasse
Copy link
Author

Dear Patrick,

thanks for comments and explanations. This helped me a lot and I continued on the PR. Most of it worked and is good input, however, I again hit now two barriers:

  1. test_saveat_solution_skip_steps here I get different sizes than the ones I expect (in the test). I am 80% sure that my results are correct but can you please confirm that my test should return what I assert?
  2. I changed back the above save_fn to just return subsaveat.fn(tprev, y, args), the problem is that I got errors with the lax.cond returning different shapes. I was also wondering, how do we know the shape of the output without evaluating it? Should we do somewhere a dummy call to extract the shape?

@patrick-kidger patrick-kidger force-pushed the 625-feature_skip_step_saving branch from d96795f to 91c245a Compare June 7, 2025 12:52
@patrick-kidger
Copy link
Owner

patrick-kidger commented Jun 7, 2025

Awesome! I've just taken a look at this and I think it's nearly there. On test_saveat_solution_skip_steps, I identified a few bugfixes: during the loop we were still using maybe_inplace which hardcodes the condition as keep_step rather than keep_step & save_step, and at t1 we were unconditionally saving t but not not y. So I've just pushed a commit that tackles these.

As I'm doing so, I've also squashed your commits and rebased the branch on latest main.

Anyway, test_saveat_solution_skip_steps now passes! Tell me what you think about these changes? I've not yet ran the full test suite, we'll see what the CI says.

@alessandrofasse
Copy link
Author

Thanks for checking and finding this issue! Very nice!

On the ci test_saveat_solution_skip_vs_saveat still fails. Did you push the lastest version that passed on your machine?

@patrick-kidger
Copy link
Owner

I simply did not update that test as well! :) In particular if you take a look at the other test_saveat_solution_skip_steps test that I did update, then you can see that the logic I've selected results in picking different time points (due to an off-by-one change) than we had before. (This was what I found cleanest to implement.)

Can you take a try at updating the test and seeing if things work for your use-case as you expect them to?

@alessandrofasse alessandrofasse marked this pull request as ready for review June 10, 2025 08:52
@alessandrofasse
Copy link
Author

Thats a good idea! I changed now the final test from my side that now passes as expected. I did not run the full test suite, but I hope this will work now. Thanks again for your time and effort

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants