Skip to content

Remove the (NumPy-) backend-specific hierarchy of weighting classes #1686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
38 changes: 36 additions & 2 deletions odl/space/base_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""Base classes for implementations of tensor spaces."""

from __future__ import absolute_import, division, print_function

from typing import Optional
from numbers import Integral

import numpy as np
Expand All @@ -24,6 +24,7 @@
signature_string, writable_array)
from odl.util.ufuncs import TensorSpaceUfuncs
from odl.util.utility import TYPE_MAP_C2R, TYPE_MAP_R2C, nullcontext
from .weighting import Weighting, ArrayWeighting, ConstWeighting

__all__ = ('TensorSpace',)

Expand Down Expand Up @@ -62,7 +63,7 @@ class TensorSpace(LinearSpace):
.. _Wikipedia article on tensors: https://en.wikipedia.org/wiki/Tensor
"""

def __init__(self, shape, dtype):
def __init__(self, shape, dtype, weighting : Optional[Weighting] =None ):
"""Initialize a new instance.

Parameters
Expand Down Expand Up @@ -109,6 +110,36 @@ def __init__(self, shape, dtype):
else:
field = None

# Set the weighting
if weighting is not None:
if isinstance(weighting, Weighting):
if weighting.impl != self.impl:
raise ValueError(f"The Weighting and the TensorSpace implementations must match, but {weighting.impl} and {self.impl} were provided")

# Check (afterwards) that the weighting input was sane
if isinstance(self.weighting, ArrayWeighting):
if self.weighting.array.dtype == object:
raise ValueError('invalid `weighting` argument: {}'
''.format(weighting))
elif not np.can_cast(self.weighting.array.dtype, self.dtype):
raise ValueError(
'cannot cast from `weighting` data type {} to '
'the space `dtype` {}'
''.format(dtype_str(self.weighting.array.dtype),
dtype_str(self.dtype)))
if self.weighting.array.shape != self.shape:
raise ValueError('array-like weights must have same '
'shape {} as this space, got {}'
''.format(self.shape,
self.weighting.array.shape))
self.__weighting = weighting
else:
raise TypeError(f"The weighting can only be an ODL Weighting, but {type(weighting)} was provided.")
else:
if self.dtype == bool:
self.__weighting = Weighting(self.impl)
else:
self.__weighting = ConstWeighting(const=1, impl=self.impl)
LinearSpace.__init__(self, field)

########## static methods ##########
Expand Down Expand Up @@ -294,6 +325,9 @@ def size(self):
return (0 if self.shape == () else
int(np.prod(self.shape, dtype='int64')))

@property
def weighting(self):
return self.__weighting
########## public methods ##########
def astype(self, dtype):
"""Return a copy of this space with new ``dtype``.
Expand Down
Loading