Skip to content

Commit 9f7fa33

Browse files
committed
Fix hang in nested fixpoint iteration
1 parent 2d4321e commit 9f7fa33

File tree

3 files changed

+120
-6
lines changed

3 files changed

+120
-6
lines changed

src/function/memo.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,11 @@ impl<V> Memo<V> {
163163
cycle_heads: &CycleHeads,
164164
) -> bool {
165165
let mut retry = false;
166+
let mut hit_cycle = false;
166167

167168
for head in cycle_heads {
168169
let head_index = head.database_key_index;
169170

170-
if head_index == database_key_index {
171-
continue;
172-
}
173-
174171
let ingredient = zalsa.lookup_ingredient(head_index.ingredient_index());
175172
let cycle_head_kind = ingredient.cycle_head_kind(zalsa, head_index.key_index());
176173
if matches!(
@@ -180,7 +177,9 @@ impl<V> Memo<V> {
180177
// This cycle is already finalized, so we don't need to wait on it;
181178
// keep looping through cycle heads.
182179
retry = true;
180+
tracing::trace!("Dependent cycle head {head_index:?} has been finalized.");
183181
} else if ingredient.wait_for(zalsa, head_index.key_index()) {
182+
tracing::trace!("Dependent cycle head {head_index:?} has been released (there's a new memo)");
184183
// There's a new memo available for the cycle head; fetch our own
185184
// updated memo and see if it's still provisional or if the cycle
186185
// has resolved.
@@ -189,7 +188,10 @@ impl<V> Memo<V> {
189188
// We hit a cycle blocking on the cycle head; this means it's in
190189
// our own active query stack and we are responsible to resolve the
191190
// cycle, so go ahead and return the provisional memo.
192-
return false;
191+
tracing::debug!(
192+
"Waiting for {head_index:?} results in a cycle, return {database_key_index:?} once all other cycle heads completed to allow the outer cycle to make progress."
193+
);
194+
hit_cycle = true;
193195
}
194196
}
195197

@@ -199,7 +201,14 @@ impl<V> Memo<V> {
199201
// the cycle head (either initial value, or from a later iteration) and should be
200202
// returned to caller to allow fixpoint iteration to proceed. (All cases in the loop
201203
// above other than "cycle head is self" are either terminal or set `retry`.)
202-
retry
204+
if hit_cycle {
205+
false
206+
} else if retry {
207+
tracing::debug!("Retrying {database_key_index:?}");
208+
true
209+
} else {
210+
false
211+
}
203212
}
204213
}
205214

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
//! Test a specific cycle scenario:
2+
//!
3+
//! 1. Thread T1 calls `a` which calls `b`
4+
//! 2. Thread T2 calls `c` which calls `b` (blocks on T1 for `b`). The ordering here is important!
5+
//! 3. Thread T1: `b` calls `c` and `a`, both trigger a cycle and Salsa returns a fixpoint initial values (with `c` and `a` as cycle heads).
6+
//! 4. Thread T1: `b` is released (its not in its own cycle heads), `Memo::provisional_retry` blocks blocks on `T2` because `c` is in its cycle heads
7+
//! 5. Thread T2: Iterates `c`, blocks on T1 when reading `a`.
8+
//! 6. Thread T1: Completes the first itaration of `a`, inserting a provisional that depends on `c` and itself (`a`).
9+
//! Starts a new iteration where it executes `b`. Calling `query_a` hits a cycle:
10+
//!
11+
//! 1. `fetch_cold` returns the current provisional for `a` that depends both on `a` (owned by itself) and `c` (has no cycle heads).
12+
//! 2. `Memo::provisional_retry`: Awaits `c` (which has no cycle heads anymore).
13+
//! - Before: it skipped over the dependency key `a` that it is holding itself. It sees that `c` is final, so it retries (which gets us back to 6.1)
14+
//! - Now: Return the provisional memo and allow the outer cycle to resolve.
15+
//!
16+
//! The desired behavior here is that:
17+
//! 1. `t1`: completes the first iteration of b
18+
//! 2. `t2`: completes the cycle `c`, up to where it only depends on `a`, now blocks on `a`
19+
//! 3. `t1`: Iterates on `a`, finalizes the memo
20+
21+
use salsa::CycleRecoveryAction;
22+
23+
use crate::setup::{Knobs, KnobsDatabase};
24+
25+
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)]
26+
struct CycleValue(u32);
27+
28+
const MIN: CycleValue = CycleValue(0);
29+
const MAX: CycleValue = CycleValue(1);
30+
31+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)]
32+
fn query_a(db: &dyn KnobsDatabase) -> CycleValue {
33+
query_b(db)
34+
}
35+
36+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)]
37+
fn query_b(db: &dyn KnobsDatabase) -> CycleValue {
38+
// Wait for thread 2 to have entered `query_c`.
39+
tracing::debug!("Wait for signal 1 from thread 2");
40+
db.wait_for(1);
41+
42+
// Unblock query_c on thread 2
43+
db.signal(2);
44+
tracing::debug!("Signal 2 for thread 2");
45+
46+
let c_value = query_c(db);
47+
48+
tracing::debug!("query_b: c = {:?}", c_value);
49+
50+
let a_value = query_a(db);
51+
52+
tracing::debug!("query_b: a = {:?}", a_value);
53+
54+
CycleValue(a_value.0 + 1).min(MAX)
55+
}
56+
57+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)]
58+
fn query_c(db: &dyn KnobsDatabase) -> CycleValue {
59+
tracing::debug!("query_c: signaling thread1 to call c");
60+
db.signal(1);
61+
62+
tracing::debug!("query_c: waiting for signal");
63+
// Wait for thread 1 to acquire the lock on query_b
64+
db.wait_for(1);
65+
let b = query_b(db);
66+
tracing::debug!("query_c: b = {:?}", b);
67+
b
68+
}
69+
70+
fn cycle_fn(
71+
_db: &dyn KnobsDatabase,
72+
_value: &CycleValue,
73+
_count: u32,
74+
) -> CycleRecoveryAction<CycleValue> {
75+
CycleRecoveryAction::Iterate
76+
}
77+
78+
fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue {
79+
MIN
80+
}
81+
82+
#[test_log::test]
83+
fn the_test() {
84+
std::thread::scope(|scope| {
85+
let db_t1 = Knobs::default();
86+
87+
let db_t2 = db_t1.clone();
88+
89+
let t1 = scope.spawn(move || {
90+
let _span =
91+
tracing::debug_span!("t1", thread_id = ?std::thread::current().id()).entered();
92+
query_a(&db_t1)
93+
});
94+
let t2 = scope.spawn(move || {
95+
let _span =
96+
tracing::debug_span!("t2", thread_id = ?std::thread::current().id()).entered();
97+
query_c(&db_t2)
98+
});
99+
100+
let (r_t1, r_t2) = (t1.join().unwrap(), t2.join().unwrap());
101+
102+
assert_eq!((r_t1, r_t2), (MAX, MAX));
103+
});
104+
}

tests/parallel/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ mod cycle_a_t1_b_t2_fallback;
55
mod cycle_ab_peeping_c;
66
mod cycle_nested_three_threads;
77
mod cycle_panic;
8+
mod cycle_provisional_depending_on_itself;
89
mod parallel_cancellation;
910
mod parallel_join;
1011
mod parallel_map;

0 commit comments

Comments
 (0)