diff --git a/HISTORY.md b/HISTORY.md index d559e6373..617543a5f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -31,11 +31,11 @@ This version therefore excises the context argument, and instead uses `model.con The upshot of this is that many functions that previously took a context argument now no longer do. There were very few such functions where the context argument was actually used (most of them simply took `DefaultContext()` as the default value). -`evaluate!!(model, varinfo, ext_context)` is deprecated, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`. +`evaluate!!(model, varinfo, ext_context)` is removed, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`. If the 'external context' `ext_context` is a parent context, then you should wrap `model.context` appropriately to ensure that its information content is not lost. If, on the other hand, `ext_context` is a `DefaultContext`, then you can just drop the argument entirely. -To aid with this process, `contextualize` is now exported from DynamicPPL. +**To aid with this process, `contextualize` is now exported from DynamicPPL.** The main situation where one _did_ want to specify an additional evaluation context was when that context was a `SamplingContext`. Doing this would allow you to run the model and sample fresh values, instead of just using the values that existed in the VarInfo object. @@ -54,9 +54,10 @@ However, here are the more user-facing ones: And a couple of more internal changes: - - `evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` no longer accept context arguments + - Just like `evaluate!!`, the other functions `_evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` now no longer accept context arguments - `evaluate!!` no longer takes rng and sampler (if you used this, you should use `evaluate_and_sample!!` instead, or construct your own `SamplingContext`) - The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument + - The internal representation and API dealing with submodels (i.e., `ReturnedModelWrapper`, `Sampleable`, `should_auto_prefix`, `is_rhs_model`) has been simplified. If you need to check whether something is a submodel, just use `x isa DynamicPPL.Submodel`. Note that the public API i.e. `to_submodel` remains completely untouched. ## 0.36.12 diff --git a/docs/src/api.md b/docs/src/api.md index 886d34a2f..24efdae30 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -152,12 +152,6 @@ In the context of including models within models, it's also useful to prefix the DynamicPPL.prefix ``` -Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else - -```@docs -returned(::Model) -``` - ## Utilities It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 2b4d0e4a6..502e725ce 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -171,11 +171,11 @@ abstract type AbstractVarInfo <: AbstractModelTrace end include("utils.jl") include("chains.jl") include("model.jl") -include("submodel.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") +include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") diff --git a/src/compiler.jl b/src/compiler.jl index 22dff33a2..6384eaa7c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -176,11 +176,7 @@ function check_tilde_rhs(@nospecialize(x)) end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x -check_tilde_rhs(x::ReturnedModelWrapper) = x -function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} - model = check_tilde_rhs(x.model) - return Sampleable{typeof(model),AutoPrefix}(model) -end +check_tilde_rhs(x::Submodel{M,AutoPrefix}) where {M,AutoPrefix} = x """ check_dot_tilde_rhs(x) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index cc75cd7e6..b11a723a5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -63,31 +63,10 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) - return if is_rhs_model(right) - # Here, we apply the PrefixContext _not_ to the parent `context`, but - # to the context of the submodel being evaluated. This means that later= - # on in `make_evaluate_args_and_kwargs`, the context stack will be - # correctly arranged such that it goes like this: - # parent_context[1] -> parent_context[2] -> ... -> PrefixContext -> - # submodel_context[1] -> submodel_context[2] -> ... -> leafcontext - # See the docstring of `make_evaluate_args_and_kwargs`, and the internal - # DynamicPPL documentation on submodel conditioning, for more details. - # - # NOTE: This relies on the existence of `right.model.model`. Right now, - # the only thing that can return true for `is_rhs_model` is something - # (a `Sampleable`) that has a `model` field that itself (a - # `ReturnedModelWrapper`) has a `model` field. This may or may not - # change in the future. - if should_auto_prefix(right) - dppl_model = right.model.model # This isa DynamicPPL.Model - prefixed_submodel_context = PrefixContext(vn, dppl_model.context) - new_dppl_model = contextualize(dppl_model, prefixed_submodel_context) - right = to_submodel(new_dppl_model, true) - end - rand_like!!(right, context, vi) + return if right isa DynamicPPL.Submodel + _evaluate!!(right, vi, context, vn) else - value, vi = tilde_assume(context, right, vn, vi) - return value, vi + tilde_assume(context, right, vn, vi) end end @@ -129,17 +108,14 @@ accumulate the log probability, and return the observed value and updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!!(context::DefaultContext, right, left, vn, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) +function tilde_observe!!(::DefaultContext, right, left, vn, vi) + right isa DynamicPPL.Submodel && + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) vi = accumulate_observe!!(vi, right, left, vn) return left, vi end -function assume(rng::Random.AbstractRNG, spl::Sampler, dist) +function assume(::Random.AbstractRNG, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end diff --git a/src/model.jl b/src/model.jl index 27551bfa2..93e77eaec 100644 --- a/src/model.jl +++ b/src/model.jl @@ -258,7 +258,7 @@ julia> # However, it's not possible to condition `inner` directly. conditioned_model_fail = model | (inner = 1.0, ); julia> conditioned_model_fail() -ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported +ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observed [...] ``` """ @@ -864,12 +864,6 @@ If multiple threads are available, the varinfo provided will be wrapped in a Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). - - evaluate!!(model::Model, varinfo, context) - -When an extra context stack is provided, the model's context is inserted into -that context stack. See `combine_model_and_external_contexts`. This method is -deprecated. """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) return if use_threadsafe_eval(model.context, varinfo) @@ -878,17 +872,6 @@ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) evaluate_threadunsafe!!(model, varinfo) end end -function AbstractPPL.evaluate!!( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext -) - Base.depwarn( - "The `context` argument to evaluate!!(model, varinfo, context) is deprecated.", - :dynamicppl_evaluate_context, - ) - new_ctx = combine_model_and_external_contexts(model.context, context) - model = contextualize(model, new_ctx) - return evaluate!!(model, varinfo) -end """ evaluate_threadunsafe!!(model, varinfo) @@ -932,54 +915,14 @@ Evaluate the `model` with the given `varinfo`. This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not reset the log probability of the `varinfo` before running. - - _evaluate!!(model::Model, varinfo, context) - -If an additional `context` is provided, the model's context is combined with -that context before evaluation. """ function _evaluate!!(model::Model, varinfo::AbstractVarInfo) args, kwargs = make_evaluate_args_and_kwargs(model, varinfo) return model.f(args...; kwargs...) end -function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) - # TODO(penelopeysm): We don't really need this, but it's a useful - # convenience method. We could remove it after we get rid of the - # evaluate_threadsafe!! stuff (in favour of making users call evaluate!! - # with a TSVI themselves). - new_ctx = combine_model_and_external_contexts(model.context, context) - model = contextualize(model, new_ctx) - return _evaluate!!(model, varinfo) -end is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#") -""" - combine_model_and_external_contexts(model_context, external_context) - -Combine a context from a model and an external context into a single context. - -The resulting context stack has the following structure: - - `external_context` -> `childcontext(external_context)` -> ... -> - `model_context` -> `childcontext(model_context)` -> ... -> - `leafcontext(external_context)` - -The reason for this is that we want to give `external_context` precedence over -`model_context`, while also preserving the leaf context of `external_context`. -We can do this by - -1. Set the leaf context of `model_context` to `leafcontext(external_context)`. -2. Set leaf context of `external_context` to the context resulting from (1). -""" -function combine_model_and_external_contexts( - model_context::AbstractContext, external_context::AbstractContext -) - return setleafcontext( - external_context, setleafcontext(model_context, leafcontext(external_context)) - ) -end - """ make_evaluate_args_and_kwargs(model, varinfo) diff --git a/src/submodel.jl b/src/submodel.jl index 94658b6bf..dcb107bb4 100644 --- a/src/submodel.jl +++ b/src/submodel.jl @@ -1,98 +1,13 @@ """ - is_rhs_model(x) + Submodel{M,AutoPrefix} -Return `true` if `x` is a model or model wrapper, and `false` otherwise. +A wrapper around a model, plus a flag indicating whether it should be automatically +prefixed with the left-hand variable in a `~` statement. """ -is_rhs_model(x) = false - -""" - Distributional - -Abstract type for type indicating that something is "distributional". -""" -abstract type Distributional end - -""" - should_auto_prefix(distributional) - -Return `true` if the `distributional` should use automatic prefixing, and `false` otherwise. -""" -function should_auto_prefix end - -""" - is_rhs_model(x) - -Return `true` if the `distributional` is a model, and `false` otherwise. -""" -function is_rhs_model end - -""" - Sampleable{M} <: Distributional - -A wrapper around a model indicating it is sampleable. -""" -struct Sampleable{M,AutoPrefix} <: Distributional +struct Submodel{M,AutoPrefix} model::M end -should_auto_prefix(::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} = AutoPrefix -is_rhs_model(x::Sampleable) = is_rhs_model(x.model) - -# TODO: Export this if it end up having a purpose beyond `to_submodel`. -""" - to_sampleable(model[, auto_prefix]) - -Return a wrapper around `model` indicating it is sampleable. - -# Arguments -- `model::Model`: the model to wrap. -- `auto_prefix::Bool`: whether to prefix the variables in the model. Default: `true`. -""" -to_sampleable(model, auto_prefix::Bool=true) = Sampleable{typeof(model),auto_prefix}(model) - -""" - rand_like!!(model_wrap, context, varinfo) - -Returns a tuple with the first element being the realization and the second the updated varinfo. - -# Arguments -- `model_wrap::ReturnedModelWrapper`: the wrapper of the model to use. -- `context::AbstractContext`: the context to use for evaluation. -- `varinfo::AbstractVarInfo`: the varinfo to use for evaluation. - """ -function rand_like!!( - model_wrap::Sampleable, context::AbstractContext, varinfo::AbstractVarInfo -) - return rand_like!!(model_wrap.model, context, varinfo) -end - -""" - ReturnedModelWrapper - -A wrapper around a model indicating it is a model over its return values. - -This should rarely be constructed explicitly; see [`returned(model)`](@ref) instead. -""" -struct ReturnedModelWrapper{M<:Model} - model::M -end - -is_rhs_model(::ReturnedModelWrapper) = true - -function rand_like!!( - model_wrap::ReturnedModelWrapper, context::AbstractContext, varinfo::AbstractVarInfo -) - # Return's the value and the (possibly mutated) varinfo. - return _evaluate!!(model_wrap.model, varinfo, context) -end - -""" - returned(model) - -Return a `model` wrapper indicating that it is a model over its return-values. -""" -returned(model::Model) = ReturnedModelWrapper(model) - """ to_submodel(model::Model[, auto_prefix::Bool]) @@ -106,8 +21,8 @@ the model can be sampled from but not necessarily evaluated for its log density. `left ~ right` such as [`condition`](@ref), will also not work with `to_submodel`. !!! warning - To avoid variable names clashing between models, it is recommend leave argument `auto_prefix` equal to `true`. - If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly. + To avoid variable names clashing between models, it is recommended to leave the argument `auto_prefix` equal to `true`. + If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly, i.e. `to_submodel(prefix(model, @varname(my_prefix)))` # Arguments - `model::Model`: the model to wrap. @@ -231,9 +146,50 @@ illegal_likelihood (generic function with 2 methods) julia> model = illegal_likelihood() | (a = 1.0,); julia> model() -ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported +ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observed [...] ``` """ -to_submodel(model::Model, auto_prefix::Bool=true) = - to_sampleable(returned(model), auto_prefix) +to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}(m) + +# When automatic prefixing is used, the submodel itself doesn't carry the +# prefix, as the prefix is obtained from the LHS of `~` (whereas the submodel +# is on the RHS). The prefix can only be obtained in `tilde_assume!!`, and then +# passed into this function. +# +# `parent_context` here refers to the context of the model that contains the +# submodel. +function _evaluate!!( + submodel::Submodel{M,AutoPrefix}, + vi::AbstractVarInfo, + parent_context::AbstractContext, + left_vn::VarName, +) where {M<:Model,AutoPrefix} + # First, we construct the context to be used when evaluating the submodel. There + # are several considerations here: + # (1) We need to apply an appropriate PrefixContext when evaluating the submodel, but + # _only_ if automatic prefixing is supposed to be applied. + submodel_context_prefixed = if AutoPrefix + PrefixContext(left_vn, submodel.model.context) + else + submodel.model.context + end + + # (2) We need to respect the leaf-context of the parent model. This, unfortunately, + # means disregarding the leaf-context of the submodel. + submodel_context = setleafcontext( + submodel_context_prefixed, leafcontext(parent_context) + ) + + # (3) We need to use the parent model's context to wrap the whole thing, so that + # e.g. if the user conditions the parent model, the conditioned variables will be + # correctly picked up when evaluating the submodel. + eval_context = setleafcontext(parent_context, submodel_context) + + # (4) Finally, we need to store that context inside the submodel. + model = contextualize(submodel.model, eval_context) + + # Once that's all set up nicely, we can just _evaluate!! the wrapped model. This + # returns a tuple of submodel.model's return value and the new varinfo. + return _evaluate!!(model, vi) +end diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 08acdfada..863db4262 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -63,10 +63,12 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. # Untyped varinfo. varinfo_untyped = DynamicPPL.VarInfo() - @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) - @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) + model_with_spl = contextualize(model, SamplingContext(context)) + model_without_spl = contextualize(model, context) + @test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any + @test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any # Typed varinfo. varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) - @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) - @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) + @test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any + @test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any end