Skip to content

Commit a70fad4

Browse files
authored
reduce tree size before saving them (#33)
1 parent 2343803 commit a70fad4

File tree

4 files changed

+57
-39
lines changed

4 files changed

+57
-39
lines changed

pymc_experimental/bart/pgbart.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,14 @@ class PGBART(ArrayStepShared):
5555
generates_stats = True
5656
stats_dtypes = [{"variable_inclusion": object, "bart_trees": object}]
5757

58-
def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", model=None):
58+
def __init__(
59+
self,
60+
vars=None,
61+
num_particles=40,
62+
max_stages=100,
63+
batch="auto",
64+
model=None,
65+
):
5966
model = modelcontext(model)
6067
initial_values = model.compute_initial_point()
6168
if vars is None:
@@ -135,7 +142,7 @@ def astep(self, _):
135142
# at the end of the algorithm we return one of these particles as the new tree
136143
particles = self.init_particles(tree_id)
137144
# Compute the sum of trees without the old tree, that we are attempting to replace
138-
self.sum_trees_noi = self.sum_trees - particles[0].tree.predict_output()
145+
self.sum_trees_noi = self.sum_trees - particles[0].tree._predict()
139146
# Resample leaf values for particle 1 which is a copy of the old tree
140147
particles[1].sample_leafs(
141148
self.sum_trees,
@@ -191,10 +198,11 @@ def astep(self, _):
191198
# Get the new tree and update
192199
new_particle = np.random.choice(particles, p=normalized_weights)
193200
new_tree = new_particle.tree
194-
self.all_trees[tree_id] = new_tree
201+
195202
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
196203
self.all_particles[tree_id] = new_particle
197-
self.sum_trees = self.sum_trees_noi + new_tree.predict_output()
204+
self.sum_trees = self.sum_trees_noi + new_tree._predict()
205+
self.all_trees[tree_id] = new_tree.trim()
198206

199207
if self.tune:
200208
self.ssv = SampleSplittingVariable(self.alpha_vec)
@@ -239,7 +247,7 @@ def update_weight(self, particle, old=False):
239247
Since the prior is used as the proposal,the weights are updated additively as the ratio of
240248
the new and old log-likelihoods.
241249
"""
242-
new_likelihood = self.likelihood_logp(self.sum_trees_noi + particle.tree.predict_output())
250+
new_likelihood = self.likelihood_logp(self.sum_trees_noi + particle.tree._predict())
243251
if old:
244252
particle.log_weight = new_likelihood
245253
particle.old_likelihood_logp = new_likelihood

pymc_experimental/bart/tree.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,22 +74,28 @@ def delete_node(self, index):
7474
self.idx_leaf_nodes.remove(index)
7575
del self.tree_structure[index]
7676

77-
def predict_output(self, excluded=None):
77+
def trim(self):
78+
a_tree = self.copy()
79+
del a_tree.num_observations
80+
del a_tree.idx_leaf_nodes
81+
for k, v in a_tree.tree_structure.items():
82+
current_node = a_tree[k]
83+
del current_node.depth
84+
if isinstance(current_node, LeafNode):
85+
del current_node.idx_data_points
86+
return a_tree
87+
88+
def _predict(self):
7889
output = np.zeros(self.num_observations)
7990
for node_index in self.idx_leaf_nodes:
8091
leaf_node = self.get_node(node_index)
81-
if excluded is None:
82-
output[leaf_node.idx_data_points] = leaf_node.value
83-
else:
84-
parent_node = leaf_node.get_idx_parent_node()
85-
if self.get_node(parent_node).idx_split_variable not in excluded:
86-
output[leaf_node.idx_data_points] = leaf_node.value
92+
output[leaf_node.idx_data_points] = leaf_node.value
8793

8894
return output.astype(aesara.config.floatX)
8995

90-
def predict_out_of_sample(self, X, excluded=None):
96+
def predict(self, X, excluded=None):
9197
"""
92-
Predict output of tree for an unobserved point x.
98+
Predict output of tree for an (un)observed point X.
9399
94100
Parameters
95101
----------

pymc_experimental/bart/utils.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,21 @@
1010
from scipy.stats import pearsonr
1111

1212

13-
def predict(idata, rng, X_new=None, size=None, excluded=None):
13+
def predict(idata, rng, X=None, size=None, excluded=None):
1414
"""
1515
Generate samples from the BART-posterior.
1616
1717
Parameters
1818
----------
19-
idata: InferenceData
19+
idata : InferenceData
2020
InferenceData containing a collection of BART_trees in sample_stats group
2121
rng: NumPy random generator
22-
X_new : array-like
23-
A new covariate matrix. Use it to obtain out-of-sample predictions
24-
size: int or tuple
22+
X : array-like
23+
A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
24+
out-of-sample predictions.
25+
size : int or tuple
2526
Number of samples.
26-
excluded: list
27+
excluded : list
2728
indexes of the variables to exclude when computing predictions
2829
"""
2930
bart_trees = idata.sample_stats.bart_trees
@@ -39,16 +40,10 @@ def predict(idata, rng, X_new=None, size=None, excluded=None):
3940

4041
idx = rng.randint(len(stacked_trees.trees), size=flatten_size)
4142

42-
if X_new is None:
43-
pred = np.zeros((flatten_size, stacked_trees[0, 0].item().num_observations))
44-
for ind, p in enumerate(pred):
45-
for tree in stacked_trees.isel(trees=idx[ind]).values:
46-
p += tree.predict_output(excluded=excluded)
47-
else:
48-
pred = np.zeros((flatten_size, X_new.shape[0]))
49-
for ind, p in enumerate(pred):
50-
for tree in stacked_trees.isel(trees=idx[ind]).values:
51-
p += np.array([tree.predict_out_of_sample(x, excluded) for x in X_new])
43+
pred = np.zeros((flatten_size, X.shape[0]))
44+
for ind, p in enumerate(pred):
45+
for tree in stacked_trees.isel(trees=idx[ind]).values:
46+
p += np.array([tree.predict(x, excluded) for x in X])
5247
return pred.reshape((*size, -1))
5348

5449

@@ -210,13 +205,13 @@ def plot_dependence(
210205
for x_i in new_X_i:
211206
new_X[:, indices_mi] = X[:, indices_mi]
212207
new_X[:, i] = x_i
213-
y_pred.append(np.mean(predict(idata, rng, X_new=new_X, size=samples), 1))
208+
y_pred.append(np.mean(predict(idata, rng, X=new_X, size=samples), 1))
214209
new_X_target.append(new_X_i)
215210
else:
216211
for instance in instances:
217212
new_X = X[idx_s]
218213
new_X[:, indices_mi] = X[:, indices_mi][instance]
219-
y_pred.append(np.mean(predict(idata, rng, X_new=new_X, size=samples), 0))
214+
y_pred.append(np.mean(predict(idata, rng, X=new_X, size=samples), 0))
220215
new_X_target.append(new_X[:, i])
221216
y_mins.append(np.min(y_pred))
222217
new_Y.append(np.array(y_pred).T)
@@ -302,16 +297,21 @@ def plot_dependence(
302297
return axes
303298

304299

305-
def plot_variable_importance(idata, labels=None, figsize=None, samples=100, random_seed=None):
300+
def plot_variable_importance(
301+
idata, X=None, labels=None, figsize=None, samples=100, random_seed=None
302+
):
306303
"""
307304
Estimates variable importance from the BART-posterior.
308305
309306
Parameters
310307
----------
311308
idata: InferenceData
312309
InferenceData containing a collection of BART_trees in sample_stats group
310+
X : array-like
311+
The covariate matrix.
313312
labels: list
314-
List of the names of the covariates.
313+
List of the names of the covariates. If X is a DataFrame the names of the covariables will
314+
be taken from it and this argument will be ignored.
315315
figsize : tuple
316316
Figure size. If None it will be defined automatically.
317317
samples : int
@@ -326,6 +326,10 @@ def plot_variable_importance(idata, labels=None, figsize=None, samples=100, rand
326326
rng = RandomState(seed=random_seed)
327327
_, axes = plt.subplots(2, 1, figsize=figsize)
328328

329+
if hasattr(X, "columns") and hasattr(X, "values"):
330+
labels = list(X.columns)
331+
X = X.values
332+
329333
VI = idata.sample_stats["variable_inclusion"].mean(("chain", "draw")).values
330334
if labels is None:
331335
labels = range(len(VI))
@@ -341,12 +345,12 @@ def plot_variable_importance(idata, labels=None, figsize=None, samples=100, rand
341345
axes[0].set_xlabel("variable index")
342346
axes[0].set_ylabel("relative importance")
343347

344-
predicted_all = predict(idata, rng, size=samples, excluded=None)
348+
predicted_all = predict(idata, rng, X=X, size=samples, excluded=None)
345349

346350
EV_mean = np.zeros(len(VI))
347351
EV_hdi = np.zeros((len(VI), 2))
348352
for idx, subset in enumerate(subsets):
349-
predicted_subset = predict(idata, rng, size=samples, excluded=subset)
353+
predicted_subset = predict(idata, rng, X=X, size=samples, excluded=subset)
350354
pearson = np.zeros(samples)
351355
for j in range(samples):
352356
pearson[j] = pearsonr(predicted_all[j], predicted_subset[j])[0]

pymc_experimental/tests/test_bart.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ class TestUtils:
7777

7878
def test_predict(self):
7979
rng = RandomState(12345)
80-
pred_all = pmx.bart.utils.predict(self.idata, rng, size=2)
80+
pred_all = pmx.bart.utils.predict(self.idata, rng, X=self.X, size=2)
8181
rng = RandomState(12345)
82-
pred_first = pmx.bart.utils.predict(self.idata, rng, X_new=self.X[:10])
82+
pred_first = pmx.bart.utils.predict(self.idata, rng, X=self.X[:10])
8383

8484
assert_almost_equal(pred_first, pred_all[0, :10], decimal=4)
8585
assert pred_all.shape == (2, 50)
@@ -112,7 +112,7 @@ def test_pdp(self, kwargs):
112112
],
113113
)
114114
def test_vi(self, kwargs):
115-
pmx.bart.utils.plot_variable_importance(self.idata, **kwargs)
115+
pmx.bart.utils.plot_variable_importance(self.idata, X=self.X, **kwargs)
116116

117117
def test_pdp_pandas_labels(self):
118118
pd = pytest.importorskip("pandas")

0 commit comments

Comments
 (0)