-
-
Notifications
You must be signed in to change notification settings - Fork 152
Added steps=n logic for skipping steps #626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Added steps=n logic for skipping steps #626
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution! I like the look of this. I've left a first round of review comments. :)
Dear @patrick-kidger , I continued on the pull request and I could fix most of the cases. However, I am now stuck in an issue were are a couple of hours I could not find the problem. Im shorts its about Can you maybe assist me here a bit, I think I am misunderstanding about the logic on how states are stored and logged anlong the way. You help would be much appreciated! 🙏 |
Hey, its me again. I also added now another test where I think that the results should be the same, i.e. when using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pushing this so far!
For _save_t1
, then you're right, it gets a bit hairy! I think the idea for reasoning this through is as follows. We had some logic around whether to save the final value, which is a complicated combination of .t1
, .steps
, and the root-finding machinery. The important part of it, however, is when we check subsaveat.steps
: at this point we were checking whether the value at t1
is already going to be saved via steps=True
. This is the logic we need to preserve: whether the value at t1
is going to be saved anyway via steps=True
.
In this PR, that corresponds to the boolean value num_accepted_steps % subsaveat.steps == 0
. If this is True
then we will have already saved the t1
value via .steps
. If it is False
then we will not.
Armed with this intuition, let's see what that looks like in code. Now sadly the following will of course crash with a tracer error, but it highlights what I think the change in logic should be:
def _save_t1(subsaveat, save_state):
if event is None or event.root_finder is None:
- if subsaveat.t1 and not subsaveat.steps:
+ if subsaveat.t1 and not (final_state.num_accepted_steps % subsaveat.steps == 0):
# If subsaveat.steps then the final value is already saved.
save_state = _save(
tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1
)
else:
- if subsaveat.t1 or subsaveat.steps:
+ if subsaveat.t1 or (final_state.num_accepted_steps % subsaveat.steps == 0):
# In this branch we need to replace the last value with tfinal
# and yfinal returned by the root finder also if subsaveat.steps
# because we deleted the last value after the event time above.
save_state = _save(
tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1
)
return save_state
This hits a tracer error, as the value of the modulus might not be known at compile time, and we have to wait until runtime.
So at this point we then need to chase through some annoying JAX details, because our saving has now moved from 'definitely known at compile time' to 'maybe known at compile time but maybe only known at runtime'.
Untested, but I think that means rewriting the _save_t1
function as follows:
def _save_t1(subsaveat, save_state):
if subsaveat.steps == 0:
# We're not saving the final value via `steps`, so we might need to save it via `t1`.
t1_saved_via_steps = False
elif subsaveat.steps == 1:
# We're definitely saving the final value via `steps`, so we can skip saving it via `t1`.
t1_saved_via_steps = True
else:
# We might be saving the final value via `steps`, so we might need to save it via `t1`.
t1_saved_via_steps = final_state.num_accepted_steps % subsaveat.steps == 0
if event is None or event.root_finder is None:
if type(t1_saved_via_steps) is bool:
t1_not_saved_via_steps = not t1_saved_via_steps
else:
t1_not_saved_via_steps = jnp.logical_not(t1_saved_via_steps)
pred = subsaveat.t1 & t1_not_saved_via_steps
else:
# If we're using an event with a root finder, and are saving steps, then we need to write
# the final value here because we deleted the last value after the event time above.
pred = subsaveat.t1 | t1_saved_via_steps
if pred is not False:
save_state = _save(
tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1, pred=pred
)
return save_state
along with the following small diff to _save
:
def _save(
t: FloatScalarLike,
y: PyTree[Array],
args: PyTree,
fn: Callable,
save_state: SaveState,
repeat: int,
pred: bool,
) -> SaveState:
ts = save_state.ts
ys = save_state.ys
save_index = save_state.save_index
ts = lax.dynamic_update_slice_in_dim(
ts, jnp.broadcast_to(t, (repeat,)), save_index, axis=0
)
+ y_to_save = lax.cond(pred, lambda: fn(t, y, args), lambda: jtu.tree_map(lambda ys_: ys_[save_index], ys))
ys = jtu.tree_map(
lambda ys_, y_: lax.dynamic_update_slice_in_dim(
ys_, jnp.broadcast_to(y_, (repeat, *y_.shape)), save_index, axis=0
),
ys,
- fn(t, y, args),
+ y_to_save,
)
save_index = save_index + jnp.where(pred, repeat, 0)
return eqx.tree_at(
lambda s: [s.ts, s.ys, s.save_index], save_state, [ts, ys, save_index]
)
But please do check what I've written here looks correct, I haven't tested this at all! 😁
Dear Patrick, thanks for comments and explanations. This helped me a lot and I continued on the PR. Most of it worked and is good input, however, I again hit now two barriers:
|
d96795f
to
91c245a
Compare
Awesome! I've just taken a look at this and I think it's nearly there. On As I'm doing so, I've also squashed your commits and rebased the branch on latest Anyway, |
Thanks for checking and finding this issue! Very nice! On the ci |
I simply did not update that test as well! :) In particular if you take a look at the other Can you take a try at updating the test and seeing if things work for your use-case as you expect them to? |
Thats a good idea! I changed now the final test from my side that now passes as expected. I did not run the full test suite, but I hope this will work now. Thanks again for your time and effort |
See https://github.com/patrick-kidger/diffrax/issues/