Skip to content

Issue warning if RVs are present in derived probability graphs #6651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,33 @@
TensorLike: TypeAlias = Union[Variable, float, np.ndarray]


def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
def _warn_rvs_in_inferred_graph(graph: Sequence[TensorVariable]):
"""Issue warning if any RVs are found in graph.

RVs are usually an (implicit) conditional input of the derived probability expression,
and meant to be replaced by respective value variables before evaluation.
However, when the IR graph is built, any non-input nodes (including RVs) are cloned,
breaking the link with the original ones.
This makes it impossible (or difficult) to replace it by the respective values afterward,
so we instruct users to do it beforehand.
"""
from pymc.testing import assert_no_rvs

try:
assert_no_rvs(graph)
except AssertionError:
warnings.warn(
"RandomVariables were found in the derived graph. "
"These variables are a clone and do not match the original ones on identity.\n"
"If you are deriving a quantity that depends on model RVs, use `model.replace_rvs_by_values` first. For example: "
"`logp(model.replace_rvs_by_values([rv])[0], value)`",
stacklevel=3,
)


def logp(
rv: TensorVariable, value: TensorLike, warn_missing_rvs: bool = True, **kwargs
) -> TensorVariable:
"""Return the log-probability graph of a Random Variable"""

value = pt.as_tensor_variable(value, dtype=rv.dtype)
Expand All @@ -74,10 +100,15 @@ def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
except NotImplementedError:
fgraph, _, _ = construct_ir_fgraph({rv: value})
[(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items()
return _logprob_helper(ir_rv, ir_value, **kwargs)
expr = _logprob_helper(ir_rv, ir_value, **kwargs)
if warn_missing_rvs:
_warn_rvs_in_inferred_graph(expr)
return expr


def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
def logcdf(
rv: TensorVariable, value: TensorLike, warn_missing_rvs: bool = True, **kwargs
) -> TensorVariable:
"""Create a graph for the log-CDF of a Random Variable."""
value = pt.as_tensor_variable(value, dtype=rv.dtype)
try:
Expand All @@ -86,10 +117,15 @@ def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
# Try to rewrite rv
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
[ir_rv] = fgraph.outputs
return _logcdf_helper(ir_rv, value, **kwargs)
expr = _logcdf_helper(ir_rv, value, **kwargs)
if warn_missing_rvs:
_warn_rvs_in_inferred_graph(expr)
return expr


def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
def icdf(
rv: TensorVariable, value: TensorLike, warn_missing_rvs: bool = True, **kwargs
) -> TensorVariable:
"""Create a graph for the inverse CDF of a Random Variable."""
value = pt.as_tensor_variable(value, dtype=rv.dtype)
try:
Expand All @@ -98,7 +134,10 @@ def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
# Try to rewrite rv
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
[ir_rv] = fgraph.outputs
return _icdf_helper(ir_rv, value, **kwargs)
expr = _icdf_helper(ir_rv, value, **kwargs)
if warn_missing_rvs:
_warn_rvs_in_inferred_graph(expr)
return expr


def factorized_joint_logprob(
Expand Down Expand Up @@ -215,7 +254,8 @@ def factorized_joint_logprob(
if warn_missing_rvs:
warnings.warn(
"Found a random variable that was neither among the observations "
f"nor the conditioned variables: {node.outputs}"
f"nor the conditioned variables: {outputs}.\n"
"This variables is a clone and does not match the original one on identity."
)
continue

Expand Down
57 changes: 55 additions & 2 deletions tests/logprob/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@
import pymc as pm

from pymc.logprob.basic import factorized_joint_logprob, icdf, joint_logp, logcdf, logp
from pymc.logprob.transforms import LogTransform
from pymc.logprob.utils import rvs_to_value_vars, walk_model
from pymc.pytensorf import replace_rvs_by_values
from pymc.testing import assert_no_rvs
from tests.logprob.utils import joint_logprob

Expand Down Expand Up @@ -248,16 +250,25 @@ def test_persist_inputs():
y_vv_2 = y_vv * 2
logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})

assert y_vv in ancestors([logp_2])
assert y_vv_2 in ancestors([logp_2])

# Even when they are random
y_vv = pt.random.normal(name="y_vv2")
y_vv_2 = y_vv * 2
logp_2 = joint_logprob({beta_rv: beta_vv, Y_rv: y_vv_2})

assert y_vv in ancestors([logp_2])
assert y_vv_2 in ancestors([logp_2])


def test_warn_random_not_found():
def test_warn_random_found_factorized_joint_logprob():
x_rv = pt.random.normal(name="x")
y_rv = pt.random.normal(x_rv, 1, name="y")

y_vv = y_rv.clone()

with pytest.warns(UserWarning):
with pytest.warns(UserWarning, match="Found a random variable that was neither among"):
factorized_joint_logprob({y_rv: y_vv})

with warnings.catch_warnings():
Expand Down Expand Up @@ -457,3 +468,45 @@ def test_probability_inference_fails(func, func_name):
match=f"{func_name} method not implemented for Elemwise{{cos,no_inplace}}",
):
func(pt.cos(pm.Normal.dist()), 1)


@pytest.mark.parametrize(
"func, scipy_func, test_value",
[
(logp, "logpdf", 5.0),
(logcdf, "logcdf", 5.0),
(icdf, "ppf", 0.7),
],
)
def test_warn_random_found_probability_inference(func, scipy_func, test_value):
# Fail if unexpected warning is issued
with warnings.catch_warnings():
warnings.simplefilter("error")

input_rv = pm.Normal.dist(0, name="input")
# Note: This graph could correspond to a convolution of two normals
# In which case the inference should either return that or fail explicitly
# For now, the lopgrob submodule treats the input as a stochastic value.
rv = pt.exp(pm.Normal.dist(input_rv))
with pytest.warns(UserWarning, match="RandomVariables were found in the derived graph"):
assert func(rv, 0.0)

res = func(rv, 0.0, warn_missing_rvs=False)
# This is the problem we are warning about, as now we can no longer identify the original rv in the graph
# or replace it by the respective value
assert rv not in ancestors([res])

# Test that the prescribed solution does not raise a warning and works as expected
input_vv = input_rv.clone()
[new_rv] = replace_rvs_by_values(
[rv],
rvs_to_values={input_rv: input_vv},
rvs_to_transforms={input_rv: LogTransform()},
)
input_vv_test = 1.3
np.testing.assert_almost_equal(
func(new_rv, test_value).eval({input_vv: input_vv_test}),
getattr(sp.lognorm(s=1, loc=0, scale=np.exp(np.exp(input_vv_test))), scipy_func)(
test_value
),
)