Skip to content

Commit 1dc56f6

Browse files
committed
Fix literally everything else that I broke
1 parent c8db5a4 commit 1dc56f6

35 files changed

+349
-526
lines changed

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,12 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8181
end
8282

8383
adbackend = to_backend(adbackend)
84-
context = DynamicPPL.DefaultContext()
8584

8685
if islinked
8786
vi = DynamicPPL.link(vi, model)
8887
end
8988

90-
f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend)
89+
f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend)
9190
# The parameters at which we evaluate f.
9291
θ = vi[:]
9392

docs/src/api.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ getargnames
3636
getmissings
3737
```
3838

39+
The context of a model can be set using [`contextualize`](@ref):
40+
41+
```@docs
42+
contextualize
43+
```
44+
3945
## Evaluation
4046

4147
With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref).
@@ -438,13 +444,21 @@ DynamicPPL.varname_and_value_leaves
438444

439445
### Evaluation Contexts
440446

441-
Internally, both sampling and evaluation of log densities are performed with [`AbstractPPL.evaluate!!`](@ref).
447+
Internally, model evaluation is performed with [`AbstractPPL.evaluate!!`](@ref).
442448

443449
```@docs
444450
AbstractPPL.evaluate!!
445451
```
446452

447-
The behaviour of a model execution can be changed with evaluation contexts that are passed as additional argument to the model function.
453+
This method mutates the `varinfo` used for execution.
454+
By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`.
455+
To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method:
456+
457+
```@docs
458+
DynamicPPL.sample!!
459+
```
460+
461+
The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model.
448462
Contexts are subtypes of `AbstractPPL.AbstractContext`.
449463

450464
```@docs

docs/src/internals/submodel_condition.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,13 @@ Take these models, for example:
102102
unwrap_sampling_context(ctx::DynamicPPL.SamplingContext) = ctx.context
103103
unwrap_sampling_context(ctx::DynamicPPL.AbstractContext) = ctx
104104
@model function inner()
105-
println("inner context: $(unwrap_sampling_context(__context__))")
105+
println("inner context: $(unwrap_sampling_context(__model__.context))")
106106
x ~ Normal()
107107
return y ~ Normal()
108108
end
109109
110110
@model function outer()
111-
println("outer context: $(unwrap_sampling_context(__context__))")
111+
println("outer context: $(unwrap_sampling_context(__model__.context))")
112112
return a ~ to_submodel(inner())
113113
end
114114
@@ -118,7 +118,7 @@ with_outer_cond = outer() | (@varname(a.x) => 1.0)
118118
# 'Inner conditioning'
119119
inner_cond = inner() | (@varname(x) => 1.0)
120120
@model function outer2()
121-
println("outer context: $(unwrap_sampling_context(__context__))")
121+
println("outer context: $(unwrap_sampling_context(__model__.context))")
122122
return a ~ to_submodel(inner_cond)
123123
end
124124
with_inner_cond = outer2()

ext/DynamicPPLJETExt.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,10 @@ using DynamicPPL: DynamicPPL
44
using JET: JET
55

66
function DynamicPPL.Experimental.is_suitable_varinfo(
7-
model::DynamicPPL.Model,
8-
context::DynamicPPL.AbstractContext,
9-
varinfo::DynamicPPL.AbstractVarInfo;
10-
only_ddpl::Bool=true,
7+
model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true
118
)
129
# Let's make sure that both evaluation and sampling doesn't result in type errors.
13-
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
14-
model, varinfo, context
15-
)
10+
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo)
1611
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
1712
# This way we don't just fall back to untyped if the user's code is the issue.
1813
result = if only_ddpl
@@ -24,14 +19,19 @@ function DynamicPPL.Experimental.is_suitable_varinfo(
2419
end
2520

2621
function DynamicPPL.Experimental._determine_varinfo_jet(
27-
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
22+
model::DynamicPPL.Model; only_ddpl::Bool=true
2823
)
24+
# Use SamplingContext to test type stability.
25+
sampling_model = DynamicPPL.contextualize(
26+
model, DynamicPPL.SamplingContext(model.context)
27+
)
28+
2929
# First we try with the typed varinfo.
30-
varinfo = DynamicPPL.typed_varinfo(model, context)
30+
varinfo = DynamicPPL.typed_varinfo(sampling_model)
3131

3232
# Let's make sure that both evaluation and sampling doesn't result in type errors.
3333
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
34-
model, context, varinfo; only_ddpl
34+
sampling_model, varinfo; only_ddpl
3535
)
3636

3737
if !issuccess
@@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
4646
else
4747
# Warn the user that we can't use the type stable one.
4848
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
49-
DynamicPPL.untyped_varinfo(model, context)
49+
DynamicPPL.untyped_varinfo(sampling_model)
5050
end
5151
end
5252

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ function DynamicPPL.predict(
115115
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
116116
predictive_samples = map(iters) do (sample_idx, chain_idx)
117117
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
118-
model(rng, varinfo, DynamicPPL.SampleFromPrior())
118+
varinfo = last(DynamicPPL.sample!!(rng, model, varinfo))
119119

120120
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
121121
varname_vals = mapreduce(

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ export AbstractVarInfo,
102102
# LogDensityFunction
103103
LogDensityFunction,
104104
# Contexts
105+
contextualize,
105106
SamplingContext,
106107
DefaultContext,
107108
PrefixContext,

src/compiler.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)
1+
const INTERNALNAMES = (:__model__, :__varinfo__)
22

33
"""
44
need_concretize(expr)
@@ -63,9 +63,9 @@ used in its place.
6363
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
6464
return quote
6565
if $(DynamicPPL.contextual_isassumption)(
66-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
66+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
6767
)
68-
# Considered an assumption by `__context__` which means either:
68+
# Considered an assumption by `__model__.context` which means either:
6969
# 1. We hit the default implementation, e.g. using `DefaultContext`,
7070
# which in turn means that we haven't considered if it's one of
7171
# the model arguments, hence we need to check this.
@@ -116,7 +116,7 @@ end
116116
isfixed(expr, vn) = false
117117
function isfixed(::Union{Symbol,Expr}, vn)
118118
return :($(DynamicPPL.contextual_isfixed)(
119-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
119+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
120120
))
121121
end
122122

@@ -417,7 +417,7 @@ function generate_assign(left, right)
417417
return quote
418418
$right_val = $right
419419
if $(DynamicPPL.is_extracting_values)(__varinfo__)
420-
$vn = $(DynamicPPL.prefix)(__context__, $(make_varname_expression(left)))
420+
$vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left)))
421421
__varinfo__ = $(map_accumulator!!)(
422422
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
423423
)
@@ -431,7 +431,11 @@ function generate_tilde_literal(left, right)
431431
@gensym value
432432
return quote
433433
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
434-
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, __varinfo__
434+
__model__.context,
435+
$(DynamicPPL.check_tilde_rhs)($right),
436+
$left,
437+
nothing,
438+
__varinfo__,
435439
)
436440
$value
437441
end
@@ -456,20 +460,20 @@ function generate_tilde(left, right)
456460
$isassumption = $(DynamicPPL.isassumption(left, vn))
457461
if $(DynamicPPL.isfixed(left, vn))
458462
$left = $(DynamicPPL.getfixed_nested)(
459-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
463+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
460464
)
461465
elseif $isassumption
462466
$(generate_tilde_assume(left, dist, vn))
463467
else
464468
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
465469
if !$(DynamicPPL.inargnames)($vn, __model__)
466470
$left = $(DynamicPPL.getconditioned_nested)(
467-
__context__, $(DynamicPPL.prefix)(__context__, $vn)
471+
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
468472
)
469473
end
470474

471475
$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
472-
__context__,
476+
__model__.context,
473477
$(DynamicPPL.check_tilde_rhs)($dist),
474478
$(maybe_view(left)),
475479
$vn,
@@ -494,7 +498,7 @@ function generate_tilde_assume(left, right, vn)
494498

495499
return quote
496500
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
497-
__context__,
501+
__model__.context,
498502
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
499503
__varinfo__,
500504
)
@@ -652,11 +656,7 @@ function build_output(modeldef, linenumbernode)
652656

653657
# Add the internal arguments to the user-specified arguments (positional + keywords).
654658
evaluatordef[:args] = vcat(
655-
[
656-
:(__model__::$(DynamicPPL.Model)),
657-
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
658-
:(__context__::$(DynamicPPL.AbstractContext)),
659-
],
659+
[:(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo))],
660660
args,
661661
)
662662

0 commit comments

Comments
 (0)