Skip to content

ModelBuilder's predict_posterior returns draws from just one chain? #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
mbjoseph opened this issue Apr 12, 2023 · 3 comments
Closed

Comments

@mbjoseph
Copy link
Contributor

I've been experimenting with ModelBuilder to get a cleaner API for fitting/saving/loading/predicting, but I'm a bit confused by the shape of predict_posterior() output.

I would expect to get samples from the posterior predictive distribution for each chain and draw, but I'm getting what appear to be values from just one chain.

Here's a reproducible example:

import numpy as np
import pandas as pd
from pymc_experimental.tests.test_model_builder import test_ModelBuilder

model = test_ModelBuilder.initial_build_and_fit()
print(model.idata.posterior_predictive["y_model"].shape) # (3, 1000, 100)


x_pred = np.random.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
pred = model.predict_posterior(prediction_data)

print(pred["y_model"].shape) # (1000, 100), but I expect (3, 1000, 100)

I would expect that sampling from the posterior predictive distribution with prediction_data would yield an array of shape (chains, draws, samples). In this case I'd like posterior predictive samples of shape (3, 1000, 100) rather than (1000, 100).

Perhaps this is due to indexing with [0] here, which would appear to select values from just the first chain? https://github.com/pymc-devs/pymc-experimental/blob/5f1c2bbcdd3aceea4a53bbe2db509d8e88e7595d/pymc_experimental/model_builder.py#L350

Are my expectations not aligned with the intended behavior? Happy to have missed something. Thanks! 😃

@twiecki
Copy link
Member

twiecki commented Apr 12, 2023

Ah that looks like a bug, it should use az.extract(). Want to do a PR to fix it?

@mbjoseph
Copy link
Contributor Author

Sure - I'll take a stab at a PR 👍🏼

@mbjoseph
Copy link
Contributor Author

Fixed with #140

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants