dsx.infer for filtering and smoothing without numpyro#238
Conversation
Decompose dsx.sample into dsx.infer (pure computation, returns InferResult) + numpyro site registration (via callback). Filter and Smoother are now numpyro-free: they compute results and return InferResult with a deferred _register_numpyro_sites callback that dsx.sample fires. All integration backends return (marginal_loglik, states, dists) tuples with no side effects.
|
On the testing suite, sometimes |
|
Thanks Theo! This seems directionally about right, I like the implementation strategy!
That makes sense! It will be a tiny bit tricky to get working with Simulators, I think, but that's probably okay. One tricky part is you're allowed to stack simulators with filters, i.e.,
I don't have any off the top of my head, besides the aforementioned interaction with Simulators. The other main test will be running all the notebooks in the documentation from scratch and making sure the results are qualitatively similar, but I think that can come in a bit.
Agreed, thanks!
Sure, I more-or-less agree.
I'm surprised the simulator/discretizer smokes are flaky, I haven't run into that before... but it's not super surprising to me that the discrete_time_l63_mcmc is flaky. I think those were written before we made EnKF the discrete-time default.
Right... I think the test_science suite has somewhat fallen into disuse, and our tests in general are a bit of a mess (though I think with decent coverage -- just disorganized). It's been sitting on the backlog for a bit. I wouldn't worry too much about particularly slow test_science tests as long as docs in the notebooks are looking okay. At the risk of getting ahead of myself on a draft PR, I think it makes sense to set up a staging branch for this. Then, we can try to land this PR in the staging branch; worry about Simulators and its various interactions afterwards; then worry about the documentation lift that this implies. |
Right, now I've seen the workflows I can see they just run the tests ignoring I can give |
I think it makes sense to keep the PRs small, but also want to minimize the amount of half-working features on the upstream. So I've changed the base branch to From that perspective, feel free to mark as ready to review whenever you feel it's ready and I'll take a close look :) thanks again!! |
|
Just renamed I think this is a good point. Smaller change, tests passing, not tooooo many files to check over. Then I'll use any feedback before doing the |
|
DanWaxman
left a comment
There was a problem hiding this comment.
This is pretty close on my end! A few things:
- After talking with Matt, I agree that
dsx.condition(...)is likely to be clearer thandsx.infer(...).conditionis much closer to the statistical "work" thatinferis actually doing. - Some inconsistencies in docstring should be fixed
- The guards during registering results are on
marginal_loglikelihoodbeingNone, but currently, empty filters return an increment of0.0. It's not super clear what this means for registering sites as a result.- Would also be nice to have a test for empty filters
| registering numpyro.factor / numpyro.deterministic if needed. | ||
|
|
||
| Returns: | ||
| tuple: (marginal_loglik, posterior, filtered_dists). |
There was a problem hiding this comment.
Should match the other return docstring block
| Returns: | ||
| tuple: (marginal_loglik, posterior, smoothed_dists). |
There was a problem hiding this comment.
Same here (should match other more detailed docstring)
| obs_len = int(obs_values.shape[0]) | ||
| if obs_len == 0: | ||
| return [] | ||
| return jnp.array(0.0), None, [] |
There was a problem hiding this comment.
| return jnp.array(0.0), None, [] | |
| return None, None, [] |
I think the regestration guard is actually on the MLL:
| t1 = int(obs_values.shape[0]) | ||
| if t1 == 0: | ||
| return [] | ||
| return jnp.array(0.0), None, [] |
There was a problem hiding this comment.
| return jnp.array(0.0), None, [] | |
| return None, None, [] |
As before, I think the guarding is on the MLL:
|
I've still kept |
|
Thanks, will take another look tomorrow!
Hmm, I agree... Maybe |
|
Agree not to let it block...I'd probably do ConditionedResult fwiw |
DanWaxman
left a comment
There was a problem hiding this comment.
I think this looks very reasonable. Certainly good enough to merge into a staging branch, for me. There is an actual limitation of this API right now (we don't have a great way to deal with non-CRN filters without invoking the numpyro.seed handler), though. That should be dealt with at some point.
| if config.crn_seed is not None: | ||
| key = config.crn_seed | ||
| else: | ||
| import warnings # noqa: PLC0415 | ||
|
|
||
| with warnings.catch_warnings(): | ||
| warnings.simplefilter("ignore") | ||
| key = numpyro.prng_key() # returns None outside seed handler | ||
|
|
There was a problem hiding this comment.
This also needn't be blocking right now, but this is a legitimate shortcoming; we don't really have a way to deal with non-CRN filters without using the numpyro seed handler. This is non-trivial to fix (though one option is to basically replicate the numpyro.prng_key() function/seed handler, which are not very complex and are a rather straightforward way to implement global random seeds).
|
|
This is for #213.
So far it is just for
FilterandSmoother. Simulator will have to come.Basic idea is that Filter/Smoother no longer call
numpyro.factorinline, it happens indsx.sample.test_filtersandtest_smoothersstill pass, happy days.In the background,
samplejust callsinferand then registers sites.dsx.inferreturnsInferResultwhich carriesmarginal_loglik,states,dists, and a private_register_numpyro_sitescallback.The main work can be seen in the
test_*_standalone.py, which I based from my cuthbert-models repo. This is how I would see it being used.how these changes were made
Of course, this was largely done by burning tokens and handholding so that it matches the design that I (and then with feedback from both of you) wanted. I am not as familiar with the internals of the library, so if there are other places beyond the diff that you think these changes might affect, let me know where to look – I was relying a bit on existing tests not breaking.
smaller design things
I largely tried to keep everything that was there before to reduce the size of the refactor. But:
BaseLogFactorAdder(and theFilterequivalent) anABCbecause I think it made sense._sample_intpto_infer_intp.InferResulta__call__shim to satisfyFunctionOfTimeprotocol (needed because some model functionsreturn dsx.sample(...)). Effectful uses the@defopreturn annotation to decide the type offwd()returns.InferResult.__call__raisesNotImplementedErrorto satisfy the protocol — we can't change it without reworkingSimulator.