diff --git a/functional_algorithms/tests/test_accuracy.py b/functional_algorithms/tests/test_accuracy.py index 4c34e14..df5026e 100644 --- a/functional_algorithms/tests/test_accuracy.py +++ b/functional_algorithms/tests/test_accuracy.py @@ -58,8 +58,9 @@ def test_unary(unary_func_name, backend, device, dtype, fpu): if "disable-DAZ" in fpu: register_params.update(DAZ=False) + numpy_with_backend = getattr(fa.utils, f"numpy_with_{backend}")(device=device, dtype=dtype) try: - func = getattr(getattr(fa.utils, f"numpy_with_{backend}")(device=device, dtype=dtype), unary_func_name) + func = getattr(numpy_with_backend, unary_func_name) except NotImplementedError as msg: pytest.skip(f"{unary_func_name}: {msg}") except AttributeError as msg: @@ -80,7 +81,7 @@ def test_unary(unary_func_name, backend, device, dtype, fpu): fi = numpy.finfo(dtype) x = numpy.sqrt(fi.smallest_normal) * dtype(0.5) with register(**register_params): - v1 = getattr(getattr(fa.utils, f"numpy_with_{backend}")(device=device, dtype=dtype), "square")(x) + v1 = getattr(numpy_with_backend, "square")(x) v2 = numpy.square(x) d = fa.utils.diff_ulp(v1, v2) if d > 1000: diff --git a/functional_algorithms/utils.py b/functional_algorithms/utils.py index e06aacb..bc95ff7 100644 --- a/functional_algorithms/utils.py +++ b/functional_algorithms/utils.py @@ -374,6 +374,7 @@ def backend_is_available(cls, device): def __init__(self, *args, **kwargs): self.device = kwargs.pop("device", "cpu") + kwargs.pop("dtype", None) super().__init__(*args, **kwargs) @@ -500,6 +501,10 @@ class vectorize_with_mpmath(vectorize_with_backend): longdouble=numpy.nextafter(numpy.longdouble(numpy.inf), numpy.longdouble(0)), ) + @classmethod + def backend_is_available(cls, device): + return device == "cpu" + def __init__(self, *args, **kwargs): self.extra_prec_multiplier = kwargs.pop("extra_prec_multiplier", 0) self.extra_prec = kwargs.pop("extra_prec", 0) @@ -508,7 +513,6 @@ def __init__(self, *args, **kwargs): self._contexts = None self._contexts_inv = None super().__init__(*args, **kwargs) - assert self.device == "cpu", self.device def __getstate__(self): state = self.__dict__.copy() @@ -1002,11 +1006,12 @@ def __getattr__(self, name): name = dict(asinh="arcsinh", acos="arccos", asin="arcsin", acosh="arccosh", atan="arctan", atanh="arctanh").get( name, name ) - if name in self._vfunc_cache: - return self._vfunc_cache[name] + key = name, tuple(sorted(self.params.items())) + if key in self._vfunc_cache: + return self._vfunc_cache[key] if hasattr(mpmath_array_api, name): vfunc = vectorize_with_mpmath(getattr(mpmath_array_api(), name), **self.params) - self._vfunc_cache[name] = vfunc + self._vfunc_cache[key] = vfunc return vfunc raise NotImplementedError(f"vectorize_with_mpmath.{name}") @@ -1085,12 +1090,13 @@ def __getattr__(self, name): name = dict(asinh="arcsinh", acos="arccos", asin="arcsin", acosh="arccosh", atan="arctan", atanh="arctanh").get( name, name ) - if name in self._vfunc_cache: - return self._vfunc_cache[name] + key = name, tuple(sorted(self.params.items())) + if key in self._vfunc_cache: + return self._vfunc_cache[key] import jax vfunc = vectorize_with_jax(getattr(jax.numpy, name), **self.params) - self._vfunc_cache[name] = vfunc + self._vfunc_cache[key] = vfunc return vfunc @@ -1108,12 +1114,13 @@ def __getattr__(self, name): name = dict(asinh="arcsinh", acos="arccos", asin="arcsin", acosh="arccosh", atan="arctan", atanh="arctanh").get( name, name ) - if name in self._vfunc_cache: - return self._vfunc_cache[name] + key = name, tuple(sorted(self.params.items())) + if key in self._vfunc_cache: + return self._vfunc_cache[key] import numpy vfunc = numpy.vectorize(getattr(numpy, name), **self.params) - self._vfunc_cache[name] = vfunc + self._vfunc_cache[key] = vfunc return vfunc @@ -1132,7 +1139,7 @@ def __getattr__(self, name): name, name ) dtype = self.params["dtype"] - key = name, dtype.__name__ + key = name, tuple(sorted(self.params.items())) if key in self._vfunc_cache: return self._vfunc_cache[key] @@ -1853,6 +1860,8 @@ def function_validation_parameters(func_name, dtype): extra_prec_multiplier = 20 elif func_name in {"tanh", "tan"}: extra_prec_multiplier = 20 + elif func_name == "loq1p": + max_valid_ulp_count = 4 return dict( extra_prec_multiplier=extra_prec_multiplier, max_valid_ulp_count=max_valid_ulp_count,