Skip to content

Commit 630e970

Browse files
committed
WIP
Signed-off-by: Adam Li <[email protected]>
1 parent 1938c19 commit 630e970

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed

doc/references.bib

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,28 @@ @inproceedings{Park2021conditional
7070
organization = {PMLR}
7171
}
7272

73+
@article{peters2016causal,
74+
title={Causal inference by using invariant prediction: identification and confidence intervals},
75+
author={Peters, Jonas and B{\"u}hlmann, Peter and Meinshausen, Nicolai},
76+
journal={Journal of the Royal Statistical Society Series B: Statistical Methodology},
77+
volume={78},
78+
number={5},
79+
pages={947--1012},
80+
year={2016},
81+
publisher={Oxford University Press}
82+
}
83+
84+
@article{shah2018goodness,
85+
title={Goodness-of-fit tests for high dimensional linear models},
86+
author={Shah, Rajen D and B{\"u}hlmann, Peter},
87+
journal={Journal of the Royal Statistical Society Series B: Statistical Methodology},
88+
volume={80},
89+
number={1},
90+
pages={113--135},
91+
year={2018},
92+
publisher={Oxford University Press}
93+
}
94+
7395
@inproceedings{Runge2018cmi,
7496
title = {Conditional independence testing based on a nearest-neighbor estimator of conditional mutual information},
7597
author = {Runge, Jakob},

pywhy_stats/residual.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""KSample testing for equality of (conditional) distributions.
2+
3+
If the distributions are marginal distributions being compared, then
4+
the test is a standard two-sample test, where the KS statistic, or
5+
Mann-Whitney U statistic, is used to test for equality of distributions.
6+
7+
If the distributions are conditional distributions being compared, then
8+
the test is a conditional two-sample test, where the KS statistic, or
9+
Mann-Whitney U statistic, is used to test for equality of the
10+
residual distributions, where the residuals are computed by regressing
11+
the target variable, Y, on the conditioning variable, X.
12+
13+
The test statistic is described fully in :footcite:`peters2016causal`
14+
and :footcite:`shah2018goodness`.
15+
"""
16+
17+
import numpy as np
18+
from scipy.stats import kstest
19+
20+
from .pvalue_result import PValueResult
21+
22+
23+
def ksample(Y, Z):
24+
stat, pval = kstest(Y[Z == 1], Y[Z == 0])
25+
26+
return PValueResult(pvalue=pval, statistic=stat)
27+
28+
29+
def condksample(Y, Z, X, residual_test="ks", target_predictor=None, combine_pvalues=True):
30+
r"""
31+
Calulates the 2-sample test statistic.
32+
33+
Parameters
34+
----------
35+
Y : ndarray, shape (n_samples,)
36+
Target or outcome features
37+
X : ndarray, shape (n_samples, n_features)
38+
Features to condition on
39+
Z : list or ndarray, shape (n_samples,)
40+
List of zeros and ones indicating which samples belong to
41+
which groups.
42+
target_predictor : sklearn.BaseEstimator, default=None
43+
Method to predict the target given the covariates. If None,
44+
uses a spline regression with 4 knots and degree 3 as
45+
described in :footcite:`peters2016causal`.
46+
residual_test : {"whitney_levene", "ks"}, default="ks"
47+
Test of the residuals between the groups
48+
combine_pvalues: bool, default=True
49+
If True, returns hte minimum of the corrected pvalues.
50+
51+
Returns
52+
-------
53+
pvalue : float
54+
The computed *k*-sample p-value.
55+
r2 : float
56+
r2 score of the regression fit
57+
model : object
58+
Fitted regresion model, if return_model is True
59+
"""
60+
from sklearn.metrics import r2_score
61+
62+
if target_predictor is None:
63+
from sklearn.linear_model import LinearRegression
64+
from sklearn.model_selection import GridSearchCV
65+
from sklearn.pipeline import Pipeline
66+
from sklearn.preprocessing import SplineTransformer
67+
68+
pipe = Pipeline(
69+
steps=[
70+
("spline", SplineTransformer(n_knots=4, degree=3)),
71+
("linear", LinearRegression()),
72+
]
73+
)
74+
param_grid = {
75+
"spline__n_knots": [3, 5, 7, 9],
76+
}
77+
target_predictor = GridSearchCV(
78+
pipe, param_grid, n_jobs=-2, refit=True, scoring="neg_mean_squared_error"
79+
)
80+
81+
target_predictor.fit(X, Y)
82+
Y_pred = target_predictor.predict(X)
83+
residuals = Y - Y_pred
84+
r2 = r2_score(Y, Y_pred)
85+
86+
if residual_test == "whitney_levene":
87+
from scipy.stats import levene, mannwhitneyu
88+
89+
_, mean_pval = mannwhitneyu(
90+
residuals[np.asarray(Z, dtype=bool)],
91+
residuals[np.asarray(1 - Z, dtype=bool)],
92+
)
93+
_, var_pval = levene(
94+
residuals[np.asarray(Z, dtype=bool)],
95+
residuals[np.asarray(1 - Z, dtype=bool)],
96+
)
97+
# Correct for multiple tests
98+
if combine_pvalues:
99+
pval = min(mean_pval * 2, var_pval * 2, 1)
100+
else:
101+
pval = (min(mean_pval * 2, 1), min(var_pval * 2, 1))
102+
elif residual_test == "ks":
103+
from scipy.stats import kstest
104+
105+
_, pval = kstest(
106+
residuals[np.asarray(Z, dtype=bool)],
107+
residuals[np.asarray(1 - Z, dtype=bool)],
108+
)
109+
else:
110+
raise ValueError(f"Test {residual_test} not a valid option.")
111+
112+
return PValueResult(
113+
statistic=r2, pvalue=pval, additional_info={"target_predictor": target_predictor}
114+
)

0 commit comments

Comments
 (0)