diff --git a/src/function/memo.rs b/src/function/memo.rs index 0b1ca7036..9b8e2bb85 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -150,14 +150,11 @@ impl Memo { cycle_heads: &CycleHeads, ) -> bool { let mut retry = false; + let mut hit_cycle = false; for head in cycle_heads { let head_index = head.database_key_index; - if head_index == database_key_index { - continue; - } - let ingredient = zalsa.lookup_ingredient(head_index.ingredient_index()); let cycle_head_kind = ingredient.cycle_head_kind(zalsa, head_index.key_index()); if matches!( @@ -167,7 +164,9 @@ impl Memo { // This cycle is already finalized, so we don't need to wait on it; // keep looping through cycle heads. retry = true; + tracing::trace!("Dependent cycle head {head_index:?} has been finalized."); } else if ingredient.wait_for(zalsa, head_index.key_index()) { + tracing::trace!("Dependent cycle head {head_index:?} has been released (there's a new memo)"); // There's a new memo available for the cycle head; fetch our own // updated memo and see if it's still provisional or if the cycle // has resolved. @@ -176,7 +175,10 @@ impl Memo { // We hit a cycle blocking on the cycle head; this means it's in // our own active query stack and we are responsible to resolve the // cycle, so go ahead and return the provisional memo. - return false; + tracing::debug!( + "Waiting for {head_index:?} results in a cycle, return {database_key_index:?} once all other cycle heads completed to allow the outer cycle to make progress." + ); + hit_cycle = true; } } @@ -186,7 +188,14 @@ impl Memo { // the cycle head (either initial value, or from a later iteration) and should be // returned to caller to allow fixpoint iteration to proceed. (All cases in the loop // above other than "cycle head is self" are either terminal or set `retry`.) - retry + if hit_cycle { + false + } else if retry { + tracing::debug!("Retrying {database_key_index:?}"); + true + } else { + false + } } } diff --git a/tests/parallel/cycle_provisional_depending_on_itself.rs b/tests/parallel/cycle_provisional_depending_on_itself.rs new file mode 100644 index 000000000..ba3645fd5 --- /dev/null +++ b/tests/parallel/cycle_provisional_depending_on_itself.rs @@ -0,0 +1,103 @@ +//! Test a specific cycle scenario: +//! +//! 1. Thread T1 calls `a` which calls `b` +//! 2. Thread T2 calls `c` which calls `b` (blocks on T1 for `b`). The ordering here is important! +//! 3. Thread T1: `b` calls `c` and `a`, both trigger a cycle and Salsa returns a fixpoint initial values (with `c` and `a` as cycle heads). +//! 4. Thread T1: `b` is released (its not in its own cycle heads), `Memo::provisional_retry` blocks blocks on `T2` because `c` is in its cycle heads +//! 5. Thread T2: Iterates `c`, blocks on T1 when reading `a`. +//! 6. Thread T1: Completes the first itaration of `a`, inserting a provisional that depends on `c` and itself (`a`). +//! Starts a new iteration where it executes `b`. Calling `query_a` hits a cycle: +//! +//! 1. `fetch_cold` returns the current provisional for `a` that depends both on `a` (owned by itself) and `c` (has no cycle heads). +//! 2. `Memo::provisional_retry`: Awaits `c` (which has no cycle heads anymore). +//! - Before: it skipped over the dependency key `a` that it is holding itself. It sees that `c` is final, so it retries (which gets us back to 6.1) +//! - Now: Return the provisional memo and allow the outer cycle to resolve. +//! +//! The desired behavior here is that: +//! 1. `t1`: completes the first iteration of b +//! 2. `t2`: completes the cycle `c`, up to where it only depends on `a`, now blocks on `a` +//! 3. `t1`: Iterates on `a`, finalizes the memo + +use crate::sync::thread; +use salsa::CycleRecoveryAction; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(1); + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + query_b(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + // Wait for thread 2 to have entered `query_c`. + tracing::debug!("Wait for signal 1 from thread 2"); + db.wait_for(1); + + // Unblock query_c on thread 2 + db.signal(2); + tracing::debug!("Signal 2 for thread 2"); + + let c_value = query_c(db); + + tracing::debug!("query_b: c = {:?}", c_value); + + let a_value = query_a(db); + + tracing::debug!("query_b: a = {:?}", a_value); + + CycleValue(a_value.0 + 1).min(MAX) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + tracing::debug!("query_c: signaling thread1 to call c"); + db.signal(1); + + tracing::debug!("query_c: waiting for signal"); + // Wait for thread 1 to acquire the lock on query_b + db.wait_for(1); + let b = query_b(db); + tracing::debug!("query_c: b = {:?}", b); + b +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +#[test_log::test] +fn the_test() { + crate::sync::check(|| { + let db_t1 = Knobs::default(); + + let db_t2 = db_t1.clone(); + + let t1 = thread::spawn(move || { + let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered(); + query_a(&db_t1) + }); + let t2 = thread::spawn(move || { + let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered(); + query_c(&db_t2) + }); + + let (r_t1, r_t2) = (t1.join().unwrap(), t2.join().unwrap()); + + assert_eq!((r_t1, r_t2), (MAX, MAX)); + }); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 2309ad270..bd9d56580 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -6,6 +6,7 @@ mod cycle_a_t1_b_t2_fallback; mod cycle_ab_peeping_c; mod cycle_nested_three_threads; mod cycle_panic; +mod cycle_provisional_depending_on_itself; mod parallel_cancellation; mod parallel_join; mod parallel_map;