diff --git a/src/tinygp/means.py b/src/tinygp/means.py index 9811d634..0db37b17 100644 --- a/src/tinygp/means.py +++ b/src/tinygp/means.py @@ -39,18 +39,18 @@ class Mean(MeanBase): signature. """ - value: JAXArray | None = None + value: JAXArray func: Callable[[JAXArray], JAXArray] | None = eqx.field(default=None, static=True) def __init__(self, value: JAXArray | Callable[[JAXArray], JAXArray]): if callable(value): self.func = value + self.value = jax.numpy.zeros(()) else: self.value = value def __call__(self, X: JAXArray) -> JAXArray: - if self.value is None: - assert self.func is not None + if self.func is not None: return self.func(X) return self.value diff --git a/src/tinygp/test_utils.py b/src/tinygp/test_utils.py index dd469d84..e55d211b 100644 --- a/src/tinygp/test_utils.py +++ b/src/tinygp/test_utils.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import Any import jax @@ -30,3 +31,29 @@ def assert_pytrees_allclose(calculated: Any, expected: Any, *args: Any, **kwargs jax.tree_util.tree_map( lambda a, b: assert_allclose(a, b, *args, **kwargs), calculated, expected ) + + +def _as_context_manager(obj): + # If it's already a context manager + if hasattr(obj, "__enter__") and hasattr(obj, "__exit__"): + return obj + + # If it's a generator, wrap it + if hasattr(obj, "__iter__") and hasattr(obj, "send"): + return contextmanager(lambda: obj)() + + raise TypeError("Object is neither a context manager nor a generator") + + +@contextmanager +def jax_enable_x64(): + if hasattr(jax, "enable_x64"): + cm = jax.enable_x64(True) + else: + # deprecated in jax>=0.9 + from jax.experimental import enable_x64 as _enable_x64 + + cm = _enable_x64() + + with _as_context_manager(cm): + yield diff --git a/tests/test_gp.py b/tests/test_gp.py index 0c478e0c..be77c11b 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -6,7 +6,7 @@ from numpy import random as np_random from tinygp import GaussianProcess, kernels -from tinygp.test_utils import assert_allclose +from tinygp.test_utils import assert_allclose, jax_enable_x64 @pytest.fixture @@ -24,7 +24,7 @@ def data(random): def test_sample(data): X, _ = data - with jax.experimental.enable_x64(True): + with jax_enable_x64(): gp = GaussianProcess( kernels.Matern32(1.5), X, diag=0.01, mean=lambda x: jnp.sum(x) ) diff --git a/tests/test_kernels/test_kernels.py b/tests/test_kernels/test_kernels.py index 6e05883d..f41347fc 100644 --- a/tests/test_kernels/test_kernels.py +++ b/tests/test_kernels/test_kernels.py @@ -5,7 +5,7 @@ from tinygp import kernels, noise from tinygp.solvers import DirectSolver -from tinygp.test_utils import assert_allclose +from tinygp.test_utils import assert_allclose, jax_enable_x64 @pytest.fixture @@ -71,7 +71,7 @@ def test_ops(data): def test_conditioned(data): x1, x2 = data - with jax.experimental.enable_x64(): # type: ignore + with jax_enable_x64(): # type: ignore k1 = 1.5 * kernels.Matern32(2.5) k2 = 0.9 * kernels.ExpSineSquared(scale=1.5, gamma=0.3) K = k1(x1, x1) + 0.1 * jnp.eye(x1.shape[0])