Skip to content

dsx.infer for filtering and smoothing without numpyro#238

Open
theorashid wants to merge 6 commits into
BasisResearch:dsx-infer-stagingfrom
theorashid:refactor/decouple-numpyro-filters-smoothers
Open

dsx.infer for filtering and smoothing without numpyro#238
theorashid wants to merge 6 commits into
BasisResearch:dsx-infer-stagingfrom
theorashid:refactor/decouple-numpyro-filters-smoothers

Conversation

@theorashid
Copy link
Copy Markdown

This is for #213.

So far it is just for Filter and Smoother. Simulator will have to come.

Basic idea is that Filter/Smoother no longer call numpyro.factor inline, it happens in dsx.sample. test_filters and test_smoothers still pass, happy days.

# non-numpyro path: returns InferResult directly (no side effects)
with Filter(filter_config=KFConfig(...)):
    result = dsx.infer("f", dynamics, obs_times=t, obs_values=y)
loss = -result.marginal_loglik

# numpyro path (unchanged from before):
with Filter(filter_config=KFConfig(...)):
    dsx.sample("f", dynamics, obs_times=t, obs_values=y)

In the background, sample just calls infer and then registers sites. dsx.infer returns InferResult which carries marginal_loglik, states, dists, and a private _register_numpyro_sites callback.

def sample(...):
    result = infer(...)
    if isinstance(result, InferResult) and result._register_numpyro_sites is not None:
        result._register_numpyro_sites(name)
    return result

The main work can be seen in the test_*_standalone.py, which I based from my cuthbert-models repo. This is how I would see it being used.

def test_infer_optax_mle():
    """Use dsx.infer + optax to do MLE without numpyro."""
    obs_times, obs_values = _make_data()

    def neg_loglik(alpha):
        dynamics = _make_lti_dynamics(alpha)
        with Filter(filter_config=KFConfig(filter_source="cuthbert")):
            result = dsx.infer(
                "f", dynamics, obs_times=obs_times, obs_values=obs_values
            )
        return -result.marginal_loglik

    optimizer = optax.adam(1e-2)
    alpha = jnp.array(0.3)
    opt_state = optimizer.init(alpha)

    initial_loss = neg_loglik(alpha)
    grad_fn = jax.grad(neg_loglik)

    for _ in range(20):
        grads = grad_fn(alpha)
        updates, opt_state = optimizer.update(grads, opt_state)
        alpha = optax.apply_updates(alpha, updates)

    final_loss = neg_loglik(alpha)
    assert final_loss < initial_loss

how these changes were made

Of course, this was largely done by burning tokens and handholding so that it matches the design that I (and then with feedback from both of you) wanted. I am not as familiar with the internals of the library, so if there are other places beyond the diff that you think these changes might affect, let me know where to look – I was relying a bit on existing tests not breaking.

smaller design things

I largely tried to keep everything that was there before to reduce the size of the refactor. But:

  • I did make BaseLogFactorAdder (and the Filter equivalent) an ABC because I think it made sense.
  • Plate, as before, does not register per-field sites.
  • Maybe we should rename _sample_intp to _infer_intp.
  • InferResult a __call__ shim to satisfy FunctionOfTime protocol (needed because some model functions return dsx.sample(...)). Effectful uses the @defop return annotation to decide the type of fwd() returns. InferResult.__call__ raises NotImplementedError to satisfy the protocol — we can't change it without reworking Simulator.

Decompose dsx.sample into dsx.infer (pure computation, returns InferResult)
+ numpyro site registration (via callback). Filter and Smoother are now
numpyro-free: they compute results and return InferResult with a deferred
_register_numpyro_sites callback that dsx.sample fires. All integration
backends return (marginal_loglik, states, dists) tuples with no side effects.
@theorashid
Copy link
Copy Markdown
Author

On the testing suite, sometimes tests/test_hierarchical_simulator_discretizer_smokes.py, tests/test_science/test_discrete_time_l63_mcmc.py are flaky. tests/test_science/test_hmm.py::test_mcmc_inference breaks for TypeError: only 0-dimensional arrays can be converted to Python scalars. test_science is pretty slow in general

@DanWaxman
Copy link
Copy Markdown
Collaborator

Thanks Theo! This seems directionally about right, I like the implementation strategy!

In the background, sample just calls infer and then registers sites. dsx.infer returns InferResult which carries marginal_loglik, states, dists, and a private _register_numpyro_sites callback.

That makes sense! It will be a tiny bit tricky to get working with Simulators, I think, but that's probably okay. One tricky part is you're allowed to stack simulators with filters, i.e., with Simulator(), Filter(): dsx.sample(...); this allows one to sample from the filtering/posterior predictive. We'll need to pass the corresponding InferResult and append the necessary information.

I am not as familiar with the internals of the library, so if there are other places beyond the diff that you think these changes might affect, let me know where to look

I don't have any off the top of my head, besides the aforementioned interaction with Simulators. The other main test will be running all the notebooks in the documentation from scratch and making sure the results are qualitatively similar, but I think that can come in a bit.

I did make BaseLogFactorAdder (and the Filter equivalent) an ABC because I think it made sense.

Agreed, thanks!

Maybe we should rename _sample_intp to _infer_intp.

Sure, I more-or-less agree.

sometimes tests/test_hierarchical_simulator_discretizer_smokes.py, tests/test_science/test_discrete_time_l63_mcmc.py are flaky

I'm surprised the simulator/discretizer smokes are flaky, I haven't run into that before... but it's not super surprising to me that the discrete_time_l63_mcmc is flaky. I think those were written before we made EnKF the discrete-time default.

test_science is pretty slow in general

Right... I think the test_science suite has somewhat fallen into disuse, and our tests in general are a bit of a mess (though I think with decent coverage -- just disorganized). It's been sitting on the backlog for a bit. I wouldn't worry too much about particularly slow test_science tests as long as docs in the notebooks are looking okay.


At the risk of getting ahead of myself on a draft PR, I think it makes sense to set up a staging branch for this. Then, we can try to land this PR in the staging branch; worry about Simulators and its various interactions afterwards; then worry about the documentation lift that this implies.

@theorashid
Copy link
Copy Markdown
Author

I think the test_science suite has somewhat fallen into disuse

Right, now I've seen the workflows I can see they just run the tests ignoring test_science.

I can give Simulator a go if you want it in this PR, but I'm wary of keeping the size small so it's easy for you to review – up to you.

@DanWaxman DanWaxman changed the base branch from main to dsx-infer-staging May 29, 2026 13:28
@DanWaxman
Copy link
Copy Markdown
Collaborator

I can give Simulator a go if you want it in this PR, but I'm wary of keeping the size small so it's easy for you to review – up to you.

I think it makes sense to keep the PRs small, but also want to minimize the amount of half-working features on the upstream. So I've changed the base branch to dsx-infer-staging, where we can work bit-by-bit in implementing dsx.infer(...).

From that perspective, feel free to mark as ready to review whenever you feel it's ready and I'll take a close look :) thanks again!!

@theorashid theorashid marked this pull request as ready for review May 29, 2026 19:55
@theorashid
Copy link
Copy Markdown
Author

Just renamed _sample_intp to _infer_intp.

I think this is a good point. Smaller change, tests passing, not tooooo many files to check over. Then I'll use any feedback before doing the Simulator refactor.

@mattlevine22
Copy link
Copy Markdown
Collaborator

  1. Yeah, you can ignore test_science for now.
  2. Did you say tests/test_hierarchical_simulator_discretizer_smokes.py is having issues? That one should be a solid test, but it looks like everything passed in (green checks on your previous commits).
  3. I just ran tutorials 04 (discrete-time filter + simulator roll-out) and 06 (SDE filter + simulator roll-out), and both worked and look right locally, so that is a good sign!
  4. I think I prefer dsx.condition, as I worry that infer will sound "all-powerful" to some users trying to do parameter estimation. Infer will sound a bit weird when we use it for Simulator rollouts, but it is not CRAZY to say that a simulator rollout is "conditioned", even if only on obs_times.

Copy link
Copy Markdown
Collaborator

@DanWaxman DanWaxman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty close on my end! A few things:

  • After talking with Matt, I agree that dsx.condition(...) is likely to be clearer than dsx.infer(...). condition is much closer to the statistical "work" that infer is actually doing.
  • Some inconsistencies in docstring should be fixed
  • The guards during registering results are on marginal_loglikelihood being None, but currently, empty filters return an increment of 0.0. It's not super clear what this means for registering sites as a result.
    • Would also be nice to have a test for empty filters

registering numpyro.factor / numpyro.deterministic if needed.

Returns:
tuple: (marginal_loglik, posterior, filtered_dists).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +83 to +84
Returns:
tuple: (marginal_loglik, posterior, smoothed_dists).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here (should match other more detailed docstring)

Comment thread dynestyx/inference/filters.py Outdated
obs_len = int(obs_values.shape[0])
if obs_len == 0:
return []
return jnp.array(0.0), None, []
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return jnp.array(0.0), None, []
return None, None, []

I think the regestration guard is actually on the MLL:

https://github.com/theorashid/dynestyx/blob/536ead8c9a99b71ad6ab7dd26b4c02b4cab767cd/dynestyx/inference/filters.py#L348-L350

t1 = int(obs_values.shape[0])
if t1 == 0:
return []
return jnp.array(0.0), None, []
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return jnp.array(0.0), None, []
return None, None, []

As before, I think the guarding is on the MLL:

https://github.com/theorashid/dynestyx/blob/536ead8c9a99b71ad6ab7dd26b4c02b4cab767cd/dynestyx/inference/smoothers.py#L333-L335

Comment thread dynestyx/inference/numpyro_sites.py Outdated
Comment thread dynestyx/handlers.py Outdated
@theorashid
Copy link
Copy Markdown
Author

I've still kept InferResult. ConditionResult sounds weird, it sounds like the result of some if-else logic rather than what we do here.

@DanWaxman
Copy link
Copy Markdown
Collaborator

Thanks, will take another look tomorrow!

I've still kept InferResult. ConditionResult sounds weird, it sounds like the result of some if-else logic rather than what we do here.

Hmm, I agree... Maybe ConditioningResult? I don't think this has to be a blocking point.

@mattlevine22
Copy link
Copy Markdown
Collaborator

Agree not to let it block...I'd probably do ConditionedResult fwiw

@DanWaxman DanWaxman self-requested a review June 2, 2026 03:25
Copy link
Copy Markdown
Collaborator

@DanWaxman DanWaxman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks very reasonable. Certainly good enough to merge into a staging branch, for me. There is an actual limitation of this API right now (we don't have a great way to deal with non-CRN filters without invoking the numpyro.seed handler), though. That should be dealt with at some point.

Comment on lines +254 to 262
if config.crn_seed is not None:
key = config.crn_seed
else:
import warnings # noqa: PLC0415

with warnings.catch_warnings():
warnings.simplefilter("ignore")
key = numpyro.prng_key() # returns None outside seed handler

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needn't be blocking right now, but this is a legitimate shortcoming; we don't really have a way to deal with non-CRN filters without using the numpyro seed handler. This is non-trivial to fix (though one option is to basically replicate the numpyro.prng_key() function/seed handler, which are not very complex and are a rather straightforward way to implement global random seeds).

@theorashid
Copy link
Copy Markdown
Author

ConditionedResult sounds good. I changed it – start as we mean to go on

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants