-
-
Notifications
You must be signed in to change notification settings - Fork 152
Introduction of bidirectional vs. unidirectional triggering of events #643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very minor comments only, this LGTM!
diffrax/_event.py
Outdated
if isinstance(bidirect, bool): | ||
bidir_list = [bidirect] * n | ||
else: | ||
bidir_list = list(bidirect) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the logic here does not reflect the pytree structure of cond_fn
. It seems to operate on the flattened representation.
Can we add a test for this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a look and will add a test.
diffrax/_event.py
Outdated
vals, treedef = flatten(cond_fn) | ||
n = len(vals) | ||
|
||
if isinstance(bidirect, bool): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm inclined not to support this case and require exactly matching pytrees. I expect it might be a bit unusual to have multiple events that all have the same crossing requirements, and this seems like an easy footgun. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel inclined to keep it since at least in my case (if not using a single cond_fn for a whole vector as discussed in #637 (comment)), all cond_fns have the same crossing direction. I feel like it could be helpful but you definitely have a bigger overview for what people might use this library and therefore what might be confusing vs helpful. I don't have really strong feelings here.
diffrax/_event.py
Outdated
self, | ||
cond_fn, | ||
root_finder: Optional[optx.AbstractRootFinder] = None, | ||
bidirect: Union[bool, PyTree[bool]] = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps this should be a PyTree[None | bool]
: handling the cases of upcrossing-only, downcrossing-only, or either crossing. WDYT? (No strong feelings from me, since one crossing can easily be constructed from the other.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mhh in my opinion it is easier to understand if its only unidirectional vs. bidirectional and then just negating the cond_fn if we want the opposite type of crossing. Just because i call the variable bidirect and for me that instinctively sounds like a flag variable. I wouldn't without reading the docs guess that None would stand for bidirectional, while True/False stands for the direction of the crossing. So That said maybe we call the variable down_crossing to emphasize the default type so the user knows to change the cond_fn accordingly or call it triggering_direction (or trigger_dir) and then use the PyTree[None | bool]. Actually writing it out I think my main problem was with the name. Lets do The PyTree[None | bool] then and let me adjust the name accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mhh just bc it seems rather annoying comparing PyTree structures that include some None values (for example PyTreeDef([, None, [, ]]) and PyTreeDef([, , [, *]]) should for our case be the same but aren't) I will do the simple bool case instead and name it smth like downcrossing. Edit: Nevermind i can just use is_leaf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I think this LGTM! (Apart from two little nits.)
Can you rebase on the latest dev
branch, and we'll get this in!
diffrax/_event.py
Outdated
treedef_trig = treedef_cond | ||
else: | ||
vals_trig, treedef_trig = flatten(trig_dir, is_leaf=lambda x: x is None) | ||
print(vals_trig, treedef_trig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is spurious :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh sorry ^^! will fix
diffrax/_event.py
Outdated
): | ||
vals_cond, treedef_cond = flatten(cond_fn) | ||
|
||
if isinstance(trig_dir, bool): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this also include the case of trig_dir = None
?
I think i can close this, since it has been merged through #655. |
Changed Event to support flags for unidirectional vs. bidirectional triggering.