Skip to content

Track model logp during sampling #3121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

eigenfoo
Copy link
Member

Closes #2971.

This PR adds the models likelihood to the tracked variables during sampling. The guys over at Stan explain why this is desirable.

Currently, the SMC sampler adds the model logp manually, but it appears that none of the other samplers do. I moved the logic into the BaseTrace class, so that we don't need to repeat this code in all our samplers.

Still a WIP, will probably need a lot more work.

@eigenfoo
Copy link
Member Author

Hm, this caused a lot more tests to fail that I was expecting, I'm actually kind of bewildered. Will take another stab sometime soon.

@junpenglao junpenglao requested a review from aloctavodia July 28, 2018 06:44
@aloctavodia
Copy link
Member

Hi @eigenfoo thanks for taking the time to work on this.

I am writing on a new version of SMC (pretty big refactoring), and hence I suggest to not spend time making changes to smc.py because (I hope) to change that code soon. Probably it will be a better idea to focus on the diagnostic for model logp. What do you think @junpenglao?

@junpenglao
Copy link
Member

In that case, I would suggest dont log the logp during sampling, but computed it post-sampling. This should also break much less codes.h

@eigenfoo
Copy link
Member Author

Can I make a different suggestion? I still think it's more elegant to add the model likelihood as a pm.Deterministic, and keep track of it in the trace during sampling. After all, it's not different from any other deterministic variable, and computing it post-sampling sounds like a quick fix that incurs technical debt.

Instead, could we not add some logic to pm.sample? Right before we sample, we check to see if the model already includes a model likelihood. If not, we add it. We can even add a compute_model_logp flag (that defaults to True) so that the user can specify to not add this variable, if necessary.

@junpenglao @aloctavodia does this sound reasonable to you?

@eigenfoo
Copy link
Member Author

In other words, something along the lines of the PR above ^

@eigenfoo
Copy link
Member Author

It looks like it passes most tests, and the failed tests look like easy fixes. I'll invest more time into that PR if you guys think this is a good idea.

@junpenglao
Copy link
Member

Hmm, if you are going to log it during sampling, treating it as a Deterministic might be suboptimal, as you are computing the logp twice - once in the sampler for MC update, once in the deterministic. If the logp is expensive to compute, the overhead could be significant. If we were to do this properly maybe it is better to treat this as a sampler stats and log it there. I will imagine some difficult for compound step, thoughts?

@eigenfoo
Copy link
Member Author

eigenfoo commented Jul 30, 2018

@junpenglao hm, I see how this might be inefficient. I agree that adding the model logp to the trace while sampling (as a sampling stat) would be better.

I guess I don't have a good enough understanding of PyMC3 internals to understand:

  1. where sampler stats logging takes place, or what the best way to track model logp is. My intuition says the step_methods directory should have the code, but I can't find anything there. Could you point me to something?

  2. a) why compound steps would make this logging any harder. Is there any good resource for me to familiarize myself about compound step sampling?
    b) Does no other sampler stat become harder when using compound steps? If the model logp is really the only sampler stat that would be tricky for compound steps, I would agree with you that computing it post-sampling would be a better choice. Per the zen:

    Special cases aren't special enough to break the rules.
    Although practicality beats purity.

In the meantime, I'll close the previous PR.

@ColCarroll
Copy link
Member

Late to discussion, and I think being able to access the logp would be great, but also want to mention that there are a few utility functions that would be effected by adding a deterministic variable (get_default_varnames, at least), and a bunch of places that don't use that code that might start picking up a deterministic variable.

It sounds like there's a good different way forward, but I wanted to mention this reason too!

@ColCarroll
Copy link
Member

A good place to start looking at sampler statistics is how they are implemented in Hamiltonian Monte Carlo. You can then trace that back through

  • the step method base classes to see how stats are emitted if they are supported, then
  • the sampling code which accepts the stats and sends them to the backend if the backend knows how to handle them, and then
  • NDArray, which knows how to handle all the statistics.

It is a bunch of steps, but trying to add a sampler stat to NUTS or HMC will throw errors until it doesn't, and then it will make more sense 😄 .

@junpenglao
Copy link
Member

junpenglao commented Jul 30, 2018

CompoundStep might be challenging because logp was evaluated multiple times in each Gibbs step, but we only want the last one where logp is evaluated on the final accepted point (not the intermediate ones).
Maybe something like this, looking at
https://github.com/pymc-devs/pymc3/blob/452d5e2eaa74dcdc8fecec95095c2630d2c44ee1/pymc3/step_methods/compound.py#L21-L34

  1. Add logp to sampler stats, now all sampler have stats method
  • Maybe remove the self.generates_stats? since now every sampler has stats properties (need to also make sure if user are writing custom step method without logging the logp, we need to add it for them)
  1. At the end of the for-loop, log the final logp:
for method in self.methods:
    point, state = method.step(point) 
    states.extend(state)
states['model_logp'].extend(states[-1]['logp'])

@eigenfoo
Copy link
Member Author

I see, thanks for all the help! I'll read through and whip up another PR. In the meantime, I'll close this PR.

I also just realized that computing the model logp post-sampling is similarly inefficient: the logp would be computed for MC update, and then recomputed again after sampling is finished. So it looks like tracking it as a sampler stat is the only logical choice here!

Unfortunately I'm a bit tied up this week, but I'll have plenty of time to tinker with this on my upcoming (interminable) flight 😄

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants