diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index 79562f9ed..cc1fa4c08 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -235,7 +235,6 @@ def fit( progressbar: bool = True, random_seed: RandomState = None, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, - *args: Any, **kwargs: Any, ) -> az.InferenceData: """ @@ -269,6 +268,10 @@ def fit( if self.sampler_config is None: self.sampler_config = sampler_config self.build_model(self.model_data, self.model_config) + + sampler_config["progressbar"] = progressbar + sampler_config["random_seed"] = random_seed + with self.model: self.idata = pm.sample(**self.sampler_config, **kwargs) self.idata.extend(pm.sample_prior_predictive())