10
10
from scipy .stats import pearsonr
11
11
12
12
13
- def predict (idata , rng , X_new = None , size = None , excluded = None ):
13
+ def predict (idata , rng , X = None , size = None , excluded = None ):
14
14
"""
15
15
Generate samples from the BART-posterior.
16
16
17
17
Parameters
18
18
----------
19
- idata: InferenceData
19
+ idata : InferenceData
20
20
InferenceData containing a collection of BART_trees in sample_stats group
21
21
rng: NumPy random generator
22
- X_new : array-like
23
- A new covariate matrix. Use it to obtain out-of-sample predictions
24
- size: int or tuple
22
+ X : array-like
23
+ A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
24
+ out-of-sample predictions.
25
+ size : int or tuple
25
26
Number of samples.
26
- excluded: list
27
+ excluded : list
27
28
indexes of the variables to exclude when computing predictions
28
29
"""
29
30
bart_trees = idata .sample_stats .bart_trees
@@ -39,16 +40,10 @@ def predict(idata, rng, X_new=None, size=None, excluded=None):
39
40
40
41
idx = rng .randint (len (stacked_trees .trees ), size = flatten_size )
41
42
42
- if X_new is None :
43
- pred = np .zeros ((flatten_size , stacked_trees [0 , 0 ].item ().num_observations ))
44
- for ind , p in enumerate (pred ):
45
- for tree in stacked_trees .isel (trees = idx [ind ]).values :
46
- p += tree .predict_output (excluded = excluded )
47
- else :
48
- pred = np .zeros ((flatten_size , X_new .shape [0 ]))
49
- for ind , p in enumerate (pred ):
50
- for tree in stacked_trees .isel (trees = idx [ind ]).values :
51
- p += np .array ([tree .predict_out_of_sample (x , excluded ) for x in X_new ])
43
+ pred = np .zeros ((flatten_size , X .shape [0 ]))
44
+ for ind , p in enumerate (pred ):
45
+ for tree in stacked_trees .isel (trees = idx [ind ]).values :
46
+ p += np .array ([tree .predict (x , excluded ) for x in X ])
52
47
return pred .reshape ((* size , - 1 ))
53
48
54
49
@@ -210,13 +205,13 @@ def plot_dependence(
210
205
for x_i in new_X_i :
211
206
new_X [:, indices_mi ] = X [:, indices_mi ]
212
207
new_X [:, i ] = x_i
213
- y_pred .append (np .mean (predict (idata , rng , X_new = new_X , size = samples ), 1 ))
208
+ y_pred .append (np .mean (predict (idata , rng , X = new_X , size = samples ), 1 ))
214
209
new_X_target .append (new_X_i )
215
210
else :
216
211
for instance in instances :
217
212
new_X = X [idx_s ]
218
213
new_X [:, indices_mi ] = X [:, indices_mi ][instance ]
219
- y_pred .append (np .mean (predict (idata , rng , X_new = new_X , size = samples ), 0 ))
214
+ y_pred .append (np .mean (predict (idata , rng , X = new_X , size = samples ), 0 ))
220
215
new_X_target .append (new_X [:, i ])
221
216
y_mins .append (np .min (y_pred ))
222
217
new_Y .append (np .array (y_pred ).T )
@@ -302,16 +297,21 @@ def plot_dependence(
302
297
return axes
303
298
304
299
305
- def plot_variable_importance (idata , labels = None , figsize = None , samples = 100 , random_seed = None ):
300
+ def plot_variable_importance (
301
+ idata , X = None , labels = None , figsize = None , samples = 100 , random_seed = None
302
+ ):
306
303
"""
307
304
Estimates variable importance from the BART-posterior.
308
305
309
306
Parameters
310
307
----------
311
308
idata: InferenceData
312
309
InferenceData containing a collection of BART_trees in sample_stats group
310
+ X : array-like
311
+ The covariate matrix.
313
312
labels: list
314
- List of the names of the covariates.
313
+ List of the names of the covariates. If X is a DataFrame the names of the covariables will
314
+ be taken from it and this argument will be ignored.
315
315
figsize : tuple
316
316
Figure size. If None it will be defined automatically.
317
317
samples : int
@@ -326,6 +326,10 @@ def plot_variable_importance(idata, labels=None, figsize=None, samples=100, rand
326
326
rng = RandomState (seed = random_seed )
327
327
_ , axes = plt .subplots (2 , 1 , figsize = figsize )
328
328
329
+ if hasattr (X , "columns" ) and hasattr (X , "values" ):
330
+ labels = list (X .columns )
331
+ X = X .values
332
+
329
333
VI = idata .sample_stats ["variable_inclusion" ].mean (("chain" , "draw" )).values
330
334
if labels is None :
331
335
labels = range (len (VI ))
@@ -341,12 +345,12 @@ def plot_variable_importance(idata, labels=None, figsize=None, samples=100, rand
341
345
axes [0 ].set_xlabel ("variable index" )
342
346
axes [0 ].set_ylabel ("relative importance" )
343
347
344
- predicted_all = predict (idata , rng , size = samples , excluded = None )
348
+ predicted_all = predict (idata , rng , X = X , size = samples , excluded = None )
345
349
346
350
EV_mean = np .zeros (len (VI ))
347
351
EV_hdi = np .zeros ((len (VI ), 2 ))
348
352
for idx , subset in enumerate (subsets ):
349
- predicted_subset = predict (idata , rng , size = samples , excluded = subset )
353
+ predicted_subset = predict (idata , rng , X = X , size = samples , excluded = subset )
350
354
pearson = np .zeros (samples )
351
355
for j in range (samples ):
352
356
pearson [j ] = pearsonr (predicted_all [j ], predicted_subset [j ])[0 ]
0 commit comments