Skip to content

Commit 9de529c

Browse files
Fix query that hangs after do-intervention (#173)
* trying to reporduce the hanging error * first iteration to handle splitted graph by do-intervention. tests needed * reverted to develop as only commented * added functions docstrings and typing * first attempt at tests * fixing tests * flake * Speed up _create_node_functions by taking the first element using next() * Use next(iter(x)) to get the first element * first iterations to address PR comments and discussion: adds default marginals and returns upstream marginal from default ones rather than nans * removing nan import * setting default marginal with query() * lint changes * removed jupyter notebook file from git * lint changes * latest modifications * added my info and fix updates * fisrt attempt pr comment to avoid duplicate call to obtain parents of node * PR comment: avoide duplicate call to get node parents * fixing lint * Refactor _remove_disconnected_nodes() and tidy up codes and docstrings * Add edge one by one (instead of constructing edge list) to make graph construction faster * Linting * last PR comments * Shift add_node() inside the loop for _remove_disconnected_node Co-authored-by: oentaryorj <[email protected]>
1 parent 3efb04a commit 9de529c

File tree

4 files changed

+193
-29
lines changed

4 files changed

+193
-29
lines changed

RELEASE.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
* Fix broken URLs in FAQ documentation, as per #113 and #125
66
* Add a link to `PyGraphviz` installation guide under the installation prerequisites
77
* Fix integer index type checking for timeseries data, as per #74 and #86
8+
* Fix infinite loop at `.query()` after a `.do_intervention()` that splits
9+
the graph into two or more subgraphs, as per #100, #45
810

911
# Release 0.10.0
1012
* Add supervised discretisation strategies using Decision Tree and MDLP algorithms.
@@ -104,6 +106,8 @@ The initial release of CausalNex.
104106

105107
## Thanks for supporting contributions
106108
CausalNex was originally designed by [Paul Beaumont](https://www.linkedin.com/in/pbeaumont/) and [Ben Horsburgh](https://www.linkedin.com/in/benhorsburgh/) to solve challenges they faced in inferencing causality in their project work. This work was later turned into a product thanks to the following contributors:
107-
[Yetunde Dada](https://github.com/yetudada), [Wesley Leong](https://www.linkedin.com/in/wesleyleong/), [Steve Ler](https://www.linkedin.com/in/song-lim-steve-ler-380366106/), [Viktoriia Oliinyk](https://www.linkedin.com/in/victoria-oleynik/), [Roxana Pamfil](https://www.linkedin.com/in/roxana-pamfil-1192053b/), [Nisara Sriwattanaworachai](https://www.linkedin.com/in/nisara-sriwattanaworachai-795b357/), [Nikolaos Tsaousis](https://www.linkedin.com/in/ntsaousis/), [Angel Droth](https://www.linkedin.com/in/angeldroth/), [Zain Patel](https://www.linkedin.com/in/zain-patel/), [Richard Oentaryo](https://www.linkedin.com/in/oentaryo/), and [Shuhei Ishida](https://www.linkedin.com/in/shuhei-i/).
109+
[Yetunde Dada](https://github.com/yetudada), [Wesley Leong](https://www.linkedin.com/in/wesleyleong/), [Steve Ler](https://www.linkedin.com/in/song-lim-steve-ler-380366106/), [Viktoriia Oliinyk](https://www.linkedin.com/in/victoria-oleynik/), [Roxana Pamfil](https://www.linkedin.com/in/roxana-pamfil-1192053b/), [Nisara Sriwattanaworachai](https://www.linkedin.com/in/nisara-sriwattanaworachai-795b357/), [Nikolaos Tsaousis](https://www.linkedin.com/in/ntsaousis/), [Angel Droth](https://www.linkedin.com/in/angeldroth/), [Zain Patel](https://www.linkedin.com/in/zain-patel/), [Richard Oentaryo](https://www.linkedin.com/in/oentaryo/),
110+
[Shuhei Ishida](https://www.linkedin.com/in/shuhei-i/), and [Francesca
111+
Sogaro](https://www.linkedin.com/in/francesca-sogaro/).
108112

109113
CausalNex would also not be possible without the generous sharing from leading researches in the field of causal inference and we are grateful to everyone who advised and supported us, filed issues or helped resolve them, asked and answered questions or simply be part of inspiring discussions.

causalnex/inference/inference.py

Lines changed: 85 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030
3131
``InferenceEngine`` provides tools to make inferences based on interventions and observations.
3232
"""
33-
3433
import copy
3534
import inspect
3635
import re
3736
import types
3837
from typing import Callable, Dict, Hashable, List, Tuple, Union
3938

39+
import networkx as nx
4040
import pandas as pd
4141
from pathos import multiprocessing
4242

@@ -91,7 +91,7 @@ class InferenceEngine:
9191

9292
def __init__(self, bn: BayesianNetwork):
9393
"""
94-
Create a new ``InferenceEngine`` from an existing ``BayesianNetwork``.
94+
Creates a new ``InferenceEngine`` from an existing ``BayesianNetwork``.
9595
9696
It is expected that structure and probability distribution has already been learned
9797
for the ``BayesianNetwork`` that is to be used for inference.
@@ -104,8 +104,8 @@ def __init__(self, bn: BayesianNetwork):
104104
ValueError: if the Bayesian Network contains isolates, or if a variable name is invalid,
105105
or if the CPDs have not been learned yet.
106106
"""
107-
108107
bad_nodes = [node for node in bn.nodes if not re.match("^[0-9a-zA-Z_]+$", node)]
108+
109109
if bad_nodes:
110110
raise ValueError(
111111
"Variable names must match ^[0-9a-zA-Z_]+$ - please fix the "
@@ -119,16 +119,20 @@ def __init__(self, bn: BayesianNetwork):
119119
)
120120

121121
self._cpds = None
122+
self._upstream_cpds = {}
122123

123124
self._create_cpds_dict_bn(bn)
124125
self._generate_domains_bn(bn)
125126
self._generate_bbn()
126127

128+
# TODO: can we do it without a query() call? # pylint: disable=fixme
129+
self._default_marginals = self.query()
130+
127131
def _single_query(
128132
self, observations: Dict[str, Hashable] = None
129133
) -> Dict[str, Dict[Hashable, float]]:
130134
"""
131-
Query the ``BayesianNetwork`` for marginals given some observations.
135+
Queries the ``BayesianNetwork`` for marginals given some observations.
132136
133137
Args:
134138
observations: observed states of nodes in the Bayesian Network.
@@ -139,15 +143,19 @@ def _single_query(
139143
A dictionary of marginal probabilities of the network.
140144
For instance, :math:`P(a=1) = 0.3, P(a=2) = 0.7` -> {a: {1: 0.3, 2: 0.7}}
141145
"""
142-
143146
bbn_results = (
144147
self._bbn.query(**observations) if observations else self._bbn.query()
145148
)
146-
147149
results = {node: dict() for node in self._cpds}
150+
148151
for (node, state), prob in bbn_results.items():
149152
results[node][state] = prob
150153

154+
# the upstream nodes are set to the default marginals based on the
155+
# original cpds of the bn
156+
for detached_node in self._upstream_cpds:
157+
results[detached_node] = self._default_marginals[detached_node]
158+
151159
return results
152160

153161
def query(
@@ -159,7 +167,7 @@ def query(
159167
Dict[str, Dict[Hashable, float]], List[Dict[str, Dict[Hashable, float]]]
160168
]:
161169
"""
162-
Query the ``BayesianNetwork`` for marginals given one or more observations.
170+
Queries the ``BayesianNetwork`` for marginals given one or more observations.
163171
164172
Args:
165173
observations: one or more observations of states of nodes in the Bayesian Network.
@@ -170,21 +178,21 @@ def query(
170178
Returns:
171179
A dictionary or a list of dictionaries of marginal probabilities of the network.
172180
"""
173-
174181
if isinstance(observations, dict) or observations is None:
175182
return self._single_query(observations)
183+
176184
result = []
185+
177186
if parallel:
178187
with multiprocessing.Pool(num_cores) as p:
179188
result = p.map(self._single_query, observations)
180-
181189
else:
182190
for obs in observations:
183191
result.append(self._single_query(obs))
184192

185193
return result
186194

187-
def _do(self, observation: str, state: Dict[Hashable, float]) -> None:
195+
def _do(self, observation: str, state: Dict[Hashable, float]):
188196
"""
189197
Makes an intervention on the Bayesian Network.
190198
@@ -215,10 +223,12 @@ def _do(self, observation: str, state: Dict[Hashable, float]) -> None:
215223
self._cpds[observation] = {s: {(): p} for s, p in state.items()}
216224

217225
def do_intervention(
218-
self, node: str, state: Union[Hashable, Dict[Hashable, float]] = None
219-
) -> None:
226+
self,
227+
node: str,
228+
state: Union[Hashable, Dict[Hashable, float]] = None,
229+
):
220230
"""
221-
Make an intervention on the Bayesian Network.
231+
Makes an intervention on the Bayesian Network.
222232
223233
For instance,
224234
`do_intervention('X', 'x')` will set :math:`P(X=x)` to 1, and :math:`P(X=y)` to 0
@@ -245,36 +255,45 @@ def do_intervention(
245255
state = {s: float(s == state) for s in self._cpds[node]}
246256

247257
self._do(node, state)
258+
259+
# check for presence of separate subgraph after do-intervention
260+
self._remove_disconnected_nodes(node)
248261
self._generate_bbn()
249262

250-
def reset_do(self, observation: str) -> None:
263+
def reset_do(self, observation: str):
251264
"""
252265
Resets any do_interventions that have been applied to the observation.
253266
254267
Args:
255268
observation: observation that will be reset.
256269
"""
257-
258270
self._cpds[observation] = self._cpds_original[observation]
271+
272+
for upstream_node, original_cpds in self._upstream_cpds.items():
273+
self._cpds[upstream_node] = original_cpds
274+
275+
self._upstream_cpds = {}
259276
self._generate_bbn()
260277

261278
def _generate_bbn(self):
262-
"""Re-create the _bbn."""
279+
"""Re-creates the _bbn."""
263280
self._node_functions = self._create_node_functions()
264-
265281
self._bbn = build_bbn(
266282
list(self._node_functions.values()), domains=self._domains
267283
)
268284

269-
def _generate_domains_bn(self, bn):
270-
285+
def _generate_domains_bn(self, bn: BayesianNetwork):
286+
"""Generates domains from Bayesian network"""
271287
self._domains = {
272288
variable: list(cpd.index.values) for variable, cpd in bn.cpds.items()
273289
}
274290

275-
def _create_cpds_dict_bn(self, bn: BayesianNetwork) -> None:
291+
def _create_cpds_dict_bn(self, bn: BayesianNetwork):
276292
"""
277-
Map CPDs in the ``BayesianNetwork`` to required format:
293+
Maps CPDs in the ``BayesianNetwork`` to required format:
294+
295+
Args:
296+
bn: Bayesian network
278297
279298
>>> {"observation":
280299
>>> {"state":
@@ -292,7 +311,6 @@ def _create_cpds_dict_bn(self, bn: BayesianNetwork) -> None:
292311
>>> }
293312
>>> }
294313
"""
295-
296314
lookup = {
297315
variable: {
298316
state: {
@@ -305,7 +323,6 @@ def _create_cpds_dict_bn(self, bn: BayesianNetwork) -> None:
305323
}
306324
for variable, cpd in bn.cpds.items()
307325
}
308-
309326
self._cpds = lookup
310327
self._cpds_original = copy.deepcopy(self._cpds)
311328

@@ -349,26 +366,66 @@ def template() -> float:
349366
code.co_cellvars,
350367
)
351368
template.__name__ = name
352-
353369
return template
354370

355371
def _create_node_functions(self) -> Dict[str, Callable]:
356-
"""Creates all functions required to create a ``BayesianNetwork``."""
372+
"""
373+
Creates all functions required to create a ``BayesianNetwork``.
357374
375+
Returns:
376+
Dictionary of node functions
377+
"""
358378
node_functions = dict()
359379

360380
for node, states in self._cpds.items():
361381
# since we only need condition names, which are consistent across all states,
362382
# then we can inspect the 0th element
363-
states_conditions = list(states.values())[0]
383+
states_conditions = next(iter(states.values()))
364384

365385
# take any state, and get its conditions
366-
state_conditions = list(states_conditions.items())[0]
367-
condition_nodes = [n for n, v in state_conditions[0]]
386+
state_conditions = next(iter(states_conditions.keys()))
387+
condition_nodes = [n for n, v in state_conditions]
368388

369389
node_args = tuple([node] + condition_nodes) # type: Tuple[str]
370390
function_name = "f_{node}".format(node=node)
371391
node_function = self._create_node_function(function_name, node_args)
372392
node_functions[node] = node_function
373393

374394
return node_functions
395+
396+
def _remove_disconnected_nodes(self, var: str):
397+
"""
398+
Identifies and removes from the _cpds the nodes of the bbn which are
399+
part of one or more upstream subgraphs that could have been formed
400+
after a do-intervention.
401+
402+
Uses the attribute _cpds to determine the parents of each node.
403+
Leverages networkX `weakly_connected_component` method to identify the
404+
subgraphs.
405+
406+
For instance, the network A -> B -> C -> D -> E would be split into
407+
two sub networks (A -> B) and (C -> D -> E) if we intervene on
408+
node C.
409+
410+
Args:
411+
var: variable we have intervened on
412+
"""
413+
# construct graph from CPDs
414+
g = nx.DiGraph()
415+
416+
# add nodes as there could be isolates (e.g. A->B->C intervening on B
417+
# makes A an isolate)
418+
for node, states in self._cpds.items():
419+
sample_state = next(iter(states.values()))
420+
parents = next(iter(sample_state.keys()))
421+
g.add_node(node)
422+
423+
for parent, _ in parents:
424+
g.add_edge(parent, node)
425+
426+
# remove nodes in subgraphs which do not contain the intervention node
427+
for sub_graph in nx.weakly_connected_components(g):
428+
if var not in sub_graph:
429+
for node in sub_graph:
430+
self._upstream_cpds[node] = self._cpds[node]
431+
self._cpds.pop(node)

tests/conftest.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,3 +1070,37 @@ def iris_edge_list():
10701070
]
10711071

10721072
return edge_list
1073+
1074+
1075+
@pytest.fixture
1076+
def chain_network() -> BayesianNetwork:
1077+
"""
1078+
This Bayesian Model structure to test do interventions that split graph
1079+
into subgraphs.
1080+
1081+
a → b → c → d → e
1082+
1083+
"""
1084+
1085+
n = 50
1086+
nodes_names = list("abcde")
1087+
random_binary_matrix = (
1088+
np.random.randint(10, size=(n, len(nodes_names))) > 6
1089+
).astype(int)
1090+
df = pd.DataFrame(data=random_binary_matrix, columns=nodes_names)
1091+
1092+
model = StructureModel()
1093+
model.add_edges_from(
1094+
[
1095+
("a", "b"),
1096+
("b", "c"),
1097+
("c", "d"),
1098+
("d", "e"),
1099+
]
1100+
)
1101+
1102+
chain_bn = BayesianNetwork(model)
1103+
chain_bn = chain_bn.fit_node_states(df)
1104+
chain_bn = chain_bn.fit_cpds(df, method="BayesianEstimator", bayes_prior="K2")
1105+
1106+
return chain_bn

tests/test_inference.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,3 +408,72 @@ def test_multi_query(self, bn):
408408
assert single_0 == results_parallel[0]
409409
assert single_1 == results_parallel[1]
410410
assert single_2 == results_parallel[2]
411+
412+
def test_query_after_do_intervention_has_split_graph(self, chain_network):
413+
"""
414+
chain network: a → b → c → d → e
415+
416+
test 1.
417+
- do intervention on node c generates 2 graphs (a → b) and (c → d → e)
418+
- assert the query can be run (it used to hang before)
419+
- assert rest_do works
420+
"""
421+
ie = InferenceEngine(chain_network)
422+
original_margs = ie.query()
423+
424+
var = "c"
425+
state_dict = {0: 1.0, 1: 0.0}
426+
ie.do_intervention(var, state_dict)
427+
# assert the intervention node has indeed the right state
428+
assert ie.query()[var][0] == state_dict[0]
429+
assert ie.query()[var][1] == state_dict[1]
430+
431+
# assert the upstream nodes have the default marginals (no info
432+
# propagates in the upstream graph)
433+
assert ie.query()["a"][0] == original_margs["a"][0]
434+
assert ie.query()["a"][1] == original_margs["a"][1]
435+
assert ie.query()["b"][0] == original_margs["b"][0]
436+
assert ie.query()["b"][1] == original_margs["b"][1]
437+
438+
# assert the _cpds of the upstream nodes are stored correctly
439+
orig_cpds = ie._cpds_original # pylint: disable=protected-access
440+
upstream_cpds = ie._upstream_cpds # pylint: disable=protected-access
441+
assert orig_cpds["a"] == upstream_cpds["a"]
442+
assert orig_cpds["b"] == upstream_cpds["b"]
443+
444+
ie.reset_do(var)
445+
reset_margs = ie.query()
446+
447+
for node in original_margs.keys():
448+
dict_left = original_margs[node]
449+
dict_right = reset_margs[node]
450+
for (kl, kr) in zip(dict_left.keys(), dict_right.keys()):
451+
assert math.isclose(dict_left[kl], dict_right[kr])
452+
453+
# repeating above tests intervening on b, so that there is one single
454+
# isolate
455+
var_b = "b"
456+
state_dict_b = {0: 1.0, 1: 0.0}
457+
ie.do_intervention(var_b, state_dict_b)
458+
# assert the intervention node has indeed the right state
459+
assert ie.query()[var_b][0] == state_dict[0]
460+
assert ie.query()[var_b][1] == state_dict[1]
461+
462+
# assert the upstream nodes have the default marginals (no info
463+
# propagates in the upstream graph)
464+
assert ie.query()["a"][0] == original_margs["a"][0]
465+
assert ie.query()["a"][1] == original_margs["a"][1]
466+
467+
# assert the _cpds of the upstream nodes are stored correctly
468+
orig_cpds = ie._cpds_original # pylint: disable=protected-access
469+
upstream_cpds = ie._upstream_cpds # pylint: disable=protected-access
470+
assert orig_cpds["a"] == upstream_cpds["a"]
471+
472+
ie.reset_do(var_b)
473+
reset_margs = ie.query()
474+
475+
for node in original_margs.keys():
476+
dict_left = original_margs[node]
477+
dict_right = reset_margs[node]
478+
for (kl, kr) in zip(dict_left.keys(), dict_right.keys()):
479+
assert math.isclose(dict_left[kl], dict_right[kr])

0 commit comments

Comments
 (0)