-
-
Notifications
You must be signed in to change notification settings - Fork 166
Description
I have a large project that's intended to run on CPU--during development it was running in well under a second per evaluation, but now I find it runs almost twice as slow.
I have more or less pinned down the issue to be a problem with my JAX version--0.4.26 and 0.4.28 runs fast, and everything later is slower. However, in creating an MWE I can only reproduce if I use diffrax, and so I wonder if there's some interaction with later versions causing a slowdown. I also opened a related issue on the JAX GitHub.
Here is the MWE:
import jax
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt
import time
def fun(rtol=1e-8, atol=1e-10,solver=Tsit5()):
T_EM_init = 8.6
rho_extra_init = 830.
Y0 = (0., T_EM_init)
sol = diffeqsolve(
ODETerm(dY), solver, args=(rho_extra_init),
t0 = 0., t1=100., dt0=None, y0=Y0,
saveat=SaveAt(steps=True),
stepsize_controller = PIDController(
rtol=rtol, atol=atol
),
max_steps=512
)
a_vec = jnp.exp(sol.ys[0])
return (
a_vec
)
def dY(t, Y, args):
lna, T_g = Y
rho_extra_init = args
rho_EM = T_g**4
rho_extra = rho_extra_init * 1. / jnp.exp(lna)**4
H = (rho_EM + rho_extra)**0.5
drho_EM_dt = -3 * H * rho_EM
dT_g_dt = drho_EM_dt / (4*T_g**3)
return H, dT_g_dt
for i in range(5):
start = time.time()
a_vec = jax.block_until_ready(jax.jit(fun)())
print(time.time() - start)
Running with jax/jaxlib==0.4.28 and diffrax==0.6.0, each compiled iteration runs in ~0.00015s on an M1 mac. Running with either jax/lib==0.6.2 or 0.8.1 and diffrax==0.7.0, each compiled iteration runs in 0.0004s on the same hardware. I suspect this is not directly an issue with newer versions of diffrax, because I also see a slowdown in my production code with jax/lib==0.4.29 and diffrax==0.6.0, but maybe something second-order in the way diffrax is drawing on JAX?
In absolute terms it's not much, but I suspect this is also translating to my factor-of-two slowdown in my actual production code. Incidentally I find that if my differential equation is a function of only one variable--i.e., I track only H and not dT_g_dt, I find the compiled code runs faster on newer versions. Using eqx.filter_jit doesn't seem to help much in any case.