Skip to content

Do not re-verify already verified memoized value in cycle verification #851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions src/function/maybe_changed_after.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ pub enum VerifyResult {
///
/// The inner value tracks whether the memo or any of its dependencies have an
/// accumulated value.
///
/// Don't mark memos verified until we've iterated the full cycle to ensure no inputs changed
/// when encountering this variant.
Unchanged(InputAccumulatedValues),
}

Expand All @@ -37,10 +34,6 @@ impl VerifyResult {
pub(crate) fn unchanged() -> Self {
Self::Unchanged(InputAccumulatedValues::Empty)
}

pub(crate) const fn is_unchanged(&self) -> bool {
matches!(self, Self::Unchanged(_))
}
}

impl<C> IngredientImpl<C>
Expand Down Expand Up @@ -146,11 +139,11 @@ where
// Check if the inputs are still valid. We can just compare `changed_at`.
let deep_verify =
self.deep_verify_memo(db, zalsa, old_memo, database_key_index, cycle_heads);
if deep_verify.is_unchanged() {
if let VerifyResult::Unchanged(accumulated_inputs) = deep_verify {
return Some(if old_memo.revisions.changed_at > revision {
VerifyResult::Changed
} else {
VerifyResult::Unchanged(old_memo.revisions.accumulated_inputs.load())
VerifyResult::Unchanged(accumulated_inputs)
});
}

Expand Down Expand Up @@ -316,18 +309,18 @@ where
memo = memo.tracing_debug()
);

if memo.revisions.cycle_heads.is_empty() {
let cycle_heads = &memo.revisions.cycle_heads;
if cycle_heads.is_empty() {
return true;
}

let cycle_heads = &memo.revisions.cycle_heads;

zalsa_local.with_query_stack(|stack| {
cycle_heads.iter().all(|cycle_head| {
stack.iter().rev().any(|query| {
query.database_key_index == cycle_head.database_key_index
&& query.iteration_count() == cycle_head.iteration_count
})
stack
.iter()
.rev()
.find(|query| query.database_key_index == cycle_head.database_key_index)
.is_some_and(|query| query.iteration_count() == cycle_head.iteration_count)
})
})
}
Expand Down Expand Up @@ -402,16 +395,18 @@ where
return VerifyResult::Changed;
}

let dyn_db = db.as_dyn_database();

let mut last_verified_at = old_memo.verified_at.load();
let mut first_iteration = true;
'cycle: loop {
let mut inputs = InputAccumulatedValues::Empty;
// Fully tracked inputs? Iterate over the inputs and check them, one by one.
//
// NB: It's important here that we are iterating the inputs in the order that
// they executed. It's possible that if the value of some input I0 is no longer
// valid, then some later input I1 might never have executed at all, so verifying
// it is still up to date is meaningless.
let last_verified_at = old_memo.verified_at.load();
let mut inputs = InputAccumulatedValues::Empty;
let dyn_db = db.as_dyn_database();
for &edge in edges.input_outputs.iter() {
match edge {
QueryEdge::Input(dependency_index) => {
Expand All @@ -421,9 +416,7 @@ where
last_verified_at,
cycle_heads,
) {
VerifyResult::Changed => {
break 'cycle VerifyResult::Changed;
}
VerifyResult::Changed => break 'cycle VerifyResult::Changed,
VerifyResult::Unchanged(input_accumulated) => {
inputs |= input_accumulated;
}
Expand Down Expand Up @@ -477,9 +470,17 @@ where
// from cycle heads. We will handle our own memo (and the rest of our cycle) on a
// future iteration; first the outer cycle head needs to verify itself.

let in_heads = cycle_heads.remove(&database_key_index);
let was_in_heads = cycle_heads.remove(&database_key_index);
let heads_non_empty = !cycle_heads.is_empty();
if heads_non_empty {
// case 2 / 4
break 'cycle VerifyResult::Unchanged(inputs);
} else if !first_iteration {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the elif !first_iteration a semantic change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is what causes the test changes you annotated as well below. I made that commit separately precisely to discuss (forgot to make a self review to start that discussion sorry). The change was made given the comment above declaring the 4 cases. Judging by case 3's description (the only one that actually continues the loop, we want to iterate it once more but we always re-add ourselves as the cylce head on this second iteration, meaning we will re-validate us once more (that is the cause for the duplicate DidValidateMemoizedValue events) in the second iteration and then continue again for a third time before finally bailing out.

Given the wording of the comment I interpreted this as that we want to merely run this a second time but no more. This is probably something that @carljm should have a look at.

I can split this PR into its two commits if the first one looks good as is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty big speed up on our benchmarks https://codspeed.io/astral-sh/ruff/branches/micha%2Ffixpoint-changes

It would be nice if the code reflected the 1...4 ordering a little more. Especially if 1. could be first but this is a nit (which should also be the most likely branch taken at runtime)

// 3 (second loop turn)
break 'cycle VerifyResult::Unchanged(inputs);
} else {
last_verified_at = zalsa.current_revision();

if cycle_heads.is_empty() {
old_memo.mark_as_verified(zalsa, database_key_index);
old_memo.revisions.accumulated_inputs.store(inputs);

Expand All @@ -490,11 +491,15 @@ where
.store(true, Ordering::Relaxed);
}

if in_heads {
if was_in_heads {
first_iteration = false;
// case 3
continue 'cycle;
} else {
// case 1
break 'cycle VerifyResult::Unchanged(inputs);
}
}
break 'cycle VerifyResult::Unchanged(inputs);
}
}
}
Expand Down
17 changes: 6 additions & 11 deletions tests/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,6 @@ fn cycle_unchanged() {
[
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we see behavior changes here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is removing a redundant verification (min_iterate(1) was already verified above), which is the point of the change as I understand it

]"#]]);

a.assert_value(&db, 45);
Expand Down Expand Up @@ -929,9 +928,7 @@ fn cycle_unchanged_nested() {
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(4)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
]"#]]);

a.assert_value(&db, 45);
Expand Down Expand Up @@ -992,14 +989,12 @@ fn cycle_unchanged_nested_intertwined() {
b.assert_value(&db, 60);

db.assert_logs(expect![[r#"
[
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(4)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
]"#]]);
[
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(4)) })",
"salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })",
]"#]]);

a.assert_value(&db, 45);
}
Expand Down
1 change: 0 additions & 1 deletion tests/cycle_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ fn revalidate_no_changes() {
"salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(402)) })",
"salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(403)) })",
"salsa_event(DidValidateMemoizedValue { database_key: query_a(Id(0)) })",
"salsa_event(DidValidateMemoizedValue { database_key: query_b(Id(0)) })",
]"#]]);
}

Expand Down