-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Track model logp during sampling #3121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hm, this caused a lot more tests to fail that I was expecting, I'm actually kind of bewildered. Will take another stab sometime soon. |
Hi @eigenfoo thanks for taking the time to work on this. I am writing on a new version of SMC (pretty big refactoring), and hence I suggest to not spend time making changes to |
In that case, I would suggest dont log the logp during sampling, but computed it post-sampling. This should also break much less codes.h |
Can I make a different suggestion? I still think it's more elegant to add the model likelihood as a Instead, could we not add some logic to @junpenglao @aloctavodia does this sound reasonable to you? |
In other words, something along the lines of the PR above ^ |
It looks like it passes most tests, and the failed tests look like easy fixes. I'll invest more time into that PR if you guys think this is a good idea. |
Hmm, if you are going to log it during sampling, treating it as a Deterministic might be suboptimal, as you are computing the logp twice - once in the sampler for MC update, once in the deterministic. If the logp is expensive to compute, the overhead could be significant. If we were to do this properly maybe it is better to treat this as a sampler stats and log it there. I will imagine some difficult for compound step, thoughts? |
@junpenglao hm, I see how this might be inefficient. I agree that adding the model logp to the trace while sampling (as a sampling stat) would be better. I guess I don't have a good enough understanding of PyMC3 internals to understand:
In the meantime, I'll close the previous PR. |
Late to discussion, and I think being able to access the It sounds like there's a good different way forward, but I wanted to mention this reason too! |
A good place to start looking at sampler statistics is how they are implemented in Hamiltonian Monte Carlo. You can then trace that back through
It is a bunch of steps, but trying to add a sampler stat to NUTS or HMC will throw errors until it doesn't, and then it will make more sense 😄 . |
CompoundStep might be challenging because logp was evaluated multiple times in each Gibbs step, but we only want the last one where logp is evaluated on the final accepted point (not the intermediate ones).
for method in self.methods:
point, state = method.step(point)
states.extend(state)
states['model_logp'].extend(states[-1]['logp']) |
I see, thanks for all the help! I'll read through and whip up another PR. In the meantime, I'll close this PR. I also just realized that computing the model logp post-sampling is similarly inefficient: the logp would be computed for MC update, and then recomputed again after sampling is finished. So it looks like tracking it as a sampler stat is the only logical choice here! Unfortunately I'm a bit tied up this week, but I'll have plenty of time to tinker with this on my upcoming (interminable) flight 😄 |
Closes #2971.
This PR adds the models likelihood to the tracked variables during sampling. The guys over at Stan explain why this is desirable.
Currently, the
SMC
sampler adds the model logp manually, but it appears that none of the other samplers do. I moved the logic into theBaseTrace
class, so that we don't need to repeat this code in all our samplers.Still a WIP, will probably need a lot more work.