Skip to content

Commit 69f3304

Browse files
Refactor SharedVariable type and interface
1 parent d7be8fb commit 69f3304

File tree

4 files changed

+64
-39
lines changed

4 files changed

+64
-39
lines changed

aesara/compile/sharedvalue.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33
import copy
44
from contextlib import contextmanager
55
from functools import singledispatch
6-
from typing import List, Optional
6+
from typing import TYPE_CHECKING, List, Optional
77

88
from aesara.graph.basic import Variable
99
from aesara.graph.utils import add_tag_trace
1010
from aesara.link.basic import Container
1111
from aesara.link.c.type import generic
1212

1313

14+
if TYPE_CHECKING:
15+
from aesara.graph.type import Type
16+
17+
1418
__SHARED_CONTEXT__: Optional[List[Variable]] = None
1519

1620

@@ -30,14 +34,39 @@ def collect_new_shareds():
3034
class SharedVariable(Variable):
3135
"""Variable that is shared between compiled functions."""
3236

33-
container: Optional[Container] = None
34-
"""
35-
A container to use for this SharedVariable when it is an implicit
36-
function parameter.
37-
"""
37+
def __init__(
38+
self,
39+
type: "Type",
40+
value,
41+
strict: bool,
42+
allow_downcast=None,
43+
container: Optional[Container] = None,
44+
name: Optional[str] = None,
45+
):
46+
r"""
47+
Parameters
48+
----------
49+
type
50+
The `Type` for this variable (see `Variable`).
51+
value
52+
A value to associate with this variable (a new container will be
53+
created).
54+
strict
55+
``True`` means that values assigned to this variable will not be
56+
cast or copied, so they must have the correct `Type`\s.
57+
allow_downcast
58+
Only applies if `strict` is ``False``.
59+
``True`` means that the assigned value can lose precision when cast
60+
during assignment. ``None`` means that only down-casting of a Python
61+
float to a scalar ``floatX`` is allowed.
62+
container
63+
The container to use for this variable. Illegal to pass this as well as
64+
a value.
65+
name
66+
The name for this variable (see `Variable`).
3867
39-
def __init__(self, name, type, value, strict, allow_downcast=None, container=None):
40-
super().__init__(type=type, name=name, owner=None, index=None)
68+
"""
69+
super().__init__(type=type, owner=None, index=None, name=name)
4170

4271
if container is not None:
4372
self.container = container
@@ -107,26 +136,6 @@ def set_value(self, new_value, borrow=False):
107136
def get_test_value(self):
108137
return self.get_value(borrow=True, return_internal_type=True)
109138

110-
def zero(self, borrow=False):
111-
"""
112-
Set the values of a shared variable to 0.
113-
114-
Parameters
115-
----------
116-
borrow : bbol
117-
True to modify the value of a shared variable directly by using
118-
its previous value. Potentially this can cause problems
119-
regarding to the aliased memory.
120-
121-
Changes done with this function will be visible to all functions using
122-
this SharedVariable.
123-
124-
"""
125-
if borrow:
126-
self.container.value[...] = 0
127-
else:
128-
self.container.value = 0 * self.container.value
129-
130139
def clone(self, **kwargs):
131140
name = kwargs.get("name", self.name)
132141
cp = self.__class__(
@@ -209,7 +218,7 @@ def shared_constructor(value, name=None, strict=False, allow_downcast=None, **kw
209218
return SharedVariable(
210219
type=generic,
211220
value=value,
212-
name=name,
213221
strict=strict,
214222
allow_downcast=allow_downcast,
223+
name=name,
215224
)

aesara/sparse/sharedvar.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import scipy.sparse
44

5-
from aesara.compile import SharedVariable, shared_constructor
5+
from aesara.compile import shared_constructor
66
from aesara.sparse.basic import SparseTensorType, _sparse_py_operators
7+
from aesara.tensor.sharedvar import TensorSharedVariable
78

89

9-
class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable):
10-
dtype = property(lambda self: self.type.dtype)
11-
format = property(lambda self: self.type.format)
10+
class SparseTensorSharedVariable(TensorSharedVariable, _sparse_py_operators):
11+
pass
1212

1313

1414
@shared_constructor.register(scipy.sparse.spmatrix)
@@ -24,5 +24,5 @@ def sparse_constructor(
2424
value = copy.deepcopy(value)
2525

2626
return SparseTensorSharedVariable(
27-
type=type, value=value, name=name, strict=strict, allow_downcast=allow_downcast
27+
type=type, value=value, strict=strict, allow_downcast=allow_downcast, name=name
2828
)

aesara/tensor/random/var.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def randomgen_constructor(
3737
return rng_sv_type(
3838
type=rng_type,
3939
value=value,
40-
name=name,
4140
strict=strict,
4241
allow_downcast=allow_downcast,
42+
name=name,
4343
)

aesara/tensor/sharedvar.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,25 @@ def load_shared_variable(val):
1919
return tensor_constructor(val)
2020

2121

22-
# _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__
2322
class TensorSharedVariable(_tensor_py_operators, SharedVariable):
24-
pass
23+
def zero(self, borrow: bool = False):
24+
r"""Set the values of a shared variable to 0.
25+
26+
Parameters
27+
----------
28+
borrow
29+
``True`` to modify the value of a shared variable directly by using
30+
its previous value. Potentially this can cause problems regarding
31+
to the aliased memory.
32+
33+
Changes done with this function will be visible to all functions using
34+
this `SharedVariable`.
35+
36+
"""
37+
if borrow:
38+
self.container.value[...] = 0
39+
else:
40+
self.container.value = 0 * self.container.value
2541

2642

2743
@_get_vector_length.register(TensorSharedVariable)
@@ -69,13 +85,13 @@ def tensor_constructor(
6985
return TensorSharedVariable(
7086
type=type,
7187
value=np.array(value, copy=(not borrow)),
72-
name=name,
7388
strict=strict,
7489
allow_downcast=allow_downcast,
90+
name=name,
7591
)
7692

7793

78-
class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
94+
class ScalarSharedVariable(TensorSharedVariable):
7995
pass
8096

8197

0 commit comments

Comments
 (0)