Skip to content

Added steps=n logic for skipping steps #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 69 additions & 30 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,20 +236,25 @@ def _save(
fn: Callable,
save_state: SaveState,
repeat: int,
pred=True,
) -> SaveState:
ts = save_state.ts
ys = save_state.ys
save_index = save_state.save_index

ts = lax.dynamic_update_slice_in_dim(
ts, jnp.broadcast_to(t, (repeat,)), save_index, axis=0
t_to_save = jnp.broadcast_to(static_select(pred, t, ts[save_index]), (repeat,))
ts = lax.dynamic_update_slice_in_dim(ts, t_to_save, save_index, axis=0)
y_to_save = lax.cond(
pred,
lambda: fn(t, y, args),
lambda: jtu.tree_map(lambda ys_: ys_[save_index], ys),
)
ys = jtu.tree_map(
lambda ys_, y_: lax.dynamic_update_slice_in_dim(
ys_, jnp.broadcast_to(y_, (repeat, *y_.shape)), save_index, axis=0
),
ys,
fn(t, y, args),
y_to_save,
)
save_index = save_index + repeat

Expand Down Expand Up @@ -478,18 +483,32 @@ def _body_fun(_save_state):
save_ts, saveat.subs, save_state, is_leaf=_is_subsaveat
)

def maybe_inplace(i, u, x):
return eqxi.buffer_at_set(x, i, u, pred=keep_step)

def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.steps:
ts = maybe_inplace(save_state.save_index, tprev, save_state.ts)
if subsaveat.steps != 0:
save_step = (num_accepted_steps % subsaveat.steps) == 0
should_save = keep_step & save_step

if subsaveat.steps == 1:
y_to_save = subsaveat.fn(tprev, y, args)
else:
struct = eqx.filter_eval_shape(subsaveat.fn, tprev, y, args)
y_to_save = lax.cond(
eqxi.unvmap_any(should_save),
lambda: subsaveat.fn(tprev, y, args),
lambda: jtu.tree_map(jnp.zeros_like, struct),
)

ts = eqxi.buffer_at_set(
save_state.ts, save_state.save_index, tprev, pred=should_save
)
ys = jtu.tree_map(
ft.partial(maybe_inplace, save_state.save_index),
subsaveat.fn(tprev, y, args),
lambda _y, _ys: eqxi.buffer_at_set(
_ys, save_state.save_index, _y, pred=should_save
),
y_to_save,
save_state.ys,
)
save_index = save_state.save_index + jnp.where(keep_step, 1, 0)
save_index = save_state.save_index + jnp.where(should_save, 1, 0)
save_state = eqx.tree_at(
lambda s: [s.ts, s.ys, s.save_index],
save_state,
Expand All @@ -500,11 +519,14 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
save_state = jtu.tree_map(
save_steps, saveat.subs, save_state, is_leaf=_is_subsaveat
)

if saveat.dense:
dense_ts = maybe_inplace(dense_save_index + 1, tprev, dense_ts)
dense_ts = eqxi.buffer_at_set(
dense_ts, dense_save_index + 1, tprev, pred=keep_step
)
dense_infos = jtu.tree_map(
ft.partial(maybe_inplace, dense_save_index),
lambda _i, _is: eqxi.buffer_at_set(
_is, dense_save_index, _i, pred=keep_step
),
dense_info,
dense_infos,
)
Expand Down Expand Up @@ -800,20 +822,33 @@ def _save_if_t0_equals_t1(subsaveat: SubSaveAt, save_state: SaveState) -> SaveSt
)

def _save_t1(subsaveat, save_state):
if subsaveat.steps == 0:
# We're not saving the final value via `steps`,
# so we might need to save it via `t1`.
t1_saved_via_steps = False
elif subsaveat.steps == 1:
# We're definitely saving the final value via `steps`,
# so we can skip saving it via `t1`.
t1_saved_via_steps = True
else:
# We might be saving the final value via `steps`,
# so we might need to save it via `t1`.
t1_saved_via_steps = final_state.num_accepted_steps % subsaveat.steps == 0
if event is None or event.root_finder is None:
if subsaveat.t1 and not subsaveat.steps:
# If subsaveat.steps then the final value is already saved.
save_state = _save(
tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1
)
if type(t1_saved_via_steps) is bool:
t1_not_saved_via_steps = not t1_saved_via_steps
else:
t1_not_saved_via_steps = jnp.logical_not(t1_saved_via_steps)
pred = subsaveat.t1 & t1_not_saved_via_steps
else:
if subsaveat.t1 or subsaveat.steps:
# In this branch we need to replace the last value with tfinal
# and yfinal returned by the root finder also if subsaveat.steps
# because we deleted the last value after the event time above.
save_state = _save(
tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1
)
# If we're using an event with a root finder, and are saving steps,
# then we need to write the final value here because we deleted the
# last value after the event time above.
pred = subsaveat.t1 | t1_saved_via_steps
if pred is not False:
save_state = _save(
tfinal, yfinal, args, subsaveat.fn, save_state, repeat=1, pred=pred
)
return save_state

save_state = jtu.tree_map(_save_t1, saveat.subs, save_state, is_leaf=_is_subsaveat)
Expand Down Expand Up @@ -1215,16 +1250,20 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
out_size += 1
if subsaveat.ts is not None:
out_size += len(subsaveat.ts)
if subsaveat.steps:
if subsaveat.steps != 0:
# We have no way of knowing how many steps we'll actually end up taking, and
# XLA doesn't support dynamic shapes. So we just have to allocate the
# maximum amount of steps we can possibly take.
if max_steps is None:
raise ValueError(
"`max_steps=None` is incompatible with saving at `steps=True`"
"`max_steps=None` is incompatible with saving at `steps=n`"
)
out_size += max_steps
if subsaveat.t1 and not subsaveat.steps:
out_size += max_steps // subsaveat.steps
if subsaveat.t1 and (
(max_steps is None)
or (subsaveat.steps == 0)
or (max_steps % subsaveat.steps != 0)
):
out_size += 1
saveat_ts_index = 0
save_index = 0
Expand Down
45 changes: 28 additions & 17 deletions diffrax/_saveat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@ def save_y(t, y, args):
return y


def _convert_ts(
ts: None | Sequence[RealScalarLike] | Real[Array, " times"],
) -> Real[Array, " times"] | None:
if ts is None or len(ts) == 0:
return None
else:
return jnp.asarray(ts)


class SubSaveAt(eqx.Module):
"""Used for finer-grained control over what is saved. A PyTree of these should be
passed to `SaveAt(subs=...)`.
Expand All @@ -28,11 +19,29 @@ class SubSaveAt(eqx.Module):
relatively niche feature and most users will probably not need to use `SubSaveAt`.)
"""

t0: bool = False
t1: bool = False
ts: Real[Array, " times"] | None = eqx.field(default=None, converter=_convert_ts)
steps: bool = False
fn: Callable = save_y
t0: bool
t1: bool
ts: Real[Array, " times"] | None
steps: int
fn: Callable

def __init__(
self,
*,
t0: bool = False,
t1: bool = False,
ts: None | Sequence[RealScalarLike] | Real[Array, " times"] = None,
steps: bool | int = 0,
fn: Callable = save_y,
):
self.t0 = t0
self.t1 = t1
self.ts = jnp.asarray(ts) if ts is not None and len(ts) > 0 else None
if isinstance(steps, bool):
self.steps = 1 if steps else 0
else:
self.steps = steps
self.fn = fn

def __check_init__(self):
if not self.t0 and not self.t1 and self.ts is None and not self.steps:
Expand All @@ -44,7 +53,8 @@ def __check_init__(self):
- `t0`: If `True`, save the initial input `y0`.
- `t1`: If `True`, save the output at `t1`.
- `ts`: Some array of times at which to save the output.
- `steps`: If `True`, save the output at every step of the numerical solver.
- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver.
`0` means no saving.
- `fn`: A function `fn(t, y, args)` which specifies what to save into `sol.ys` when
using `t0`, `t1`, `ts` or `steps`. Defaults to `fn(t, y, args) -> y`, so that the
evolving solution is saved. This can be useful to save only statistics of your
Expand All @@ -71,7 +81,7 @@ def __init__(
t0: bool = False,
t1: bool = False,
ts: None | Sequence[RealScalarLike] | Real[Array, " times"] = None,
steps: bool = False,
steps: bool | int = False,
fn: Callable = save_y,
subs: PyTree[SubSaveAt] = None,
dense: bool = False,
Expand Down Expand Up @@ -100,7 +110,8 @@ def __init__(
- `t0`: If `True`, save the initial input `y0`.
- `t1`: If `True`, save the output at `t1`.
- `ts`: Some array of times at which to save the output.
- `steps`: If `True`, save the output at every step of the numerical solver.
- `steps`: If `n>0`, save the output at every `n`th step of the numerical solver.
`0` means no saving.
- `dense`: If `True`, save dense output, that can later be evaluated at any part of
the interval $[t_0, t_1]$ via `sol = diffeqsolve(...); sol.evaluate(...)`.

Expand Down
4 changes: 2 additions & 2 deletions test/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def cond_fn_2(t, y, args, **kwargs):


@pytest.mark.parametrize("steps", (1, 2, 3, 4, 5))
def test_event_save_steps(steps):
def test_event_save_all_steps(steps):
term = diffrax.ODETerm(lambda t, y, args: (1.0, 1.0))
solver = diffrax.Tsit5()
t0 = 0
Expand Down Expand Up @@ -604,7 +604,7 @@ def run(saveat):
num_steps = [steps, steps, steps + 1, steps]
yevents = [(thr, 0), (thr, 0), (thr, 0), (thr, thr)]

for saveat, n, yevent in zip(saveats, num_steps, yevents):
for saveat, n, yevent in zip(saveats, num_steps, yevents, strict=True):
ts, ys = run(saveat)
xs, zs = ys
xevent, zevent = yevent
Expand Down
1 change: 1 addition & 0 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def get_dt_and_controller(level):
diffrax.SaveAt(t1=True),
diffrax.SaveAt(ts=[3.5, 0.7]),
diffrax.SaveAt(steps=True),
diffrax.SaveAt(steps=2),
diffrax.SaveAt(dense=True),
),
)
Expand Down
Loading