Skip to content
39 changes: 34 additions & 5 deletions effectful/handlers/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,12 @@ def _pos_base_dist(self) -> dist.Distribution:

@functools.cached_property
def _is_eager(self) -> bool:
return all(
(not isinstance(x, Term) or is_eager_array(x))
for x in (*self.args, *self.kwargs.values())
)
def _arg_is_eager(x):
if isinstance(x, _DistributionTerm):
return x._is_eager
return not isinstance(x, Term) or is_eager_array(x)

return all(_arg_is_eager(x) for x in (*self.args, *self.kwargs.values()))

@property
def op(self):
Expand Down Expand Up @@ -357,7 +359,7 @@ def to_event(self, reinterpreted_batch_ndims=None) -> dist.Distribution:
raise NotHandled

@defop
def expand(self, batch_shape) -> jax.Array:
def expand(self, batch_shape) -> dist.Distribution:
if not self._is_eager:
raise NotHandled

Expand Down Expand Up @@ -396,6 +398,33 @@ def __str__(self):
expand = _DistributionTerm.expand


@defdata.register(dist.Distribution)
class _DistributionMethodTerm(_DistributionTerm):
"""Term for distribution-method ops returning the abstract ``dist.Distribution``
(``expand``, ``to_event``). Catches the ``defdata`` fallthrough that would
otherwise hit ``_CallableTerm``. See #666."""

def __init__(self, ty, op, *args, **kwargs):
receiver = args[0] if args else None
constr = (
receiver._constr
if isinstance(receiver, _DistributionTerm)
else dist.Distribution
)
super().__init__(constr, op, *args, **kwargs)

@functools.cached_property
def _pos_base_dist(self) -> dist.Distribution:
# Delegate to NumPyro's method of the same name on the materialised receiver.
receiver = self._args[0]
base = (
receiver._pos_base_dist
if isinstance(receiver, _DistributionTerm)
else receiver
)
return getattr(base, self._op.__name__)(*self._args[1:], **self._kwargs)


@defop
def Cauchy(loc=0.0, scale=1.0, **kwargs) -> dist.Cauchy:
raise NotHandled
Expand Down
97 changes: 94 additions & 3 deletions tests/test_handlers_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def add_case(raw_dist, raw_params, batch_shape, xfail=None):
("concentration0", f"exp(rand({batch_shape + indep_shape}))"),
),
batch_shape,
xfail="to_event not implemented",
xfail="to_event composed with expand_by on indexed dims not implemented",
)

# Dirichlet.to_event
Expand All @@ -494,7 +494,7 @@ def add_case(raw_dist, raw_params, batch_shape, xfail=None):
),
),
batch_shape,
xfail="to_event not implemented",
xfail="to_event composed with expand_by on indexed dims not implemented",
)

# TransformedDistribution.to_event
Expand All @@ -513,7 +513,7 @@ def add_case(raw_dist, raw_params, batch_shape, xfail=None):
("high", f"2. + rand({batch_shape + indep_shape})"),
),
batch_shape,
xfail="to_event not implemented",
xfail="TransformedDistribution not implemented",
)


Expand Down Expand Up @@ -929,3 +929,94 @@ def test_distribution_typeof():
typeof(dist.Normal(jax_getitem(jnp.array([0, 1, 2]), [defop(jax.Array)()])))
is numpyro.distributions.continuous.Normal
)


def test_distribution_method_chain_on_non_eager_term():
"""Regression test for #666 (narrow).

``Normal(mu_term, 1.0).expand([J]).to_event(1)`` must not raise
``AttributeError`` mid-chain. Previously ``_DistributionTerm.expand`` was
``@defop``-annotated to return ``jax.Array``, routing ``.expand([J])``'s
result through ``_ArrayTerm`` (no ``.to_event``). The fix annotates
``expand`` to return ``dist.Distribution`` and registers a fallback
``_DistributionMethodTerm`` for ``defdata`` dispatch on the abstract base,
so the chain stays in the distribution-term surface.
"""
mu = defop(jax.Array, name="mu")

expanded = dist.Normal(mu(), 1.0).expand([3])
assert isinstance(expanded, numpyro.distributions.Distribution)

chained = expanded.to_event(1)
assert isinstance(chained, numpyro.distributions.Distribution)


def test_expand_to_event_shape_laws():
"""Equational laws for ``.expand`` and ``.to_event`` on a distribution term
whose free-variable arg has been bound by an effectful handler.

These hold for any NumPyro distribution and should survive any future
refactor of how deferred method ops are encoded:

d.expand(s).batch_shape == tuple(s)
d.expand(s).event_shape == d.event_shape
d.to_event(k).event_shape == d.batch_shape[-k:] + d.event_shape
d.to_event(k).batch_shape == d.batch_shape[:-k]
"""
import jax.numpy as jnp

from effectful.ops.semantics import handler

mu = defop(jax.Array, name="mu")

with handler({mu: lambda: jnp.array(0.0)}):
d = dist.Normal(mu(), 1.0)
assert d.batch_shape == ()
assert d.event_shape == ()

expanded = d.expand([3, 4])
assert expanded.batch_shape == (3, 4)
assert expanded.event_shape == ()

indep = expanded.to_event(1)
assert indep.batch_shape == (3,)
assert indep.event_shape == (4,)

chained = d.expand([3]).to_event(1)
assert chained.batch_shape == ()
assert chained.event_shape == (3,)
assert not chained.support.is_discrete


def test_expand_to_event_chain_end_to_end_mcmc():
"""End-to-end regression: the literal #666 idiom — ``Normal(mu_term, 1.0)
.expand([J]).to_event(1)`` with ``mu_term`` bound by an effectful handler —
must trace, build a potential, and run MCMC to completion.

Before the fix this raised ``AttributeError: '_ArrayTerm' object has no
attribute 'to_event'`` at chain construction. After the fix, the chain
constructs a ``_DistributionMethodTerm`` whose materialised
``_pos_base_dist`` resolves to a real ``dist.Independent`` wrapping the
handler-bound receiver, so NumPyro's downstream property/sample/log_prob
accesses all resolve.
"""
import jax.numpy as jnp
import jax.random as jr

from effectful.ops.semantics import handler

mu = defop(jax.Array, name="mu")

def model():
numpyro.sample("theta", dist.Normal(mu(), 1.0).expand([3]).to_event(1))

with handler({mu: lambda: jnp.array(0.0)}):
mcmc = numpyro.infer.MCMC(
numpyro.infer.NUTS(model),
num_warmup=20,
num_samples=20,
progress_bar=False,
)
mcmc.run(jr.PRNGKey(0))

assert mcmc.get_samples()["theta"].shape == (20, 3)
Loading