Skip to content

Remove 3-argument {_,}evaluate!!; clean up submodel code #960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: breaking
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ This version therefore excises the context argument, and instead uses `model.con
The upshot of this is that many functions that previously took a context argument now no longer do.
There were very few such functions where the context argument was actually used (most of them simply took `DefaultContext()` as the default value).

`evaluate!!(model, varinfo, ext_context)` is deprecated, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`.
`evaluate!!(model, varinfo, ext_context)` is removed, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`.
If the 'external context' `ext_context` is a parent context, then you should wrap `model.context` appropriately to ensure that its information content is not lost.
If, on the other hand, `ext_context` is a `DefaultContext`, then you can just drop the argument entirely.

To aid with this process, `contextualize` is now exported from DynamicPPL.
**To aid with this process, `contextualize` is now exported from DynamicPPL.**

The main situation where one _did_ want to specify an additional evaluation context was when that context was a `SamplingContext`.
Doing this would allow you to run the model and sample fresh values, instead of just using the values that existed in the VarInfo object.
Expand All @@ -54,9 +54,10 @@ However, here are the more user-facing ones:

And a couple of more internal changes:

- `evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` no longer accept context arguments
- Just like `evaluate!!`, the other functions `_evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` now no longer accept context arguments
- `evaluate!!` no longer takes rng and sampler (if you used this, you should use `evaluate_and_sample!!` instead, or construct your own `SamplingContext`)
- The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument
- The internal representation and API dealing with submodels (i.e., `ReturnedModelWrapper`, `Sampleable`, `should_auto_prefix`, `is_rhs_model`) has been simplified. If you need to check whether something is a submodel, just use `x isa DynamicPPL.Submodel`. Note that the public API i.e. `to_submodel` remains completely untouched.

## 0.36.12

Expand Down
6 changes: 0 additions & 6 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,6 @@ In the context of including models within models, it's also useful to prefix the
DynamicPPL.prefix
```

Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else

```@docs
returned(::Model)
```

## Utilities

It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function.
Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ abstract type AbstractVarInfo <: AbstractModelTrace end
include("utils.jl")
include("chains.jl")
include("model.jl")
include("submodel.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("submodel.jl")
include("varnamedvector.jl")
include("accumulators.jl")
include("default_accumulators.jl")
Expand Down
6 changes: 1 addition & 5 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,7 @@ function check_tilde_rhs(@nospecialize(x))
end
check_tilde_rhs(x::Distribution) = x
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
check_tilde_rhs(x::ReturnedModelWrapper) = x
function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix}
model = check_tilde_rhs(x.model)
return Sampleable{typeof(model),AutoPrefix}(model)
end
check_tilde_rhs(x::Submodel{M,AutoPrefix}) where {M,AutoPrefix} = x

"""
check_dot_tilde_rhs(x)
Expand Down
38 changes: 7 additions & 31 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,31 +63,10 @@
probability of `vi` with the returned value.
"""
function tilde_assume!!(context, right, vn, vi)
return if is_rhs_model(right)
# Here, we apply the PrefixContext _not_ to the parent `context`, but
# to the context of the submodel being evaluated. This means that later=
# on in `make_evaluate_args_and_kwargs`, the context stack will be
# correctly arranged such that it goes like this:
# parent_context[1] -> parent_context[2] -> ... -> PrefixContext ->
# submodel_context[1] -> submodel_context[2] -> ... -> leafcontext
# See the docstring of `make_evaluate_args_and_kwargs`, and the internal
# DynamicPPL documentation on submodel conditioning, for more details.
#
# NOTE: This relies on the existence of `right.model.model`. Right now,
# the only thing that can return true for `is_rhs_model` is something
# (a `Sampleable`) that has a `model` field that itself (a
# `ReturnedModelWrapper`) has a `model` field. This may or may not
# change in the future.
if should_auto_prefix(right)
dppl_model = right.model.model # This isa DynamicPPL.Model
prefixed_submodel_context = PrefixContext(vn, dppl_model.context)
new_dppl_model = contextualize(dppl_model, prefixed_submodel_context)
right = to_submodel(new_dppl_model, true)
end
rand_like!!(right, context, vi)
return if right isa DynamicPPL.Submodel
_evaluate!!(right, vi, context, vn)
else
value, vi = tilde_assume(context, right, vn, vi)
return value, vi
tilde_assume(context, right, vn, vi)
end
end

Expand Down Expand Up @@ -129,17 +108,14 @@
Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name
and indices; if needed, these can be accessed through this function, though.
"""
function tilde_observe!!(context::DefaultContext, right, left, vn, vi)
is_rhs_model(right) && throw(
ArgumentError(
"`~` with a model on the right-hand side of an observe statement is not supported",
),
)
function tilde_observe!!(::DefaultContext, right, left, vn, vi)
right isa DynamicPPL.Submodel &&
throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed"))
vi = accumulate_observe!!(vi, right, left, vn)
return left, vi
end

function assume(rng::Random.AbstractRNG, spl::Sampler, dist)
function assume(::Random.AbstractRNG, spl::Sampler, dist)

Check warning on line 118 in src/context_implementations.jl

View check run for this annotation

Codecov / codecov/patch

src/context_implementations.jl#L118

Added line #L118 was not covered by tests
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
end

Expand Down
59 changes: 1 addition & 58 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ julia> # However, it's not possible to condition `inner` directly.
conditioned_model_fail = model | (inner = 1.0, );

julia> conditioned_model_fail()
ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported
ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observed
[...]
```
"""
Expand Down Expand Up @@ -864,12 +864,6 @@ If multiple threads are available, the varinfo provided will be wrapped in a

Returns a tuple of the model's return value, plus the updated `varinfo`
(unwrapped if necessary).

evaluate!!(model::Model, varinfo, context)

When an extra context stack is provided, the model's context is inserted into
that context stack. See `combine_model_and_external_contexts`. This method is
deprecated.
"""
function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
return if use_threadsafe_eval(model.context, varinfo)
Expand All @@ -878,17 +872,6 @@ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
evaluate_threadunsafe!!(model, varinfo)
end
end
function AbstractPPL.evaluate!!(
model::Model, varinfo::AbstractVarInfo, context::AbstractContext
)
Base.depwarn(
"The `context` argument to evaluate!!(model, varinfo, context) is deprecated.",
:dynamicppl_evaluate_context,
)
new_ctx = combine_model_and_external_contexts(model.context, context)
model = contextualize(model, new_ctx)
return evaluate!!(model, varinfo)
end

"""
evaluate_threadunsafe!!(model, varinfo)
Expand Down Expand Up @@ -932,54 +915,14 @@ Evaluate the `model` with the given `varinfo`.

This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not
reset the log probability of the `varinfo` before running.

_evaluate!!(model::Model, varinfo, context)

If an additional `context` is provided, the model's context is combined with
that context before evaluation.
"""
function _evaluate!!(model::Model, varinfo::AbstractVarInfo)
args, kwargs = make_evaluate_args_and_kwargs(model, varinfo)
return model.f(args...; kwargs...)
end
function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext)
# TODO(penelopeysm): We don't really need this, but it's a useful
# convenience method. We could remove it after we get rid of the
# evaluate_threadsafe!! stuff (in favour of making users call evaluate!!
# with a TSVI themselves).
new_ctx = combine_model_and_external_contexts(model.context, context)
model = contextualize(model, new_ctx)
return _evaluate!!(model, varinfo)
end

is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#")

"""
combine_model_and_external_contexts(model_context, external_context)

Combine a context from a model and an external context into a single context.

The resulting context stack has the following structure:

`external_context` -> `childcontext(external_context)` -> ... ->
`model_context` -> `childcontext(model_context)` -> ... ->
`leafcontext(external_context)`

The reason for this is that we want to give `external_context` precedence over
`model_context`, while also preserving the leaf context of `external_context`.
We can do this by

1. Set the leaf context of `model_context` to `leafcontext(external_context)`.
2. Set leaf context of `external_context` to the context resulting from (1).
"""
function combine_model_and_external_contexts(
model_context::AbstractContext, external_context::AbstractContext
)
return setleafcontext(
external_context, setleafcontext(model_context, leafcontext(external_context))
)
end

"""
make_evaluate_args_and_kwargs(model, varinfo)

Expand Down
144 changes: 50 additions & 94 deletions src/submodel.jl
Original file line number Diff line number Diff line change
@@ -1,98 +1,13 @@
"""
is_rhs_model(x)
Submodel{M,AutoPrefix}

Return `true` if `x` is a model or model wrapper, and `false` otherwise.
A wrapper around a model, plus a flag indicating whether it should be automatically
prefixed with the left-hand variable in a `~` statement.
"""
is_rhs_model(x) = false

"""
Distributional

Abstract type for type indicating that something is "distributional".
"""
abstract type Distributional end

"""
should_auto_prefix(distributional)

Return `true` if the `distributional` should use automatic prefixing, and `false` otherwise.
"""
function should_auto_prefix end

"""
is_rhs_model(x)

Return `true` if the `distributional` is a model, and `false` otherwise.
"""
function is_rhs_model end

"""
Sampleable{M} <: Distributional

A wrapper around a model indicating it is sampleable.
"""
struct Sampleable{M,AutoPrefix} <: Distributional
struct Submodel{M,AutoPrefix}
model::M
end

should_auto_prefix(::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} = AutoPrefix
is_rhs_model(x::Sampleable) = is_rhs_model(x.model)

# TODO: Export this if it end up having a purpose beyond `to_submodel`.
"""
to_sampleable(model[, auto_prefix])

Return a wrapper around `model` indicating it is sampleable.

# Arguments
- `model::Model`: the model to wrap.
- `auto_prefix::Bool`: whether to prefix the variables in the model. Default: `true`.
"""
to_sampleable(model, auto_prefix::Bool=true) = Sampleable{typeof(model),auto_prefix}(model)

"""
rand_like!!(model_wrap, context, varinfo)

Returns a tuple with the first element being the realization and the second the updated varinfo.

# Arguments
- `model_wrap::ReturnedModelWrapper`: the wrapper of the model to use.
- `context::AbstractContext`: the context to use for evaluation.
- `varinfo::AbstractVarInfo`: the varinfo to use for evaluation.
"""
function rand_like!!(
model_wrap::Sampleable, context::AbstractContext, varinfo::AbstractVarInfo
)
return rand_like!!(model_wrap.model, context, varinfo)
end

"""
ReturnedModelWrapper

A wrapper around a model indicating it is a model over its return values.

This should rarely be constructed explicitly; see [`returned(model)`](@ref) instead.
"""
struct ReturnedModelWrapper{M<:Model}
model::M
end

is_rhs_model(::ReturnedModelWrapper) = true

function rand_like!!(
model_wrap::ReturnedModelWrapper, context::AbstractContext, varinfo::AbstractVarInfo
)
# Return's the value and the (possibly mutated) varinfo.
return _evaluate!!(model_wrap.model, varinfo, context)
end

"""
returned(model)

Return a `model` wrapper indicating that it is a model over its return-values.
"""
returned(model::Model) = ReturnedModelWrapper(model)

"""
to_submodel(model::Model[, auto_prefix::Bool])

Expand All @@ -106,8 +21,8 @@ the model can be sampled from but not necessarily evaluated for its log density.
`left ~ right` such as [`condition`](@ref), will also not work with `to_submodel`.

!!! warning
To avoid variable names clashing between models, it is recommend leave argument `auto_prefix` equal to `true`.
If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly.
To avoid variable names clashing between models, it is recommended to leave the argument `auto_prefix` equal to `true`.
If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly, i.e. `to_submodel(prefix(model, @varname(my_prefix)))`

# Arguments
- `model::Model`: the model to wrap.
Expand Down Expand Up @@ -231,9 +146,50 @@ illegal_likelihood (generic function with 2 methods)
julia> model = illegal_likelihood() | (a = 1.0,);

julia> model()
ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported
ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observed
[...]
```
"""
to_submodel(model::Model, auto_prefix::Bool=true) =
to_sampleable(returned(model), auto_prefix)
to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}(m)

# When automatic prefixing is used, the submodel itself doesn't carry the
# prefix, as the prefix is obtained from the LHS of `~` (whereas the submodel
# is on the RHS). The prefix can only be obtained in `tilde_assume!!`, and then
# passed into this function.
#
# `parent_context` here refers to the context of the model that contains the
# submodel.
function _evaluate!!(
submodel::Submodel{M,AutoPrefix},
vi::AbstractVarInfo,
parent_context::AbstractContext,
left_vn::VarName,
) where {M<:Model,AutoPrefix}
# First, we construct the context to be used when evaluating the submodel. There
# are several considerations here:
# (1) We need to apply an appropriate PrefixContext when evaluating the submodel, but
# _only_ if automatic prefixing is supposed to be applied.
submodel_context_prefixed = if AutoPrefix
PrefixContext(left_vn, submodel.model.context)
else
submodel.model.context
end

# (2) We need to respect the leaf-context of the parent model. This, unfortunately,
# means disregarding the leaf-context of the submodel.
submodel_context = setleafcontext(
submodel_context_prefixed, leafcontext(parent_context)
)

# (3) We need to use the parent model's context to wrap the whole thing, so that
# e.g. if the user conditions the parent model, the conditioned variables will be
# correctly picked up when evaluating the submodel.
eval_context = setleafcontext(parent_context, submodel_context)

# (4) Finally, we need to store that context inside the submodel.
model = contextualize(submodel.model, eval_context)

# Once that's all set up nicely, we can just _evaluate!! the wrapped model. This
# returns a tuple of submodel.model's return value and the new varinfo.
return _evaluate!!(model, vi)
end
Loading
Loading