Skip to content

Commit e1fe369

Browse files
authored
Fix hang in nested fixpoint iteration (#871)
1 parent db4c4df commit e1fe369

File tree

3 files changed

+119
-6
lines changed

3 files changed

+119
-6
lines changed

src/function/memo.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,11 @@ impl<V> Memo<V> {
150150
cycle_heads: &CycleHeads,
151151
) -> bool {
152152
let mut retry = false;
153+
let mut hit_cycle = false;
153154

154155
for head in cycle_heads {
155156
let head_index = head.database_key_index;
156157

157-
if head_index == database_key_index {
158-
continue;
159-
}
160-
161158
let ingredient = zalsa.lookup_ingredient(head_index.ingredient_index());
162159
let cycle_head_kind = ingredient.cycle_head_kind(zalsa, head_index.key_index());
163160
if matches!(
@@ -167,7 +164,9 @@ impl<V> Memo<V> {
167164
// This cycle is already finalized, so we don't need to wait on it;
168165
// keep looping through cycle heads.
169166
retry = true;
167+
tracing::trace!("Dependent cycle head {head_index:?} has been finalized.");
170168
} else if ingredient.wait_for(zalsa, head_index.key_index()) {
169+
tracing::trace!("Dependent cycle head {head_index:?} has been released (there's a new memo)");
171170
// There's a new memo available for the cycle head; fetch our own
172171
// updated memo and see if it's still provisional or if the cycle
173172
// has resolved.
@@ -176,7 +175,10 @@ impl<V> Memo<V> {
176175
// We hit a cycle blocking on the cycle head; this means it's in
177176
// our own active query stack and we are responsible to resolve the
178177
// cycle, so go ahead and return the provisional memo.
179-
return false;
178+
tracing::debug!(
179+
"Waiting for {head_index:?} results in a cycle, return {database_key_index:?} once all other cycle heads completed to allow the outer cycle to make progress."
180+
);
181+
hit_cycle = true;
180182
}
181183
}
182184

@@ -186,7 +188,14 @@ impl<V> Memo<V> {
186188
// the cycle head (either initial value, or from a later iteration) and should be
187189
// returned to caller to allow fixpoint iteration to proceed. (All cases in the loop
188190
// above other than "cycle head is self" are either terminal or set `retry`.)
189-
retry
191+
if hit_cycle {
192+
false
193+
} else if retry {
194+
tracing::debug!("Retrying {database_key_index:?}");
195+
true
196+
} else {
197+
false
198+
}
190199
}
191200
}
192201

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
//! Test a specific cycle scenario:
2+
//!
3+
//! 1. Thread T1 calls `a` which calls `b`
4+
//! 2. Thread T2 calls `c` which calls `b` (blocks on T1 for `b`). The ordering here is important!
5+
//! 3. Thread T1: `b` calls `c` and `a`, both trigger a cycle and Salsa returns a fixpoint initial values (with `c` and `a` as cycle heads).
6+
//! 4. Thread T1: `b` is released (its not in its own cycle heads), `Memo::provisional_retry` blocks blocks on `T2` because `c` is in its cycle heads
7+
//! 5. Thread T2: Iterates `c`, blocks on T1 when reading `a`.
8+
//! 6. Thread T1: Completes the first itaration of `a`, inserting a provisional that depends on `c` and itself (`a`).
9+
//! Starts a new iteration where it executes `b`. Calling `query_a` hits a cycle:
10+
//!
11+
//! 1. `fetch_cold` returns the current provisional for `a` that depends both on `a` (owned by itself) and `c` (has no cycle heads).
12+
//! 2. `Memo::provisional_retry`: Awaits `c` (which has no cycle heads anymore).
13+
//! - Before: it skipped over the dependency key `a` that it is holding itself. It sees that `c` is final, so it retries (which gets us back to 6.1)
14+
//! - Now: Return the provisional memo and allow the outer cycle to resolve.
15+
//!
16+
//! The desired behavior here is that:
17+
//! 1. `t1`: completes the first iteration of b
18+
//! 2. `t2`: completes the cycle `c`, up to where it only depends on `a`, now blocks on `a`
19+
//! 3. `t1`: Iterates on `a`, finalizes the memo
20+
21+
use crate::sync::thread;
22+
use salsa::CycleRecoveryAction;
23+
24+
use crate::setup::{Knobs, KnobsDatabase};
25+
26+
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)]
27+
struct CycleValue(u32);
28+
29+
const MIN: CycleValue = CycleValue(0);
30+
const MAX: CycleValue = CycleValue(1);
31+
32+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)]
33+
fn query_a(db: &dyn KnobsDatabase) -> CycleValue {
34+
query_b(db)
35+
}
36+
37+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)]
38+
fn query_b(db: &dyn KnobsDatabase) -> CycleValue {
39+
// Wait for thread 2 to have entered `query_c`.
40+
tracing::debug!("Wait for signal 1 from thread 2");
41+
db.wait_for(1);
42+
43+
// Unblock query_c on thread 2
44+
db.signal(2);
45+
tracing::debug!("Signal 2 for thread 2");
46+
47+
let c_value = query_c(db);
48+
49+
tracing::debug!("query_b: c = {:?}", c_value);
50+
51+
let a_value = query_a(db);
52+
53+
tracing::debug!("query_b: a = {:?}", a_value);
54+
55+
CycleValue(a_value.0 + 1).min(MAX)
56+
}
57+
58+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)]
59+
fn query_c(db: &dyn KnobsDatabase) -> CycleValue {
60+
tracing::debug!("query_c: signaling thread1 to call c");
61+
db.signal(1);
62+
63+
tracing::debug!("query_c: waiting for signal");
64+
// Wait for thread 1 to acquire the lock on query_b
65+
db.wait_for(1);
66+
let b = query_b(db);
67+
tracing::debug!("query_c: b = {:?}", b);
68+
b
69+
}
70+
71+
fn cycle_fn(
72+
_db: &dyn KnobsDatabase,
73+
_value: &CycleValue,
74+
_count: u32,
75+
) -> CycleRecoveryAction<CycleValue> {
76+
CycleRecoveryAction::Iterate
77+
}
78+
79+
fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue {
80+
MIN
81+
}
82+
83+
#[test_log::test]
84+
fn the_test() {
85+
crate::sync::check(|| {
86+
let db_t1 = Knobs::default();
87+
88+
let db_t2 = db_t1.clone();
89+
90+
let t1 = thread::spawn(move || {
91+
let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered();
92+
query_a(&db_t1)
93+
});
94+
let t2 = thread::spawn(move || {
95+
let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered();
96+
query_c(&db_t2)
97+
});
98+
99+
let (r_t1, r_t2) = (t1.join().unwrap(), t2.join().unwrap());
100+
101+
assert_eq!((r_t1, r_t2), (MAX, MAX));
102+
});
103+
}

tests/parallel/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod cycle_a_t1_b_t2_fallback;
66
mod cycle_ab_peeping_c;
77
mod cycle_nested_three_threads;
88
mod cycle_panic;
9+
mod cycle_provisional_depending_on_itself;
910
mod parallel_cancellation;
1011
mod parallel_join;
1112
mod parallel_map;

0 commit comments

Comments
 (0)