Skip to content

Commit e444766

Browse files
committed
Test JAX support
1 parent d7672ec commit e444766

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

modules/foo/test/test_restraint2.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import IMP.core
55
import IMP.foo
66
import pickle
7+
try:
8+
import jax
9+
except ImportError:
10+
jax = None
711

812

913
class Tests(IMP.test.TestCase):
@@ -58,6 +62,24 @@ def test_serialize_polymorphic(self):
5862
newsf = pickle.loads(dump)
5963
self.assertAlmostEqual(newsf.evaluate(False), 45.0, delta=1e-3)
6064

65+
@IMP.test.skipIf(jax is None or IMP.__version__ == '2.23.0',
66+
'No JAX support')
67+
def test_my_restraint_jax(self):
68+
"""Test scoring of MyRestraint2 using JAX"""
69+
m = IMP.Model()
70+
p = m.add_particle("p")
71+
d = IMP.core.XYZ.setup_particle(m, p, IMP.algebra.Vector3D(1,2,3))
72+
r = IMP.foo.MyRestraint2(m, p, 10.)
73+
ji = r._get_jax()
74+
jm = ji.get_model_state()
75+
score = jax.jit(ji.score_func)
76+
self.assertAlmostEqual(score(jm), 45.0, delta=1e-4)
77+
deriv = jax.jit(jax.grad(ji.score_func))
78+
d = deriv(jm)['xyz']
79+
self.assertLess(IMP.algebra.get_distance(d[0],
80+
IMP.algebra.Vector3D(0,0,30)),
81+
1e-4)
82+
6183

6284
if __name__ == '__main__':
6385
IMP.test.main()

support/setup_ci.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ if [ ${imp_branch} = "develop" ]; then
1818
else
1919
IMP_CONDA="imp"
2020
fi
21-
conda create --yes -q -n python${python_version} -c salilab python=${python_version} pip ${IMP_CONDA} libboost-devel gxx_linux-64 eigen cereal swig cmake numpy
21+
conda create --yes -q -n python${python_version} -c salilab python=${python_version} pip ${IMP_CONDA} libboost-devel gxx_linux-64 eigen cereal swig cmake numpy jax
2222
eval "$(conda shell.bash hook)"
2323
conda activate python${python_version}
2424

0 commit comments

Comments
 (0)