Skip to content

Commit 2b97177

Browse files
torfjeldegithub-actions[bot]mhauruyebai
authored
Fix for LogDensityFunction (#621)
* lazily resolve context to avoid overriding the model context * bump patch version * Update src/logdensityfunction.jl * Update src/logdensityfunction.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * replaces more references to `f.context` with `getcontext(f)` * Bump version to v0.28 Co-authored-by: Hong Ge <[email protected]> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent d384da2 commit 2b97177

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.27.2"
3+
version = "0.28"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/logdensityfunction.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ struct LogDensityFunction{V,M,C}
4949
varinfo::V
5050
"model used for evaluation"
5151
model::M
52-
"context used for evaluation"
52+
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
5353
context::C
5454
end
5555

@@ -66,15 +66,20 @@ end
6666
function LogDensityFunction(
6767
model::Model,
6868
varinfo::AbstractVarInfo=VarInfo(model),
69-
context::AbstractContext=model.context,
69+
context::Union{Nothing,AbstractContext}=nothing,
7070
)
7171
return LogDensityFunction(varinfo, model, context)
7272
end
7373

74+
# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
75+
function getcontext(f::LogDensityFunction)
76+
return f.context === nothing ? leafcontext(f.model.context) : f.context
77+
end
78+
7479
# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
7580
# we need to define these annoying methods to ensure that we stay compatible with everything.
76-
getsampler(f::LogDensityFunction) = getsampler(f.context)
77-
hassampler(f::LogDensityFunction) = hassampler(f.context)
81+
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))
82+
hassampler(f::LogDensityFunction) = hassampler(getcontext(f))
7883

7984
_get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx)
8085
_get_indexer(ctx::SamplingContext) = ctx.sampler
@@ -86,12 +91,13 @@ _get_indexer(::IsLeaf, ctx::AbstractContext) = Colon()
8691
8792
Return the parameters of the wrapped varinfo as a vector.
8893
"""
89-
getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(f.context)]
94+
getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))]
9095

9196
# LogDensityProblems interface
9297
function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector)
93-
vi_new = unflatten(f.varinfo, f.context, θ)
94-
return getlogp(last(evaluate!!(f.model, vi_new, f.context)))
98+
context = getcontext(f)
99+
vi_new = unflatten(f.varinfo, context, θ)
100+
return getlogp(last(evaluate!!(f.model, vi_new, context)))
95101
end
96102
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
97103
return LogDensityProblems.LogDensityOrder{0}()

0 commit comments

Comments
 (0)