Skip to content

Conversation

@hannes-holey
Copy link

Hi, thank you for tinygp.

This PR attempts to replace the jax.experimental.enable_x64 context manager which is deprecated in jax>=0.9.0.

While testing this change on current GitHub runners some other JAX-related issues surfaced:

See details
Downloading cpython-3.10.19-linux-x86_64-gnu (download) (28.5MiB)
 Downloaded cpython-3.10.19-linux-x86_64-gnu (download)
Using CPython 3.10.19
Creating virtual environment at: .venv
   Building tinygp @ file:///home/runner/work/tinygp/tinygp
Downloading pygments (1.2MiB)
Downloading scipy (35.9MiB)
Downloading numpy (16.0MiB)
Downloading ml-dtypes (4.8MiB)
Downloading jaxlib (85.8MiB)
Downloading jax (2.6MiB)
      Built tinygp @ file:///home/runner/work/tinygp/tinygp
 Downloaded ml-dtypes
 Downloaded pygments
 Downloaded jax
 Downloaded jaxlib
 Downloaded numpy
 Downloaded scipy
Installed 20 packages in 41ms
============================= test session starts ==============================
platform linux -- Python 3.10.19, pytest-9.0.2, pluggy-1.6.0
rootdir: /home/runner/work/tinygp/tinygp
configfile: pyproject.toml
plugins: xdist-3.8.0, jaxtyping-0.3.6
created: 4/4 workers
4 workers [130 items]

.F...F.................................................................. [ 55%]
.....................................................s....               [100%]
=================================== FAILURES ===================================
_________________________________ test_sample __________________________________
[gw0] linux -- Python 3.10.19 /home/runner/work/tinygp/tinygp/.venv/bin/python
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

data = (array([[ 1.44739616, -2.54217712, -1.74711698, -1.02426555,  2.73445279],
       [ 0.68050516,  0.00431724, -0.836095....52507338,  2.89057525],
       [ 0.59347335, -2.85087174, -1.22763876,  0.39915777, -0.62558987]]), 50.35871298268382)

    def test_sample(data):
        X, _ = data
    
        with jax_enable_x64():
            gp = GaussianProcess(
                kernels.Matern32(1.5), X, diag=0.01, mean=lambda x: jnp.sum(x)
            )
>           y = gp.sample(jax.random.PRNGKey(543))

tests/test_gp.py:31: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = GaussianProcess(
  num_data=50,
  dtype=dtype('float64'),
  kernel=Matern32(scale=1.5, distance=L1Distance()),
  X=f64...r(
    X=f64[50,5](numpy),
    variance_value=f64[50],
    covariance_value=f64[50,50],
    scale_tril=f64[50,50]
  )
)
key = Array([  0, 543], dtype=uint32), shape = None

    def sample(
        self,
        key: jax.random.KeyArray,
        shape: Sequence[int] | None = None,
    ) -> JAXArray:
        """Generate samples from the prior process
    
        Args:
            key: A ``jax`` random number key array. shape (tuple, optional): The
            number and shape of samples to
                generate.
    
        Returns:
            The sampled realizations from the process with shape ``(N_data,) +
            shape`` where ``N_data`` is the zeroth dimension of the ``X``
            coordinates provided when instantiating this process.
        """
>       return self._sample(key, shape)
E       TypeError: Error interpreting argument to <function GaussianProcess._sample at 0x7f8a7a0f1750> as an abstract array. The problematic value is of type <class 'equinox._module._flatten._Missing'> and was passed to the function at path self.mean_function.value.
E       This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

src/tinygp/gp.py:291: TypeError
__________________________________ test_means __________________________________
[gw0] linux -- Python 3.10.19 /home/runner/work/tinygp/tinygp/.venv/bin/python
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

data = (array([[ 1.44739616, -2.54217712, -1.74711698, -1.02426555,  2.73445279],
       [ 0.68050516,  0.00431724, -0.836095....52507338,  2.89057525],
       [ 0.59347335, -2.85087174, -1.22763876,  0.39915777, -0.62558987]]), 50.35871298268382)

    def test_means(data):
        X, y = data
    
        gp1 = GaussianProcess(kernels.Matern32(1.5), X, diag=0.01, mean=lambda x: 0.0)
        gp2 = GaussianProcess(kernels.Matern32(1.5), X, diag=0.01, mean=0.0)
        gp3 = GaussianProcess(kernels.Matern32(1.5), X, diag=0.01)
    
        assert_allclose(gp1.mean, gp2.mean)
        assert_allclose(gp1.mean, gp3.mean)
>       assert_allclose(gp1.log_probability(y), gp2.log_probability(y))

tests/test_gp.py:52: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = GaussianProcess(
  num_data=50,
  dtype=dtype('float32'),
  kernel=Matern32(scale=1.5, distance=L1Distance()),
  X=f64...r(
    X=f64[50,5](numpy),
    variance_value=f32[50],
    covariance_value=f32[50,50],
    scale_tril=f32[50,50]
  )
)
y = 50.35871298268382

    def log_probability(self, y: JAXArray) -> JAXArray:
        """Compute the log probability of this multivariate normal
    
        Args:
            y (JAXArray): The observed data. This should have the shape
                ``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
                data provided when instantiating this object.
    
        Returns:
            The marginal log probability of this multivariate normal model,
            evaluated at ``y``.
        """
>       return self._compute_log_prob(self._get_alpha(y))
E       TypeError: Error interpreting argument to <function GaussianProcess._get_alpha at 0x7f8a7a0f1ab0> as an abstract array. The problematic value is of type <class 'equinox._module._flatten._Missing'> and was passed to the function at path self.mean_function.value.
E       This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

src/tinygp/gp.py:139: TypeError
=========================== short test summary info ============================
FAILED tests/test_gp.py::test_sample - TypeError: Error interpreting argument to <function GaussianProcess._sample at 0x7f8a7a0f1750> as an abstract array. The problematic value is of type <class 'equinox._module._flatten._Missing'> and was passed to the function at path self.mean_function.value.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.
FAILED tests/test_gp.py::test_means - TypeError: Error interpreting argument to <function GaussianProcess._get_alpha at 0x7f8a7a0f1ab0> as an abstract array. The problematic value is of type <class 'equinox._module._flatten._Missing'> and was passed to the function at path self.mean_function.value.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.
============= 2 failed, 127 passed, 2 skipped in 62.09s (0:01:02) ==============

As far as I can tell, the problem was the None attribute of the Mean class when initiated with a callable. This PR should fix this by setting its value attribute to a valid (dummy) JAX array.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant