Skip to content

Commit b27e392

Browse files
authored
fix: Access to tracked-struct that was freed during fixpoint (#817)
* Add test for untracked read on tracked struct created in previous cycle * Initial fix * Restrict seeding to memos from the same revision * Reduce changes * seed_outputs * Cleanup test * Add assertion * Try * Try merging outputs after query executed * Assert logs from first execution * Enable trace level logging * Use `FxIndexSet` in `diff_outputs` * Log more events * Cleanup * Append outputs only once
1 parent 0cbe7f8 commit b27e392

File tree

7 files changed

+206
-19
lines changed

7 files changed

+206
-19
lines changed

src/function/diff_outputs.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::function::memo::Memo;
22
use crate::function::{Configuration, IngredientImpl};
3-
use crate::hash::FxHashSet;
3+
use crate::hash::FxIndexSet;
44
use crate::zalsa::Zalsa;
55
use crate::zalsa_local::QueryRevisions;
66
use crate::{AsDynDatabase as _, Database, DatabaseKeyIndex, Event, EventKind};
@@ -27,22 +27,28 @@ where
2727
provisional: bool,
2828
) {
2929
// Iterate over the outputs of the `old_memo` and put them into a hashset
30-
let mut old_outputs: FxHashSet<_> = old_memo.revisions.origin.outputs().collect();
30+
let mut old_outputs: FxIndexSet<_> = old_memo.revisions.origin.outputs().collect();
31+
32+
if old_outputs.is_empty() {
33+
return;
34+
}
3135

3236
// Iterate over the outputs of the current query
3337
// and remove elements from `old_outputs` when we find them
3438
for new_output in revisions.origin.outputs() {
35-
old_outputs.remove(&new_output);
39+
old_outputs.swap_remove(&new_output);
3640
}
3741

38-
if !old_outputs.is_empty() {
39-
// Remove the outputs that are no longer present in the current revision
40-
// to prevent that the next revision is seeded with a id mapping that no longer exists.
41-
revisions.tracked_struct_ids.retain(|&k, &mut value| {
42-
!old_outputs.contains(&DatabaseKeyIndex::new(k.ingredient_index(), value))
43-
});
42+
if old_outputs.is_empty() {
43+
return;
4444
}
4545

46+
// Remove the outputs that are no longer present in the current revision
47+
// to prevent that the next revision is seeded with an id mapping that no longer exists.
48+
revisions.tracked_struct_ids.retain(|&k, &mut value| {
49+
!old_outputs.contains(&DatabaseKeyIndex::new(k.ingredient_index(), value))
50+
});
51+
4652
for old_output in old_outputs {
4753
Self::report_stale_output(zalsa, db, key, old_output, provisional);
4854
}

src/function/execute.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,18 @@ where
5959
// Query was not previously executed, or value is potentially
6060
// stale, or value is absent. Let's execute!
6161
let mut new_value = C::execute(db, C::id_to_input(db, id));
62+
63+
if let Some(old_memo) = opt_old_memo {
64+
// Copy over all outputs from a previous iteration.
65+
// This is necessary to ensure that tracked struct created during the previous iteration
66+
// (and are owned by the query) are alive even if the query in this iteration no longer creates them.
67+
// The query not re-creating the tracked struct doesn't guarantee that there
68+
// aren't any other queries depending on it.
69+
if old_memo.may_be_provisional() && old_memo.verified_at.load() == revision_now {
70+
active_query.append_outputs(old_memo.revisions.origin.outputs());
71+
}
72+
}
73+
6274
let mut revisions = active_query.pop();
6375

6476
// Did the new result we got depend on our own provisional value, in a cycle?

src/tracked_struct.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,10 @@ where
406406
let current_revision = zalsa.current_revision();
407407
match zalsa_local.tracked_struct_id(&identity) {
408408
Some(id) => {
409+
let index = self.database_key_index(id);
410+
tracing::trace!("Reuse tracked struct {id:?}", id = index);
409411
// The struct already exists in the intern map.
410-
zalsa_local.add_output(self.database_key_index(id));
412+
zalsa_local.add_output(index);
411413
self.update(zalsa, current_revision, id, &current_deps, fields);
412414
FromId::from_id(id)
413415
}
@@ -416,6 +418,7 @@ where
416418
// This is a new tracked struct, so create an entry in the struct map.
417419
let id = self.allocate(zalsa, zalsa_local, current_revision, &current_deps, fields);
418420
let key = self.database_key_index(id);
421+
tracing::trace!("Allocated new tracked struct {id:?}", id = key);
419422
zalsa_local.add_output(key);
420423
zalsa_local.store_tracked_struct_id(identity, id);
421424
FromId::from_id(id)

src/zalsa_local.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,22 @@ impl ActiveQueryGuard<'_> {
490490
})
491491
}
492492

493+
/// Append the given `outputs` to the query's output list.
494+
pub(crate) fn append_outputs<I>(&self, outputs: I)
495+
where
496+
I: IntoIterator<Item = DatabaseKeyIndex> + UnwindSafe,
497+
{
498+
self.local_state.with_query_stack(|stack| {
499+
#[cfg(debug_assertions)]
500+
assert_eq!(stack.len(), self.push_len);
501+
let frame = stack.last_mut().unwrap();
502+
503+
for output in outputs {
504+
frame.add_output(output);
505+
}
506+
})
507+
}
508+
493509
/// Invoked when the query has successfully completed execution.
494510
fn complete(self) -> QueryRevisions {
495511
let query = self.local_state.with_query_stack(|stack| {

tests/common/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ pub trait LogDatabase: HasLogger + Database {
4141
/// it is meant to be run from outside any tracked functions.
4242
fn assert_logs_len(&self, expected: usize) {
4343
let logs = std::mem::take(&mut *self.logger().logs.lock().unwrap());
44-
assert_eq!(logs.len(), expected);
44+
assert_eq!(logs.len(), expected, "Actual logs: {logs:#?}");
4545
}
4646
}
4747

tests/cycle_output.rs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ impl salsa::Database for Database {
102102
| salsa::EventKind::DidValidateMemoizedValue { .. } => {
103103
self.push_log(format!("salsa_event({:?})", event.kind));
104104
}
105-
_ => {}
105+
salsa::EventKind::WillCheckCancellation => {}
106+
_ => {
107+
self.push_log(format!("salsa_event({:?})", event.kind));
108+
}
106109
}
107110
}
108111
}
@@ -127,7 +130,7 @@ fn revalidate_no_changes() {
127130
assert_eq!(query_c(&db, c_input), 10);
128131
assert_eq!(query_b(&db, ab_input), 3);
129132

130-
db.assert_logs_len(11);
133+
db.assert_logs_len(15);
131134

132135
// trigger a new revision, but one that doesn't touch the query_a/query_b cycle
133136
c_input.set_value(&mut db).to(20);
@@ -136,9 +139,12 @@ fn revalidate_no_changes() {
136139

137140
db.assert_logs(expect![[r#"
138141
[
139-
"salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(401)) })",
142+
"salsa_event(DidSetCancellationFlag)",
143+
"salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(403)) })",
144+
"salsa_event(DidReinternValue { key: Configuration(Id(800)), revision: R2 })",
140145
"salsa_event(DidValidateMemoizedValue { database_key: query_d(Id(800)) })",
141146
"salsa_event(DidValidateMemoizedValue { database_key: query_b(Id(0)) })",
147+
"salsa_event(DidReinternValue { key: Configuration(Id(800)), revision: R2 })",
142148
"salsa_event(DidValidateMemoizedValue { database_key: query_a(Id(0)) })",
143149
"salsa_event(DidValidateMemoizedValue { database_key: query_b(Id(0)) })",
144150
]"#]]);
@@ -154,7 +160,7 @@ fn revalidate_with_change_after_output_read() {
154160

155161
assert_eq!(query_b(&db, ab_input), 3);
156162

157-
db.assert_logs_len(10);
163+
db.assert_logs_len(14);
158164

159165
// trigger a new revision that changes the output of query_d
160166
d_input.set_value(&mut db).to(20);
@@ -163,15 +169,29 @@ fn revalidate_with_change_after_output_read() {
163169

164170
db.assert_logs(expect![[r#"
165171
[
166-
"salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(401)) })",
172+
"salsa_event(DidSetCancellationFlag)",
173+
"salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(403)) })",
174+
"salsa_event(DidReinternValue { key: Configuration(Id(800)), revision: R2 })",
167175
"salsa_event(WillExecute { database_key: query_d(Id(800)) })",
168176
"salsa_event(WillExecute { database_key: query_a(Id(0)) })",
169-
"salsa_event(WillExecute { database_key: read_value(Id(400)) })",
177+
"salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(400)) })",
178+
"salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(403)) })",
179+
"salsa_event(DidDiscard { key: Output(Id(403)) })",
180+
"salsa_event(DidDiscard { key: read_value(Id(403)) })",
181+
"salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(402)) })",
182+
"salsa_event(DidDiscard { key: Output(Id(402)) })",
183+
"salsa_event(DidDiscard { key: read_value(Id(402)) })",
184+
"salsa_event(WillDiscardStaleOutput { execute_key: query_a(Id(0)), output_key: Output(Id(401)) })",
185+
"salsa_event(DidDiscard { key: Output(Id(401)) })",
186+
"salsa_event(DidDiscard { key: read_value(Id(401)) })",
170187
"salsa_event(WillExecute { database_key: query_b(Id(0)) })",
188+
"salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: 1, fell_back: false })",
171189
"salsa_event(WillExecute { database_key: query_a(Id(0)) })",
172-
"salsa_event(WillExecute { database_key: read_value(Id(401)) })",
190+
"salsa_event(WillExecute { database_key: read_value(Id(403)) })",
191+
"salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: 2, fell_back: false })",
173192
"salsa_event(WillExecute { database_key: query_a(Id(0)) })",
174-
"salsa_event(WillExecute { database_key: read_value(Id(400)) })",
193+
"salsa_event(WillExecute { database_key: read_value(Id(402)) })",
194+
"salsa_event(WillIterateCycle { database_key: query_b(Id(0)), iteration_count: 3, fell_back: false })",
175195
"salsa_event(WillExecute { database_key: query_a(Id(0)) })",
176196
"salsa_event(WillExecute { database_key: read_value(Id(401)) })",
177197
]"#]]);

tests/cycle_tracked_own_input.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
//! Test for cycle handling where a tracked struct created in the first revision
2+
//! is stored in the final value of the cycle but isn't recreated in the second
3+
//! iteration of the creating query.
4+
//!
5+
//! It's important that the creating query in the last iteration keeps *owning* the
6+
//! tracked struct from the previous iteration, otherwise Salsa will discard it
7+
//! and dereferencing the value panics.
8+
mod common;
9+
10+
use crate::common::{EventLoggerDatabase, LogDatabase};
11+
use expect_test::expect;
12+
use salsa::{CycleRecoveryAction, Database, Setter};
13+
14+
#[salsa::input(debug)]
15+
struct ClassNode {
16+
name: String,
17+
type_params: Option<TypeParamNode>,
18+
}
19+
20+
#[salsa::input(debug)]
21+
struct TypeParamNode {
22+
name: String,
23+
constraint: Option<ClassNode>,
24+
}
25+
26+
#[salsa::interned(debug)]
27+
struct Class<'db> {
28+
name: String,
29+
type_params: Option<TypeParam<'db>>,
30+
}
31+
32+
#[salsa::tracked(debug)]
33+
struct TypeParam<'db> {
34+
name: String,
35+
constraint: Option<Type<'db>>,
36+
}
37+
38+
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update)]
39+
enum Type<'db> {
40+
Class(Class<'db>),
41+
Unknown,
42+
}
43+
44+
impl Type<'_> {
45+
fn class(&self) -> Option<Class> {
46+
match self {
47+
Type::Class(class) => Some(*class),
48+
Type::Unknown => None,
49+
}
50+
}
51+
}
52+
53+
#[salsa::tracked(cycle_fn=infer_class_recover, cycle_initial=infer_class_initial)]
54+
fn infer_class<'db>(db: &'db dyn salsa::Database, node: ClassNode) -> Type<'db> {
55+
Type::Class(Class::new(
56+
db,
57+
node.name(db),
58+
node.type_params(db).map(|tp| infer_type_param(db, tp)),
59+
))
60+
}
61+
62+
#[salsa::tracked]
63+
fn infer_type_param<'db>(db: &'db dyn salsa::Database, node: TypeParamNode) -> TypeParam<'db> {
64+
if let Some(constraint) = node.constraint(db) {
65+
// Reuse the type param from the class if any.
66+
// The example is a bit silly, because it's a reduction of what we have in Astral's type checker
67+
// but including all the details doesn't make sense. What's important for the test is
68+
// that this query doesn't re-create the `TypeParam` tracked struct in the second iteration
69+
// and instead returns the one from the first iteration which
70+
// then is returned in the overall result (Class).
71+
match infer_class(db, constraint) {
72+
Type::Class(class) => class
73+
.type_params(db)
74+
.unwrap_or_else(|| TypeParam::new(db, node.name(db), Some(Type::Unknown))),
75+
Type::Unknown => TypeParam::new(db, node.name(db), Some(Type::Unknown)),
76+
}
77+
} else {
78+
TypeParam::new(db, node.name(db), None)
79+
}
80+
}
81+
82+
fn infer_class_initial(_db: &dyn Database, _node: ClassNode) -> Type {
83+
Type::Unknown
84+
}
85+
86+
fn infer_class_recover<'db>(
87+
_db: &'db dyn Database,
88+
_type: &Type<'db>,
89+
_count: u32,
90+
_inputs: ClassNode,
91+
) -> CycleRecoveryAction<Type<'db>> {
92+
CycleRecoveryAction::Iterate
93+
}
94+
95+
#[test]
96+
fn main() {
97+
let mut db = EventLoggerDatabase::default();
98+
99+
// Class with a type parameter that's constrained to itself.
100+
// class Test[T: Test]: ...
101+
let class_node = ClassNode::new(&db, "Test".to_string(), None);
102+
let type_param_node = TypeParamNode::new(&db, "T".to_string(), Some(class_node));
103+
class_node
104+
.set_type_params(&mut db)
105+
.to(Some(type_param_node));
106+
107+
let ty = infer_class(&db, class_node);
108+
109+
db.assert_logs(expect![[r#"
110+
[
111+
"DidSetCancellationFlag",
112+
"WillCheckCancellation",
113+
"WillExecute { database_key: infer_class(Id(0)) }",
114+
"WillCheckCancellation",
115+
"WillExecute { database_key: infer_type_param(Id(400)) }",
116+
"WillCheckCancellation",
117+
"DidInternValue { key: Class(Id(c00)), revision: R2 }",
118+
"WillIterateCycle { database_key: infer_class(Id(0)), iteration_count: 1, fell_back: false }",
119+
"WillCheckCancellation",
120+
"WillExecute { database_key: infer_type_param(Id(400)) }",
121+
"WillCheckCancellation",
122+
]"#]]);
123+
124+
let class = ty.class().unwrap();
125+
let type_param = class.type_params(&db).unwrap();
126+
127+
// Now read the name from the type param struct that was created in the first iteration of
128+
// `infer_type_param`. This should not panic!
129+
assert_eq!(type_param.name(&db), "T");
130+
}

0 commit comments

Comments
 (0)