Skip to content

Old versions much faster? #711

@cgiovanetti

Description

@cgiovanetti

I have a large project that's intended to run on CPU--during development it was running in well under a second per evaluation, but now I find it runs almost twice as slow.

I have more or less pinned down the issue to be a problem with my JAX version--0.4.26 and 0.4.28 runs fast, and everything later is slower. However, in creating an MWE I can only reproduce if I use diffrax, and so I wonder if there's some interaction with later versions causing a slowdown. I also opened a related issue on the JAX GitHub.

Here is the MWE:

import jax
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt

import time

def fun(rtol=1e-8, atol=1e-10,solver=Tsit5()): 
    T_EM_init = 8.6
    rho_extra_init = 830.

    Y0 = (0., T_EM_init)

    sol = diffeqsolve(
        ODETerm(dY), solver, args=(rho_extra_init),
        t0 = 0., t1=100., dt0=None, y0=Y0, 
        saveat=SaveAt(steps=True), 
        stepsize_controller = PIDController(
            rtol=rtol, atol=atol
        ), 
        max_steps=512
    )

    a_vec = jnp.exp(sol.ys[0])

    return (
        a_vec
    )

def dY(t, Y, args): 
    lna, T_g = Y
    rho_extra_init = args

    rho_EM = T_g**4
    rho_extra = rho_extra_init * 1. / jnp.exp(lna)**4 

    H = (rho_EM + rho_extra)**0.5
    drho_EM_dt = -3 * H * rho_EM
    dT_g_dt = drho_EM_dt / (4*T_g**3)

    return H, dT_g_dt

for i in range(5):
    start = time.time()
    a_vec = jax.block_until_ready(jax.jit(fun)())
    print(time.time() - start)

Running with jax/jaxlib==0.4.28 and diffrax==0.6.0, each compiled iteration runs in ~0.00015s on an M1 mac. Running with either jax/lib==0.6.2 or 0.8.1 and diffrax==0.7.0, each compiled iteration runs in 0.0004s on the same hardware. I suspect this is not directly an issue with newer versions of diffrax, because I also see a slowdown in my production code with jax/lib==0.4.29 and diffrax==0.6.0, but maybe something second-order in the way diffrax is drawing on JAX?

In absolute terms it's not much, but I suspect this is also translating to my factor-of-two slowdown in my actual production code. Incidentally I find that if my differential equation is a function of only one variable--i.e., I track only H and not dT_g_dt, I find the compiled code runs faster on newer versions. Using eqx.filter_jit doesn't seem to help much in any case.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions