Skip to content

Cleanup trace initialization and type hints #5420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 3, 2022

Conversation

michaelosthege
Copy link
Member

@michaelosthege michaelosthege commented Jan 29, 2022

This PR refactors the code and type hints in sampling.py that are related to trace backend initialization and related type hints.

I found several type hints that were too loose or even wrong.

  • In multiple places tune was advertised as int | None but the code inside assumed it to be int.
  • Similarly, multiple trace kwargs were advertised as BaseTrace | MultiTrace but the code inside called methods that are only available with BaseTrace.
  • After tightening the input types, I fixed the return type hints in a similar fashion.

Finally, this tightening enabled a consolidation of the trace initialization (_choose_backend() + strace.setup()) such that now all of _iter_sample, _mp_sample, _prepare_iter_population make the same call to _init_trace.

To reduce type confusion.
@codecov
Copy link

codecov bot commented Jan 29, 2022

Codecov Report

Merging #5420 (2d87b9d) into main (0dca647) will increase coverage by 0.01%.
The diff coverage is 85.89%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5420      +/-   ##
==========================================
+ Coverage   81.39%   81.40%   +0.01%     
==========================================
  Files          82       82              
  Lines       14213    14213              
==========================================
+ Hits        11568    11570       +2     
+ Misses       2645     2643       -2     
Impacted Files Coverage Δ
pymc/distributions/shape_utils.py 96.73% <75.00%> (-2.14%) ⬇️
pymc/sampling.py 86.06% <88.33%> (+0.22%) ⬆️
pymc/distributions/distribution.py 91.43% <100.00%> (+0.03%) ⬆️
pymc/parallel_sampling.py 87.70% <0.00%> (+0.99%) ⬆️

The code largely assumed it to be `int`,
even though it was often advertised as `int | None`.
With this change, the `_iter_sample` and `_prepare_iter_population`
take the `tune` into account for the expected length of traces.
This was not the case beforehand and I don't understand why it worked.
Also fixes an incorrectly named kwarg in some logger calls.
@michaelosthege
Copy link
Member Author

The changes reduce the number of errors found by mypy -p pymc from 362 to 266.

https://www.diffchecker.com/aHofAXhG

@michaelosthege michaelosthege marked this pull request as ready for review January 30, 2022 00:21
@@ -61,7 +62,7 @@
"NoDistribution",
]

DIST_PARAMETER_TYPES = Union[np.ndarray, int, float, TensorVariable]
DIST_PARAMETER_TYPES: TypeAlias = Union[np.ndarray, int, float, TensorVariable]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im just curiouys, what does TypeAlias do?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a type hint itself. The mypy errors and documentation told me that aliases should be annotated like this.

@@ -1393,6 +1397,27 @@ def _choose_backend(trace: Optional[Union[BaseTrace, List[str]]], **kwds) -> Bac
return NDArray(vars=trace, **kwds)


def _init_trace(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need a test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's covered extensively by the existing tests, though I agree that after extraction it is now easier to test it separately.
But my plan is to refactor this into using McBackend. This _init_trace() will then become run.init_chain() and then I'd also invest in the tests.

@michaelosthege michaelosthege added this to the v4.0.0b3 milestone Jan 31, 2022
@@ -1217,7 +1220,7 @@ def _prepare_iter_population(
step,
start: Sequence[PointType],
parallelize: bool,
tune=None,
tune: int,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove the optional?

Copy link
Member Author

@michaelosthege michaelosthege Jan 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The (only) call to _prepare_iter_population is in a context where tune is already int.
  2. It's passed down into _iter_population which assumed tune: int.

...because its an internal function called in exactly one place, it's safe to make it non-optional. The alternative would be to pick an arbitrary default, which we do already in pm.sample([tune=1000]).

@twiecki twiecki merged commit 2bb0c7c into pymc-devs:main Feb 3, 2022
@michaelosthege michaelosthege deleted the cleanup-trace-init branch February 3, 2022 20:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants