Skip to content

Commit 0b315a5

Browse files
Run pre-commit with new configuration
1 parent 8d79592 commit 0b315a5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+395
-331
lines changed

notebooks/SARMA Example.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,7 @@
15541554
" hdi_forecast.coords[\"time\"].values,\n",
15551555
" *hdi_forecast.isel(observed_state=0).values.T,\n",
15561556
" alpha=0.25,\n",
1557-
" color=\"tab:blue\"\n",
1557+
" color=\"tab:blue\",\n",
15581558
" )\n",
15591559
"ax.set_title(\"Porcupine Graph of 10-Period Forecasts (parameters estimated on all data)\")\n",
15601560
"plt.show()"
@@ -2692,7 +2692,7 @@
26922692
" *forecast_hdi.values.T,\n",
26932693
" label=\"Forecast 94% HDI\",\n",
26942694
" color=\"tab:orange\",\n",
2695-
" alpha=0.25\n",
2695+
" alpha=0.25,\n",
26962696
")\n",
26972697
"ax.legend()\n",
26982698
"plt.show()"

notebooks/Structural Timeseries Modeling.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,7 +1657,7 @@
16571657
" nile.index,\n",
16581658
" *component_hdi.smoothed_posterior.sel(state=state).values.T,\n",
16591659
" color=\"tab:blue\",\n",
1660-
" alpha=0.15\n",
1660+
" alpha=0.15,\n",
16611661
" )\n",
16621662
" axis.set_title(state.title())"
16631663
]
@@ -1706,7 +1706,7 @@
17061706
" *hdi.smoothed_posterior.sum(dim=\"state\").values.T,\n",
17071707
" color=\"tab:blue\",\n",
17081708
" alpha=0.15,\n",
1709-
" label=\"HDI 94%\"\n",
1709+
" label=\"HDI 94%\",\n",
17101710
")\n",
17111711
"ax.legend()\n",
17121712
"plt.show()"
@@ -2750,7 +2750,7 @@
27502750
"ax.fill_between(\n",
27512751
" blossom_data.index,\n",
27522752
" *hdi_post.predicted_posterior_observed.isel(observed_state=0).values.T,\n",
2753-
" alpha=0.25\n",
2753+
" alpha=0.25,\n",
27542754
")\n",
27552755
"blossom_data.plot(ax=ax)"
27562756
]

pymc_experimental/__init__.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
# limitations under the License.
1414
import logging
1515

16+
from pymc_experimental import distributions, gp, statespace, utils
17+
from pymc_experimental.inference.fit import fit
18+
from pymc_experimental.model.marginal_model import MarginalModel
19+
from pymc_experimental.model.model_api import as_model
1620
from pymc_experimental.version import __version__
1721

1822
_log = logging.getLogger("pmx")
@@ -23,7 +27,13 @@
2327
handler = logging.StreamHandler()
2428
_log.addHandler(handler)
2529

26-
from pymc_experimental import distributions, gp, statespace, utils
27-
from pymc_experimental.inference.fit import fit
28-
from pymc_experimental.model.marginal_model import MarginalModel
29-
from pymc_experimental.model.model_api import as_model
30+
__all__ = [
31+
"distributions",
32+
"gp",
33+
"statespace",
34+
"utils",
35+
"fit",
36+
"MarginalModel",
37+
"as_model",
38+
"__version__",
39+
]

pymc_experimental/distributions/continuous.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
The imports from pymc are not fully replicated here: add imports as necessary.
2020
"""
2121

22-
from typing import Tuple, Union
23-
2422
import numpy as np
2523
import pytensor.tensor as pt
24+
2625
from pymc import ChiSquared, CustomDist
2726
from pymc.distributions import transforms
2827
from pymc.distributions.dist_math import check_parameters
@@ -39,19 +38,19 @@ class GenExtremeRV(RandomVariable):
3938
name: str = "Generalized Extreme Value"
4039
signature = "(),(),()->()"
4140
dtype: str = "floatX"
42-
_print_name: Tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")
41+
_print_name: tuple[str, str] = ("Generalized Extreme Value", "\\operatorname{GEV}")
4342

4443
def __call__(self, mu=0.0, sigma=1.0, xi=0.0, size=None, **kwargs) -> TensorVariable:
4544
return super().__call__(mu, sigma, xi, size=size, **kwargs)
4645

4746
@classmethod
4847
def rng_fn(
4948
cls,
50-
rng: Union[np.random.RandomState, np.random.Generator],
49+
rng: np.random.RandomState | np.random.Generator,
5150
mu: np.ndarray,
5251
sigma: np.ndarray,
5352
xi: np.ndarray,
54-
size: Tuple[int, ...],
53+
size: tuple[int, ...],
5554
) -> np.ndarray:
5655
# Notice negative here, since remainder of GenExtreme is based on Coles parametrization
5756
return stats.genextreme.rvs(c=-xi, loc=mu, scale=sigma, random_state=rng, size=size)

pymc_experimental/distributions/discrete.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pymc as pm
17+
1718
from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
1819
from pymc.distributions.shape_utils import rv_size_is_none
1920
from pytensor import tensor as pt

pymc_experimental/distributions/histogram_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,17 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Dict
17-
1816
import numpy as np
1917
import pymc as pm
18+
2019
from numpy.typing import ArrayLike
2120

2221
__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
2322

2423

2524
def quantile_histogram(
2625
data: ArrayLike, n_quantiles=1000, zero_inflation=False
27-
) -> Dict[str, ArrayLike]:
26+
) -> dict[str, ArrayLike]:
2827
try:
2928
import xhistogram.core
3029
except ImportError as e:
@@ -34,7 +33,7 @@ def quantile_histogram(
3433
import dask.dataframe
3534
except ImportError:
3635
dask = None
37-
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
36+
if dask and isinstance(data, dask.dataframe.Series | dask.dataframe.DataFrame):
3837
data = data.to_dask_array(lengths=True)
3938
if zero_inflation:
4039
zeros = (data == 0).sum(0)
@@ -67,7 +66,7 @@ def quantile_histogram(
6766
return result
6867

6968

70-
def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]:
69+
def discrete_histogram(data: ArrayLike, min_count=None) -> dict[str, ArrayLike]:
7170
try:
7271
import xhistogram.core
7372
except ImportError as e:
@@ -78,7 +77,7 @@ def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]:
7877
except ImportError:
7978
dask = None
8079

81-
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
80+
if dask and isinstance(data, dask.dataframe.Series | dask.dataframe.DataFrame):
8281
data = data.to_dask_array(lengths=True)
8382
mid, count_uniq = np.unique(data, return_counts=True)
8483
if min_count is not None:
@@ -153,7 +152,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
153152
import dask.dataframe
154153
except ImportError:
155154
dask = None
156-
if dask and isinstance(observed, (dask.dataframe.Series, dask.dataframe.DataFrame)):
155+
if dask and isinstance(observed, dask.dataframe.Series | dask.dataframe.DataFrame):
157156
observed = observed.to_dask_array(lengths=True)
158157
if np.issubdtype(observed.dtype, np.integer):
159158
histogram = discrete_histogram(observed, **h_kwargs)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
from pymc_experimental.distributions.multivariate.r2d2m2cp import R2D2M2CP
2+
3+
__all__ = ["R2D2M2CP"]

pymc_experimental/distributions/multivariate/r2d2m2cp.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
from collections import namedtuple
17-
from typing import Sequence, Tuple, Union
17+
from collections.abc import Sequence
1818

1919
import numpy as np
2020
import pymc as pm
@@ -26,8 +26,8 @@
2626
def _psivar2musigma(
2727
psi: pt.TensorVariable,
2828
explained_var: pt.TensorVariable,
29-
psi_mask: Union[pt.TensorLike, None],
30-
) -> Tuple[pt.TensorVariable, pt.TensorVariable]:
29+
psi_mask: pt.TensorLike | None,
30+
) -> tuple[pt.TensorVariable, pt.TensorVariable]:
3131
sign = pt.sign(psi - 0.5)
3232
if psi_mask is not None:
3333
# any computation might be ignored for ~psi_mask
@@ -55,7 +55,7 @@ def _R2D2M2CP_beta(
5555
psi: pt.TensorVariable,
5656
*,
5757
psi_mask,
58-
dims: Union[str, Sequence[str]],
58+
dims: str | Sequence[str],
5959
centered=False,
6060
) -> pt.TensorVariable:
6161
"""R2D2M2CP beta prior.
@@ -120,7 +120,7 @@ def _R2D2M2CP_beta(
120120
def _broadcast_as_dims(
121121
*values: np.ndarray,
122122
dims: Sequence[str],
123-
) -> Union[Tuple[np.ndarray, ...], np.ndarray]:
123+
) -> tuple[np.ndarray, ...] | np.ndarray:
124124
model = pm.modelcontext(None)
125125
shape = [len(model.coords[d]) for d in dims]
126126
ret = tuple(np.broadcast_to(v, shape) for v in values)
@@ -135,7 +135,7 @@ def _psi_masked(
135135
positive_probs_std: pt.TensorLike,
136136
*,
137137
dims: Sequence[str],
138-
) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]:
138+
) -> tuple[pt.TensorLike | None, pt.TensorVariable]:
139139
if not (
140140
isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant)
141141
):
@@ -172,10 +172,10 @@ def _psi_masked(
172172

173173
def _psi(
174174
positive_probs: pt.TensorLike,
175-
positive_probs_std: Union[pt.TensorLike, None],
175+
positive_probs_std: pt.TensorLike | None,
176176
*,
177177
dims: Sequence[str],
178-
) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]:
178+
) -> tuple[pt.TensorLike | None, pt.TensorVariable]:
179179
if positive_probs_std is not None:
180180
mask, psi = _psi_masked(
181181
positive_probs=pt.as_tensor(positive_probs),
@@ -194,9 +194,9 @@ def _psi(
194194

195195

196196
def _phi(
197-
variables_importance: Union[pt.TensorLike, None],
198-
variance_explained: Union[pt.TensorLike, None],
199-
importance_concentration: Union[pt.TensorLike, None],
197+
variables_importance: pt.TensorLike | None,
198+
variance_explained: pt.TensorLike | None,
199+
importance_concentration: pt.TensorLike | None,
200200
*,
201201
dims: Sequence[str],
202202
) -> pt.TensorVariable:
@@ -210,15 +210,15 @@ def _phi(
210210
variables_importance = pt.as_tensor(variables_importance)
211211
if importance_concentration is not None:
212212
variables_importance *= importance_concentration
213-
return pm.Dirichlet("phi", variables_importance, dims=broadcast_dims + [dim])
213+
return pm.Dirichlet("phi", variables_importance, dims=[*broadcast_dims, dim])
214214
elif variance_explained is not None:
215215
if len(model.coords[dim]) <= 1:
216216
raise TypeError("Can't use variance explained with less than two variables")
217217
phi = pt.as_tensor(variance_explained)
218218
else:
219219
phi = _broadcast_as_dims(1.0, dims=dims)
220220
if importance_concentration is not None:
221-
return pm.Dirichlet("phi", importance_concentration * phi, dims=broadcast_dims + [dim])
221+
return pm.Dirichlet("phi", importance_concentration * phi, dims=[*broadcast_dims, dim])
222222
else:
223223
return phi
224224

@@ -233,12 +233,12 @@ def R2D2M2CP(
233233
*,
234234
dims: Sequence[str],
235235
r2: pt.TensorLike,
236-
variables_importance: Union[pt.TensorLike, None] = None,
237-
variance_explained: Union[pt.TensorLike, None] = None,
238-
importance_concentration: Union[pt.TensorLike, None] = None,
239-
r2_std: Union[pt.TensorLike, None] = None,
240-
positive_probs: Union[pt.TensorLike, None] = 0.5,
241-
positive_probs_std: Union[pt.TensorLike, None] = None,
236+
variables_importance: pt.TensorLike | None = None,
237+
variance_explained: pt.TensorLike | None = None,
238+
importance_concentration: pt.TensorLike | None = None,
239+
r2_std: pt.TensorLike | None = None,
240+
positive_probs: pt.TensorLike | None = 0.5,
241+
positive_probs_std: pt.TensorLike | None = None,
242242
centered: bool = False,
243243
) -> R2D2M2CPOut:
244244
"""R2D2M2CP Prior.
@@ -413,7 +413,7 @@ def R2D2M2CP(
413413
year = {2023}
414414
}
415415
"""
416-
if not isinstance(dims, (list, tuple)):
416+
if not isinstance(dims, list | tuple):
417417
dims = (dims,)
418418
*broadcast_dims, dim = dims
419419
input_sigma = pt.as_tensor(input_sigma)
@@ -438,7 +438,7 @@ def R2D2M2CP(
438438
r2,
439439
phi,
440440
psi,
441-
dims=broadcast_dims + [dim],
441+
dims=[*broadcast_dims, dim],
442442
centered=centered,
443443
psi_mask=mask,
444444
)

pymc_experimental/distributions/timeseries.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import warnings
2-
from typing import List, Union
32

43
import numpy as np
54
import pymc as pm
65
import pytensor
76
import pytensor.tensor as pt
7+
88
from pymc.distributions.dist_math import check_parameters
99
from pymc.distributions.distribution import (
1010
Distribution,
@@ -26,7 +26,7 @@
2626
from pytensor.tensor.random.op import RandomVariable
2727

2828

29-
def _make_outputs_info(n_lags: int, init_dist: Distribution) -> List[Union[Distribution, dict]]:
29+
def _make_outputs_info(n_lags: int, init_dist: Distribution) -> list[Distribution | dict]:
3030
"""
3131
Two cases are needed for outputs_info in the scans used by DiscreteMarkovRv. If n_lags = 1, we need to throw away
3232
the first dimension of init_dist_ or else markov_chain will have shape (steps, 1, *batch_size) instead of
@@ -142,7 +142,7 @@ def dist(cls, P=None, logit_P=None, steps=None, init_dist=None, n_lags=1, **kwar
142142

143143
if init_dist is not None:
144144
if not isinstance(init_dist, TensorVariable) or not isinstance(
145-
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
145+
init_dist.owner.op, RandomVariable | SymbolicRandomVariable
146146
):
147147
raise ValueError(
148148
f"Init dist must be a distribution created via the `.dist()` API, "

pymc_experimental/gp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414

1515

1616
from pymc_experimental.gp.latent_approx import KarhunenLoeveExpansion, ProjectedProcess
17+
18+
__all__ = ["KarhunenLoeveExpansion", "ProjectedProcess"]

pymc_experimental/gp/latent_approx.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import partial
15-
from typing import Optional
1615

1716
import numpy as np
1817
import pymc as pm
1918
import pytensor.tensor as pt
19+
2020
from pymc.gp.util import JITTER_DEFAULT, stabilize
2121
from pytensor.tensor.linalg import cholesky, solve_triangular
2222

@@ -33,7 +33,7 @@ class ProjectedProcess(pm.gp.Latent):
3333
## AKA: DTC
3434
def __init__(
3535
self,
36-
n_inducing: Optional[int] = None,
36+
n_inducing: int | None = None,
3737
*,
3838
mean_func=pm.gp.mean.Zero(),
3939
cov_func=pm.gp.cov.Constant(0.0),
@@ -59,20 +59,22 @@ def prior(
5959
self,
6060
name: str,
6161
X: np.ndarray,
62-
X_inducing: Optional[np.ndarray] = None,
62+
X_inducing: np.ndarray | None = None,
6363
jitter: float = JITTER_DEFAULT,
6464
**kwargs,
6565
) -> np.ndarray:
6666
"""
6767
Builds the GP prior with optional inducing points locations.
6868
69-
Parameters:
69+
Parameters
70+
----------
7071
- name: Name for the GP variable.
7172
- X: Input data.
7273
- X_inducing: Optional. Inducing points for the GP.
7374
- jitter: Jitter to ensure numerical stability.
7475
75-
Returns:
76+
Returns
77+
-------
7678
- GP function
7779
"""
7880
# Check if X is a numpy array
@@ -137,7 +139,7 @@ def __init__(
137139
super().__init__(mean_func=mean_func, cov_func=cov_func)
138140

139141
def _build_prior(self, name, X, jitter=1e-6, **kwargs):
140-
mu = self.mean_func(X)
142+
mu = self.mean_func(X) # noqa: F841
141143
Kxx = pm.gp.util.stabilize(self.cov_func(X), jitter)
142144
vals, vecs = pt.linalg.eigh(Kxx)
143145
## NOTE: REMOVED PRECISION CUTOFF

0 commit comments

Comments
 (0)