Skip to content

Add sub-package-level deprecations #1434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions aesara/scalar/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,28 @@
from .basic import *
from .math import *
from typing import List, Tuple

from aesara.scalar.basic import *
from aesara.scalar.math import *


# isort: off
from aesara.scalar.basic import DEPRECATED_NAMES as BASIC_DEPRECATIONS

# isort: on

DEPRECATED_NAMES: List[Tuple[str, str, object]] = BASIC_DEPRECATIONS


def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.

Adapted from https://stackoverflow.com/a/55139609/3006474.

"""
from warnings import warn

for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object

raise AttributeError(f"module {__name__} has no attribute {name}")
237 changes: 233 additions & 4 deletions aesara/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from copy import copy
from itertools import chain
from textwrap import dedent
from typing import Any, Dict, Mapping, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union

import numpy as np
from typing_extensions import TypeAlias
Expand Down Expand Up @@ -2036,8 +2036,6 @@ def grad(self, inputs, gout):


true_divide = TrueDiv(upcast_out, name="true_divide")

true_div = true_divide
divide = true_divide


Expand Down Expand Up @@ -4454,7 +4452,7 @@ def handle_composite(node, mapping):
Compositef32.special[Composite] = handle_composite


DEPRECATED_NAMES = [
DEPRECATED_NAMES: List[Tuple[str, str, object]] = [
("Inv", "`Inv` is deprecated; use `Reciprocal` instead.", Reciprocal),
("inv", "`inv` is deprecated; use `reciprocal` instead.", reciprocal),
("Scalar", "`Scalar` is deprecated; use `ScalarType` instead.", ScalarType),
Expand All @@ -4480,3 +4478,234 @@ def __getattr__(name):
return old_object

raise AttributeError(f"module {__name__} has no attribute {name}")


__all__ = [
"constant",
"as_scalar",
"ComplexError",
"IntegerDivisionError",
"upcast",
"as_common_dtype",
"NumpyAutocaster",
"autocast_float_as",
"convert",
"ScalarType",
"get_scalar_type",
"ScalarVariable",
"ScalarConstant",
"upcast_out",
"upcast_out_nobool",
"upcast_out_min8",
"upcast_out_no_complex",
"same_out_float_only",
"transfer_type",
"specific_out",
"int_out",
"float_out",
"upgrade_to_float_no_complex",
"same_out_nocomplex",
"int_out_nocomplex",
"float_out_nocomplex",
"unary_out_lookup",
"real_out",
"ScalarOp",
"UnaryScalarOp",
"BinaryScalarOp",
"LogicalComparison",
"FixedLogicalComparison",
"LT",
"GT",
"LE",
"GE",
"EQ",
"NEQ",
"IsNan",
"IsInf",
"InRange",
"Switch",
"UnaryBitOp",
"BinaryBitOp",
"OR",
"XOR",
"AND",
"Invert",
"ScalarMaximum",
"ScalarMinimum",
"Add",
"Mean",
"Mul",
"Sub",
"TrueDiv",
"IntDiv",
"mod_check",
"Mod",
"Pow",
"Clip",
"Second",
"Identity",
"Cast",
"cast",
"Abs",
"Sgn",
"Ceil",
"Floor",
"Trunc",
"RoundHalfToEven",
"round_half_away_from_zero_",
"round_half_away_from_zero_vec",
"RoundHalfAwayFromZero",
"Neg",
"Reciprocal",
"Log",
"Log2",
"Log10",
"Log1p",
"Exp",
"Exp2",
"Expm1",
"Sqr",
"Sqrt",
"Deg2Rad",
"Rad2Deg",
"Cos",
"ArcCos",
"Sin",
"ArcSin",
"Tan",
"ArcTan",
"ArcTan2",
"Cosh",
"ArcCosh",
"Sinh",
"ArcSinh",
"Tanh",
"ArcTanh",
"Real",
"Imag",
"Angle",
"Complex",
"Conj",
"ComplexFromPolar",
"Composite",
"Compositef32",
"handle_cast",
"handle_composite",
"autocast_int",
"autocast_float",
"bool",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float16",
"float32",
"float64",
"complex64",
"complex128",
"int_types",
"uint_types",
"float_types",
"complex_types",
"integer_types",
"discrete_types",
"continuous_types",
"all_types",
"discrete_dtypes",
"ints",
"floats",
"complexs",
"complexs64",
"complexs128",
"lt",
"gt",
"le",
"ge",
"eq",
"neq",
"isnan",
"isinf",
"inopenrange",
"inclosedrange",
"switch",
"or_",
"xor",
"and_",
"invert",
"scalar_maximum",
"scalar_minimum",
"add",
"mean",
"mul",
"sub",
"true_divide",
"divide",
"int_div",
"floor_div",
"mod",
"pow",
"clip",
"second",
"identity",
"_cast_mapping",
"convert_to_bool",
"convert_to_int8",
"convert_to_int16",
"convert_to_int32",
"convert_to_int64",
"convert_to_uint8",
"convert_to_uint16",
"convert_to_uint32",
"convert_to_uint64",
"convert_to_float16",
"convert_to_float32",
"convert_to_float64",
"convert_to_complex64",
"convert_to_complex128",
"abs",
"sgn",
"ceil",
"floor",
"trunc",
"round_half_to_even",
"round_half_away_from_zero_vec64",
"round_half_away_from_zero_vec32",
"round_half_away_from_zero",
"neg",
"reciprocal",
"log",
"log2",
"log10",
"log1p",
"exp",
"exp2",
"expm1",
"sqr",
"sqrt",
"deg2rad",
"rad2deg",
"cos",
"arccos",
"sin",
"arcsin",
"tan",
"arctan",
"arctan2",
"cosh",
"arccosh",
"sinh",
"arcsinh",
"tanh",
"arctanh",
"real",
"imag",
"angle",
"complex",
"conj",
"complex_from_polar",
"composite_f32",
]
24 changes: 24 additions & 0 deletions aesara/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,27 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:


__all__ = ["random"] # noqa: F405

# isort: off
from aesara.tensor.math import DEPRECATED_NAMES as MATH_DEPRECATED_NAMES

# isort: on


DEPRECATED_NAMES = MATH_DEPRECATED_NAMES


def __getattr__(name):
"""Intercept module-level attribute access of deprecated symbols.

Adapted from https://stackoverflow.com/a/55139609/3006474.

"""
from warnings import warn

for old_name, msg, old_object in DEPRECATED_NAMES:
if name == old_name:
warn(msg, DeprecationWarning, stacklevel=2)
return old_object

raise AttributeError(f"module {__name__} has no attribute {name}")
5 changes: 5 additions & 0 deletions aesara/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3145,6 +3145,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
DEPRECATED_NAMES = [
("abs_", "`abs_` is deprecated; use `abs` instead.", abs),
("inv", "`inv` is deprecated; use `reciprocal` instead.", reciprocal),
(
"true_div",
"`true_div` is deprecated; use `true_divide` or `divide` instead.",
true_divide,
),
]


Expand Down
28 changes: 28 additions & 0 deletions tests/scalar/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,31 @@ def test_shape():
assert b.shape.type.ndim == 1
assert b.shape.type.shape == (0,)
assert b.shape.type.dtype == "int64"


def test_deprecations():
"""Make sure we can import from deprecated modules."""

with pytest.deprecated_call():
from aesara.scalar.basic import true_div # noqa: F401 F811

with pytest.deprecated_call():
from aesara.scalar import true_div # noqa: F401 F811

with pytest.deprecated_call():
from aesara.scalar.basic import Inv # noqa: F401 F811

with pytest.deprecated_call():
from aesara.scalar import Inv # noqa: F401 F811

with pytest.deprecated_call():
from aesara.scalar.basic import inv # noqa: F401 F811

with pytest.deprecated_call():
from aesara.scalar import inv # noqa: F401 F811

with pytest.deprecated_call():
from aesara.scalar.basic import Scalar # noqa: F401 F811

with pytest.deprecated_call():
from aesara.scalar import Scalar # noqa: F401 F811
22 changes: 22 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3535,3 +3535,25 @@ def test_logdiffexp():
if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Exp)
]
assert len(ops_graph) != 2


def test_deprecations():
"""Make sure we can import from deprecated modules."""

with pytest.deprecated_call():
from aesara.tensor.math import abs_ # noqa: F401 F811

with pytest.deprecated_call():
from aesara.tensor import abs_ # noqa: F401 F811

with pytest.deprecated_call():
from aesara.tensor import inv # noqa: F401 F811

with pytest.deprecated_call():
from aesara.tensor.math import inv # noqa: F401 F811

with pytest.deprecated_call():
from aesara.tensor import true_div # noqa: F401 F811

with pytest.deprecated_call():
from aesara.tensor.math import true_div # noqa: F401 F811