-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Cleanup trace initialization and type hints #5420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
To reduce type confusion.
Codecov Report
@@ Coverage Diff @@
## main #5420 +/- ##
==========================================
+ Coverage 81.39% 81.40% +0.01%
==========================================
Files 82 82
Lines 14213 14213
==========================================
+ Hits 11568 11570 +2
+ Misses 2645 2643 -2
|
The code largely assumed it to be `int`, even though it was often advertised as `int | None`.
With this change, the `_iter_sample` and `_prepare_iter_population` take the `tune` into account for the expected length of traces. This was not the case beforehand and I don't understand why it worked.
00c229d
to
40f82c3
Compare
Also fixes an incorrectly named kwarg in some logger calls.
603636d
to
2d87b9d
Compare
The changes reduce the number of errors found by |
@@ -61,7 +62,7 @@ | |||
"NoDistribution", | |||
] | |||
|
|||
DIST_PARAMETER_TYPES = Union[np.ndarray, int, float, TensorVariable] | |||
DIST_PARAMETER_TYPES: TypeAlias = Union[np.ndarray, int, float, TensorVariable] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im just curiouys, what does TypeAlias do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a type hint itself. The mypy errors and documentation told me that aliases should be annotated like this.
@@ -1393,6 +1397,27 @@ def _choose_backend(trace: Optional[Union[BaseTrace, List[str]]], **kwds) -> Bac | |||
return NDArray(vars=trace, **kwds) | |||
|
|||
|
|||
def _init_trace( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need a test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's covered extensively by the existing tests, though I agree that after extraction it is now easier to test it separately.
But my plan is to refactor this into using McBackend. This _init_trace()
will then become run.init_chain()
and then I'd also invest in the tests.
@@ -1217,7 +1220,7 @@ def _prepare_iter_population( | |||
step, | |||
start: Sequence[PointType], | |||
parallelize: bool, | |||
tune=None, | |||
tune: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove the optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The (only) call to
_prepare_iter_population
is in a context where tune is alreadyint
. - It's passed down into
_iter_population
which assumedtune: int
.
...because its an internal function called in exactly one place, it's safe to make it non-optional. The alternative would be to pick an arbitrary default, which we do already in pm.sample([tune=1000])
.
This PR refactors the code and type hints in
sampling.py
that are related to trace backend initialization and related type hints.I found several type hints that were too loose or even wrong.
tune
was advertised asint | None
but the code inside assumed it to beint
.trace
kwargs were advertised asBaseTrace | MultiTrace
but the code inside called methods that are only available withBaseTrace
.Finally, this tightening enabled a consolidation of the trace initialization (
_choose_backend()
+strace.setup()
) such that now all of_iter_sample
,_mp_sample
,_prepare_iter_population
make the same call to_init_trace
.