Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/tinygp/means.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ class Mean(MeanBase):
signature.
"""

value: JAXArray | None = None
value: JAXArray
func: Callable[[JAXArray], JAXArray] | None = eqx.field(default=None, static=True)

def __init__(self, value: JAXArray | Callable[[JAXArray], JAXArray]):
if callable(value):
self.func = value
self.value = jax.numpy.zeros(())
else:
self.value = value

def __call__(self, X: JAXArray) -> JAXArray:
if self.value is None:
assert self.func is not None
if self.func is not None:
return self.func(X)
return self.value

Expand Down
27 changes: 27 additions & 0 deletions src/tinygp/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import contextmanager
from typing import Any

import jax
Expand Down Expand Up @@ -30,3 +31,29 @@ def assert_pytrees_allclose(calculated: Any, expected: Any, *args: Any, **kwargs
jax.tree_util.tree_map(
lambda a, b: assert_allclose(a, b, *args, **kwargs), calculated, expected
)


def _as_context_manager(obj):
# If it's already a context manager
if hasattr(obj, "__enter__") and hasattr(obj, "__exit__"):
return obj

# If it's a generator, wrap it
if hasattr(obj, "__iter__") and hasattr(obj, "send"):
return contextmanager(lambda: obj)()

raise TypeError("Object is neither a context manager nor a generator")


@contextmanager
def jax_enable_x64():
if hasattr(jax, "enable_x64"):
cm = jax.enable_x64(True)
else:
# deprecated in jax>=0.9
from jax.experimental import enable_x64 as _enable_x64

cm = _enable_x64()

with _as_context_manager(cm):
yield
4 changes: 2 additions & 2 deletions tests/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy import random as np_random

from tinygp import GaussianProcess, kernels
from tinygp.test_utils import assert_allclose
from tinygp.test_utils import assert_allclose, jax_enable_x64


@pytest.fixture
Expand All @@ -24,7 +24,7 @@ def data(random):
def test_sample(data):
X, _ = data

with jax.experimental.enable_x64(True):
with jax_enable_x64():
gp = GaussianProcess(
kernels.Matern32(1.5), X, diag=0.01, mean=lambda x: jnp.sum(x)
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_kernels/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from tinygp import kernels, noise
from tinygp.solvers import DirectSolver
from tinygp.test_utils import assert_allclose
from tinygp.test_utils import assert_allclose, jax_enable_x64


@pytest.fixture
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_ops(data):

def test_conditioned(data):
x1, x2 = data
with jax.experimental.enable_x64(): # type: ignore
with jax_enable_x64(): # type: ignore
k1 = 1.5 * kernels.Matern32(2.5)
k2 = 0.9 * kernels.ExpSineSquared(scale=1.5, gamma=0.3)
K = k1(x1, x1) + 0.1 * jnp.eye(x1.shape[0])
Expand Down