From 9252a32ab2a09f1db680e11b20bc3bf796bba9d3 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 22 May 2026 09:43:16 -0400 Subject: [PATCH 01/29] more precise stream type --- effectful/ops/monoid.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 70bb5002..57685aaf 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -4,10 +4,10 @@ import operator import typing from collections import Counter, UserDict, defaultdict -from collections.abc import Callable, Generator, Iterable, Mapping +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from dataclasses import dataclass from graphlib import TopologicalSorter -from typing import Annotated, Any +from typing import Annotated, Any, Protocol from effectful.internals.disjoint_set import DisjointSet from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler @@ -22,9 +22,14 @@ ) from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term + +class Stream[T](Protocol): + def __iter__(self) -> Iterator[T]: ... + + # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable -type Streams[T] = Mapping[Operation[[], T], Any] +type Streams[T] = Mapping[Operation[[], T], Stream[T]] type Body[T] = ( Iterable[T] @@ -35,9 +40,9 @@ ) -def outer_stream( - streams: Streams, -) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: +def outer_stream[T]( + streams: Streams[T], +) -> Iterable[tuple[Operation, Stream[T], dict[Operation, Stream[T]]]]: """Returns the streams that can be ordered outermost in the loop nest as well as the remaining streams in the nest. From 284546d7ea9b1a4b917001feaa57c9052df23a78 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 22 May 2026 10:28:35 -0400 Subject: [PATCH 02/29] add tests for weighted rules --- effectful/ops/monoid.py | 19 +++++++++ tests/test_handlers_jax_monoid.py | 37 +++++++++++++++++- tests/test_ops_monoid.py | 64 +++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+), 1 deletion(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 57685aaf..b741b76d 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -123,6 +123,25 @@ def __init__(self, name: str, identity: T, zero: T): self.zero = zero +@dataclass +class WeightedStream[T, W](Stream[tuple[T, W]]): + stream: Stream[T] + weight: Callable[[T], W] + monoid: Monoid[W] + + def __iter__(self): + if isinstance(self.stream, Term): + return iter_(self) + + return ((x, self.weight(x)) for x in self.stream) + + +@Operation.define +@functools.singledispatch +def weighted(x) -> WeightedStream: + raise NotImplementedError("Unsupported type", type(x)) + + Min = Monoid(name="Min", identity=float("inf")) Max = Monoid(name="Max", identity=-float("inf")) ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index 35d041fe..e21d8d42 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -5,7 +5,8 @@ from effectful.handlers.jax import bind_dims, unbind_dims from effectful.handlers.jax.monoid import ArrayReduce, LogSumExp from effectful.handlers.jax.scipy.special import logsumexp -from effectful.ops.monoid import Max, Min, Product, Sum +from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Sum, WeightedStream +from effectful.ops.semantics import coproduct from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars MONOIDS = [ @@ -94,3 +95,37 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): backend=backend, free_vars=[x, y, k1, k2, X, f, g], ) + + +@pytest.mark.xfail( + strict=True, reason="ReduceWeightedStream rewrite rule not yet implemented" +) +def test_jax_weighted_reduce(backend: Backend): + """Sum over a single ``WeightedStream`` with ``Product`` weights lowers to + ``jnp.sum(w(X) * body(X))`` under ``NormalizeIntp`` ∘ ``ArrayReduce``. + + Verifies that the desugaring rule composes cleanly with the JAX lowering + so existing handlers need no changes to support weighted streams. + """ + (x, k) = define_vars("x", "k", typ=jax.Array) + X = define_vars("X", typ=backend.stream_typ) + body = backend.fresh_op("body", n_args=1, ret="scalar") + w = backend.fresh_op("w", n_args=1, ret="scalar") + + ws = WeightedStream(stream=X(), weight=lambda v: w(v), monoid=Product) + lhs = Sum.reduce(body(x()), {x: ws}) + rhs = jnp.sum( + bind_dims( + unbind_dims(w(X()), k) * unbind_dims(body(X()), k), + k, + ), + axis=0, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=coproduct(NormalizeIntp, ArrayReduce()), + backend=backend, + free_vars=[x, k, X, body, w], + ) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index c7ee7567..d225c3d1 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -28,6 +28,7 @@ ReduceNoStreams, ReduceSplit, Sum, + WeightedStream, distributes_over, ) from effectful.ops.semantics import fvsof, handler @@ -666,3 +667,66 @@ def test_reduce_lifted_2(outer, inner, backend): backend=backend, free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2], ) + + +# --------------------------------------------------------------------------- +# Weighted streams +# --------------------------------------------------------------------------- + + +@pytest.mark.xfail( + strict=True, reason="ReduceWeightedStream rewrite rule not yet implemented" +) +def test_reduce_single_weighted_stream(backend): + """Single weighted stream desugars: + Sum.reduce(body, {a: WS(A, w, Product)}) + = Sum.reduce(Product.plus(w(a), body), {a: A}) + """ + a = define_vars("a", typ=backend.scalar_typ) + A = define_vars("A", typ=backend.stream_typ) + body = backend.fresh_op("body", n_args=1, ret="scalar") + w = backend.fresh_op("w", n_args=1, ret="scalar") + + ws = WeightedStream(stream=A(), weight=lambda v: w(v), monoid=Product) + lhs = Sum.reduce(body(a()), {a: ws}) + rhs = Sum.reduce(Product.plus(w(a()), body(a())), {a: A()}) + + check_rewrite( + lhs=lhs, rhs=rhs, rule=NormalizeIntp, backend=backend, free_vars=[A, body, w] + ) + + +@pytest.mark.xfail( + strict=True, reason="ReduceWeightedStream rewrite rule not yet implemented" +) +def test_reduce_weighted_factorization(backend): + """Two independent weighted streams under Sum with Product weights factor: + Sum.reduce(f(a)*g(b), {a: WS(A, w_a, Product), b: WS(B, w_b, Product)}) + = (Sum.reduce(w_a(a)*f(a), {a: A})) * (Sum.reduce(w_b(b)*g(b), {b: B})) + + Exercises chaining of ``ReduceWeightedStream`` with ``ReduceFactorization`` + inside ``NormalizeIntp``. + """ + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") + g = backend.fresh_op("g", n_args=1, ret="scalar") + w_a = backend.fresh_op("w_a", n_args=1, ret="scalar") + w_b = backend.fresh_op("w_b", n_args=1, ret="scalar") + + ws_a = WeightedStream(stream=A(), weight=lambda v: w_a(v), monoid=Product) + ws_b = WeightedStream(stream=B(), weight=lambda v: w_b(v), monoid=Product) + + lhs = Sum.reduce(Product.plus(f(a()), g(b())), {a: ws_a, b: ws_b}) + rhs = Product.plus( + Sum.reduce(Product.plus(w_a(a()), f(a())), {a: A()}), + Sum.reduce(Product.plus(w_b(b()), g(b())), {b: B()}), + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=NormalizeIntp, + backend=backend, + free_vars=[A, B, f, g, w_a, w_b], + ) From 4cc519d53164429016974f34beeb23028c485d0c Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 22 May 2026 13:28:53 -0400 Subject: [PATCH 03/29] add reduction rule for weighted streams and tests --- effectful/handlers/jax/monoid.py | 3 ++ effectful/internals/unification.py | 1 + effectful/ops/monoid.py | 32 ++++++++++++++----- tests/_monoid_helpers.py | 50 +++++++++++++++++++++++++++--- tests/test_ops_monoid.py | 23 +++++++------- 5 files changed, 84 insertions(+), 25 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index a406cda5..8630f4fd 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -1,4 +1,5 @@ import functools +from collections.abc import Iterable import jax @@ -19,6 +20,8 @@ from effectful.ops.syntax import ObjectInterpretation, deffn, implements from effectful.ops.types import Operation +Iterable.register(jax.Array) # required to make jax arrays compatible with Stream[T] + def cartesian_prod(x, y): if x.ndim == 1: diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index e425bba6..36a4a4e9 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -801,6 +801,7 @@ def _(typ: typing._ConcatenateGenericAlias): # type: ignore @canonicalize.register def _(typ: typing._AnyMeta): # type: ignore + return typing.Any diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index b741b76d..34217d18 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,10 +10,12 @@ from typing import Annotated, Any, Protocol from effectful.internals.disjoint_set import DisjointSet +from effectful.internals.unification import canonicalize from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler from effectful.ops.syntax import ( ObjectInterpretation, Scoped, + defdata, deffn, implements, iter_, @@ -22,9 +24,7 @@ ) from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term - -class Stream[T](Protocol): - def __iter__(self) -> Iterator[T]: ... +type Stream[T] = Iterable[T] # Note: The streams value type should be something like Iterable[T], but some of @@ -124,16 +124,13 @@ def __init__(self, name: str, identity: T, zero: T): @dataclass -class WeightedStream[T, W](Stream[tuple[T, W]]): +class WeightedStream[T, W](Iterable[T]): stream: Stream[T] weight: Callable[[T], W] monoid: Monoid[W] def __iter__(self): - if isinstance(self.stream, Term): - return iter_(self) - - return ((x, self.weight(x)) for x in self.stream) + return defdata(iter_, self) @Operation.define @@ -579,6 +576,24 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): return fwd() +class ReduceWeightedStream(ObjectInterpretation): + """reduce(M, body, {x: WeightedStream(s, w, WM), ...}) = reduce(M, + WM.plus(w(x), body), {x: s, ...}) + + requires distributes_over(WM, M). + + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + for k, v in streams.items(): + if isinstance(v, WeightedStream) and distributes_over(v.monoid, monoid): + weighted_body = v.monoid.plus(v.weight(k()), body) + new_streams = {**streams, k: v.stream} + return monoid.reduce(weighted_body, new_streams) + return fwd() + + class MonoidOverCallable(ObjectInterpretation): """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" @@ -773,6 +788,7 @@ def extend(self, *intps: Interpretation) -> typing.Self: ReduceSplit(), ReduceFactorization(), ReduceDistributeCartesianProduct(), + ReduceWeightedStream(), PlusEmpty(), PlusSingle(), PlusIdentity(), diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index f15103e3..76219b16 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,7 +1,7 @@ import itertools from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass -from typing import Any, get_args, get_origin +from typing import Annotated, Any, get_args, get_origin import jax from hypothesis import given, settings @@ -25,8 +25,19 @@ def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) +# Shape-preserving unary jax fns: scalar → scalar (counterpart of +# ``_UNARY_NUM_FNS`` for ints). Used for ops declared with ``ret="scalar"``. +_UNARY_JAX_SCALAR_FNS: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: a, + lambda a: a + 1, + lambda a: a - 1, + lambda a: -a, + lambda a: 2 * a, +] + # Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` # for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. +# Used for ops declared with ``ret="stream"``. _UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ lambda a: _jnp.stack([a, a + 1]), lambda a: _jnp.stack([a, -a]), @@ -91,14 +102,33 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: ] +def _is_stream(annotation: Any) -> bool: + """True if ``annotation`` carries the ``"stream"`` Annotated marker. + + On the JAX backend ``scalar_typ`` and ``stream_typ`` are both ``jax.Array``, + so :meth:`Backend.fresh_op` tags stream returns as + ``Annotated[jax.Array, "stream"]`` to keep them distinguishable here. + """ + return get_origin(annotation) is Annotated and "stream" in annotation.__metadata__ + + +def _strip(annotation: Any) -> Any: + """Strip an ``Annotated`` wrapper to its underlying type.""" + if get_origin(annotation) is Annotated: + return get_args(annotation)[0] + return annotation + + def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: """Pick a strategy producing a callable suitable for binding `op` in an interpretation. Inspects the operation's signature. """ sig = op.__signature__ params = list(sig.parameters.values()) - ret = sig.return_annotation - param_types = tuple(p.annotation for p in params) + ret_annot = sig.return_annotation + ret = _strip(ret_annot) + ret_is_stream = _is_stream(ret_annot) + param_types = tuple(_strip(p.annotation) for p in params) if not params: return _value_strategy_for(ret).map(deffn) @@ -109,7 +139,9 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): return st.sampled_from(_UNARY_LIST_FNS) if ret is jax.Array and param_types == (jax.Array,): - return st.sampled_from(_UNARY_JAX_FNS) + if ret_is_stream: + return st.sampled_from(_UNARY_JAX_FNS) + return st.sampled_from(_UNARY_JAX_SCALAR_FNS) if ret is jax.Array and param_types == (jax.Array, jax.Array): return st.sampled_from(_BINARY_JAX_FNS) if ( @@ -274,7 +306,15 @@ def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation each of type ``scalar_typ``. """ scalar = self.scalar_typ - out = self.stream_typ if ret == "stream" else scalar + if ret == "stream": + out = self.stream_typ + # When scalar_typ == stream_typ (e.g. jax backend), tag the return + # with an Annotated marker so ``_strategy_for_op`` can pick the + # right (shape-changing) function family. + if scalar is out: + out = Annotated[out, "stream"] + else: + out = scalar params = ", ".join(f"_a{i}" for i in range(n_args)) ns: dict[str, Any] = {"NotHandled": NotHandled} exec(f"def _fn({params}):\n raise NotHandled\n", ns) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index d225c3d1..930cbbe6 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -27,11 +27,12 @@ ReduceFusion, ReduceNoStreams, ReduceSplit, + ReduceWeightedStream, Sum, WeightedStream, distributes_over, ) -from effectful.ops.semantics import fvsof, handler +from effectful.ops.semantics import coproduct, fvsof, handler from effectful.ops.types import Operation from tests._monoid_helpers import ( INT_BACKEND, @@ -674,9 +675,6 @@ def test_reduce_lifted_2(outer, inner, backend): # --------------------------------------------------------------------------- -@pytest.mark.xfail( - strict=True, reason="ReduceWeightedStream rewrite rule not yet implemented" -) def test_reduce_single_weighted_stream(backend): """Single weighted stream desugars: Sum.reduce(body, {a: WS(A, w, Product)}) @@ -687,18 +685,19 @@ def test_reduce_single_weighted_stream(backend): body = backend.fresh_op("body", n_args=1, ret="scalar") w = backend.fresh_op("w", n_args=1, ret="scalar") - ws = WeightedStream(stream=A(), weight=lambda v: w(v), monoid=Product) + ws = WeightedStream(stream=A(), weight=w, monoid=Product) lhs = Sum.reduce(body(a()), {a: ws}) rhs = Sum.reduce(Product.plus(w(a()), body(a())), {a: A()}) check_rewrite( - lhs=lhs, rhs=rhs, rule=NormalizeIntp, backend=backend, free_vars=[A, body, w] + lhs=lhs, + rhs=rhs, + rule=ReduceWeightedStream(), + backend=backend, + free_vars=[A, body, w], ) -@pytest.mark.xfail( - strict=True, reason="ReduceWeightedStream rewrite rule not yet implemented" -) def test_reduce_weighted_factorization(backend): """Two independent weighted streams under Sum with Product weights factor: Sum.reduce(f(a)*g(b), {a: WS(A, w_a, Product), b: WS(B, w_b, Product)}) @@ -719,14 +718,14 @@ def test_reduce_weighted_factorization(backend): lhs = Sum.reduce(Product.plus(f(a()), g(b())), {a: ws_a, b: ws_b}) rhs = Product.plus( - Sum.reduce(Product.plus(w_a(a()), f(a())), {a: A()}), - Sum.reduce(Product.plus(w_b(b()), g(b())), {b: B()}), + Sum.reduce(Product.plus(w_a(a()), Product.plus(f(a()))), {a: A()}), + Sum.reduce(Product.plus(w_b(b()), Product.plus(g(b()))), {b: B()}), ) check_rewrite( lhs=lhs, rhs=rhs, - rule=NormalizeIntp, + rule=coproduct(ReduceWeightedStream(), ReduceFactorization()), backend=backend, free_vars=[A, B, f, g, w_a, w_b], ) From 011be1eac7f3ae9cf7e04b5fe3bd4b497182b062 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 22 May 2026 13:57:09 -0400 Subject: [PATCH 04/29] add test to demo expectation --- effectful/ops/monoid.py | 14 ++++------- tests/test_handlers_jax_monoid.py | 34 ++++++++++++++++---------- tests/test_ops_monoid.py | 40 +++++++++++++++++++++++++++---- 3 files changed, 62 insertions(+), 26 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 34217d18..e239cc65 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -4,13 +4,12 @@ import operator import typing from collections import Counter, UserDict, defaultdict -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping +from collections.abc import Callable, Generator, Iterable, Mapping from dataclasses import dataclass from graphlib import TopologicalSorter -from typing import Annotated, Any, Protocol +from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.internals.unification import canonicalize from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler from effectful.ops.syntax import ( ObjectInterpretation, @@ -26,10 +25,7 @@ type Stream[T] = Iterable[T] - -# Note: The streams value type should be something like Iterable[T], but some of -# our target stream types (e.g. jax.Array) are not subtypes of Iterable -type Streams[T] = Mapping[Operation[[], T], Stream[T]] +type Streams = Mapping[Operation[[], Any], Stream[Any]] type Body[T] = ( Iterable[T] @@ -40,9 +36,7 @@ ) -def outer_stream[T]( - streams: Streams[T], -) -> Iterable[tuple[Operation, Stream[T], dict[Operation, Stream[T]]]]: +def outer_stream(streams: Streams) -> Iterable[tuple[Operation, Stream, Streams]]: """Returns the streams that can be ordered outermost in the loop nest as well as the remaining streams in the nest. diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index e21d8d42..a7d2a27f 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -1,11 +1,26 @@ +import functools + import jax import pytest import effectful.handlers.jax.numpy as jnp from effectful.handlers.jax import bind_dims, unbind_dims -from effectful.handlers.jax.monoid import ArrayReduce, LogSumExp +from effectful.handlers.jax.monoid import ( + ArrayReduce, + LogSumExp, + ProductPlusJax, + SumPlusJax, +) from effectful.handlers.jax.scipy.special import logsumexp -from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Sum, WeightedStream +from effectful.ops.monoid import ( + Max, + Min, + NormalizeIntp, + Product, + ReduceWeightedStream, + Sum, + WeightedStream, +) from effectful.ops.semantics import coproduct from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars @@ -97,9 +112,6 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): ) -@pytest.mark.xfail( - strict=True, reason="ReduceWeightedStream rewrite rule not yet implemented" -) def test_jax_weighted_reduce(backend: Backend): """Sum over a single ``WeightedStream`` with ``Product`` weights lowers to ``jnp.sum(w(X) * body(X))`` under ``NormalizeIntp`` ∘ ``ArrayReduce``. @@ -112,20 +124,18 @@ def test_jax_weighted_reduce(backend: Backend): body = backend.fresh_op("body", n_args=1, ret="scalar") w = backend.fresh_op("w", n_args=1, ret="scalar") - ws = WeightedStream(stream=X(), weight=lambda v: w(v), monoid=Product) + ws = WeightedStream(stream=X(), weight=w, monoid=Product) lhs = Sum.reduce(body(x()), {x: ws}) rhs = jnp.sum( - bind_dims( - unbind_dims(w(X()), k) * unbind_dims(body(X()), k), - k, - ), - axis=0, + bind_dims(w(unbind_dims(X(), k)) * body(unbind_dims(X(), k)), k), axis=0 ) check_rewrite( lhs=lhs, rhs=rhs, - rule=coproduct(NormalizeIntp, ArrayReduce()), + rule=functools.reduce( + coproduct, [ReduceWeightedStream(), ArrayReduce(), ProductPlusJax()] + ), backend=backend, free_vars=[x, k, X, body, w], ) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 930cbbe6..8296cdab 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,3 +1,4 @@ +import math import typing import pytest @@ -32,8 +33,8 @@ WeightedStream, distributes_over, ) -from effectful.ops.semantics import coproduct, fvsof, handler -from effectful.ops.types import Operation +from effectful.ops.semantics import coproduct, evaluate, fvsof, handler +from effectful.ops.types import NotHandled, Operation, Term from tests._monoid_helpers import ( INT_BACKEND, JAX_BACKEND, @@ -713,8 +714,8 @@ def test_reduce_weighted_factorization(backend): w_a = backend.fresh_op("w_a", n_args=1, ret="scalar") w_b = backend.fresh_op("w_b", n_args=1, ret="scalar") - ws_a = WeightedStream(stream=A(), weight=lambda v: w_a(v), monoid=Product) - ws_b = WeightedStream(stream=B(), weight=lambda v: w_b(v), monoid=Product) + ws_a = WeightedStream(stream=A(), weight=w_a, monoid=Product) + ws_b = WeightedStream(stream=B(), weight=w_b, monoid=Product) lhs = Sum.reduce(Product.plus(f(a()), g(b())), {a: ws_a, b: ws_b}) rhs = Product.plus( @@ -729,3 +730,34 @@ def test_reduce_weighted_factorization(backend): backend=backend, free_vars=[A, B, f, g, w_a, w_b], ) + + +def test_weighted_expectation_demo(): + """Demo: compute E[f(X)] = Σ_x w(x)·f(x) via a weighted reduce. + + X ranges over [1, 2, 3, 4] with weights w(x) = x/10 (a valid distribution + since the weights sum to 1) and f(x) = x*x. Expected value: + 0.1·1 + 0.2·4 + 0.3·9 + 0.4·16 = 10.0 + """ + weights = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4} + + def _w(v: int) -> float: + if isinstance(v, Term): + raise NotHandled + return weights[v] + + def _f(v: int) -> float: + if isinstance(v, Term): + raise NotHandled + return float(v * v) + + a = define_vars("a", typ=int) + w = Operation.define(_w, name="w") + f = Operation.define(_f, name="f") + + ws = WeightedStream(stream=[1, 2, 3, 4], weight=w, monoid=Product) + + with handler(NormalizeIntp): + result = evaluate(Sum.reduce(f(a()), {a: ws})) + + assert math.isclose(result, 10.0) From d46123e964933b0fb8960024421c983ed1c8431f Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 22 May 2026 16:09:36 -0400 Subject: [PATCH 05/29] add numpyro monoid module --- effectful/handlers/jax/monoid.py | 5 + effectful/handlers/numpyro/__init__.py | 1 + .../{numpyro.py => numpyro/_distributions.py} | 0 effectful/handlers/numpyro/monoid.py | 145 ++++++++++++++++++ effectful/ops/monoid.py | 15 ++ 5 files changed, 166 insertions(+) create mode 100644 effectful/handlers/numpyro/__init__.py rename effectful/handlers/{numpyro.py => numpyro/_distributions.py} (100%) create mode 100644 effectful/handlers/numpyro/monoid.py diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 8630f4fd..ddd5877b 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -14,6 +14,7 @@ NormalizeIntp, Product, Sum, + distributes_over, outer_stream, ) from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof @@ -38,6 +39,10 @@ def cartesian_prod(x, y): LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) +# ``Sum`` in log space is multiplication, which distributes over ``LogSumExp``: +# a + logsumexp(b, c) = logsumexp(a + b, a + c) +distributes_over.register(Sum, LogSumExp) + def _jax_args(args): """True iff ``args`` is non-empty and every arg is a concrete diff --git a/effectful/handlers/numpyro/__init__.py b/effectful/handlers/numpyro/__init__.py new file mode 100644 index 00000000..607151df --- /dev/null +++ b/effectful/handlers/numpyro/__init__.py @@ -0,0 +1 @@ +from effectful.handlers.numpyro._distributions import * # noqa: F401, F403 diff --git a/effectful/handlers/numpyro.py b/effectful/handlers/numpyro/_distributions.py similarity index 100% rename from effectful/handlers/numpyro.py rename to effectful/handlers/numpyro/_distributions.py diff --git a/effectful/handlers/numpyro/monoid.py b/effectful/handlers/numpyro/monoid.py new file mode 100644 index 00000000..c68864cd --- /dev/null +++ b/effectful/handlers/numpyro/monoid.py @@ -0,0 +1,145 @@ +"""NumPyro distribution support for weighted streams. + +``weighted(dist)`` is the smart constructor for treating a numpyro +distribution as a weighted stream. By default it stays symbolic — i.e. +``weighted(d)`` returns a ``Term`` whose ``args[0]`` is ``d`` — so that +specialized reduction rules (closed-form expectations, quadrature, etc.) +can pattern-match on the distribution. + +Two general-purpose reduction rules are provided here: + +* :class:`NumpyroSampling` — Monte Carlo approximation. Replaces + ``weighted(d)`` with a sample-backed :class:`WeightedStream` of ``n_samples`` + i.i.d. draws and uniform weights, then delegates to the standard + :class:`ReduceWeightedStream` machinery. + +* :class:`NumpyroLogProb` — generic symbolic lowering. Replaces + ``weighted(d)`` with ``WeightedStream(stream(d.support), d.log_prob, Sum)``. + ``Sum`` acts as multiplication in log space and + ``distributes_over(Sum, LogSumExp)`` is registered, so a subsequent + ``LogSumExp.reduce`` is desugared by ``ReduceWeightedStream`` into the + standard log-space expectation integrand: + + LogSumExp.reduce(Sum.plus(d.log_prob(x), body), {x: stream(d.support)}) +""" + +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +import numpyro.distributions as dist +import numpyro.distributions.constraints as constraints + +from effectful.handlers.jax.monoid import LogSumExp +from effectful.ops.monoid import ( + Monoid, + Product, + Sum, + WeightedStream, + stream, + weighted, +) +from effectful.ops.semantics import fwd +from effectful.ops.syntax import ObjectInterpretation, deffn, implements +from effectful.ops.types import NotHandled, Operation, Term + +# --- smart constructors stay symbolic for distributions / constraints ----- + + +@weighted.register(dist.Distribution) +def _weighted_dist(_d): + raise NotHandled + + +@stream.register(dist.Distribution) +def _stream_dist(_d): + raise NotHandled + + +@stream.register(constraints.Constraint) +def _stream_constraint(_c): + raise NotHandled + + +def _weighted_dist_arg(v) -> dist.Distribution | None: + """If ``v`` is ``Term(weighted, [d])`` with ``d`` a numpyro Distribution, + return ``d``; otherwise ``None``. + """ + if not (isinstance(v, Term) and v.op is weighted): + return None + (d,) = v.args + return d if isinstance(d, dist.Distribution) else None + + +# --- rule: Monte Carlo sampling ------------------------------------------- + + +@dataclass +class NumpyroSampling(ObjectInterpretation): + """Replace ``weighted(d)`` with a sample-backed :class:`WeightedStream`. + + Draws ``n_samples`` i.i.d. samples from ``d`` and attaches a uniform + weight ``1/n_samples`` (linear space) or ``-log(n_samples)`` (log + space, when the outer monoid is :data:`LogSumExp`). The resulting + :class:`WeightedStream` is then handled by the standard + :class:`ReduceWeightedStream` rewrite. + """ + + rng_key: jax.Array + n_samples: int = 1000 + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + new_streams = dict(streams) + progress = False + for k, v in streams.items(): + d = _weighted_dist_arg(v) + if d is None: + continue + samples = d.sample(self.rng_key, sample_shape=(self.n_samples,)) + if monoid is LogSumExp: + w_val = -jnp.log(self.n_samples) + w_monoid: Monoid = Sum + else: + w_val = 1.0 / self.n_samples + w_monoid = Product + new_streams[k] = WeightedStream( + stream=samples, + weight=deffn(w_val, Operation.define(k)), + monoid=w_monoid, + ) + progress = True + if progress: + return monoid.reduce(body, new_streams) + return fwd() + + +# --- rule: symbolic log-prob lowering ------------------------------------- + + +class NumpyroLogProb(ObjectInterpretation): + """Lower ``weighted(d)`` to its symbolic log-prob form. + + Generic fallback: produces a :class:`WeightedStream` whose stream is + the symbolic ``stream(d.support)``, weight is ``d.log_prob``, and + weight monoid is :data:`Sum` (log-space multiplication). With + ``distributes_over(Sum, LogSumExp)`` registered, a surrounding + ``LogSumExp.reduce`` will then desugar via :class:`ReduceWeightedStream` + into the standard expectation integrand. + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + new_streams = dict(streams) + progress = False + for k, v in streams.items(): + d = _weighted_dist_arg(v) + if d is None: + continue + new_streams[k] = WeightedStream( + stream=stream(d.support), weight=d.log_prob, monoid=Sum + ) + progress = True + if progress: + return monoid.reduce(body, new_streams) + return fwd() diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index e239cc65..49a941c5 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -133,6 +133,21 @@ def weighted(x) -> WeightedStream: raise NotImplementedError("Unsupported type", type(x)) +@Operation.define +@functools.singledispatch +def stream(x) -> Iterable: + """Smart constructor lifting a value into the :data:`Stream` type. + + Used to wrap opaque ``support``-like values (e.g. numpyro + distributions or constraints) that aren't structurally iterable but + should appear in the stream slot of a :class:`WeightedStream`. Concrete + iterables can be registered to pass through unchanged; symbolic sources + register impls that ``raise NotHandled`` so the call stays a Term and + downstream rules can pattern-match on the wrapped value. + """ + raise NotImplementedError("Unsupported stream source", type(x)) + + Min = Monoid(name="Min", identity=float("inf")) Max = Monoid(name="Max", identity=-float("inf")) ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) From 137f22e622dd1d75b4a84d4df074fc02d2d1471e Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 22 May 2026 16:23:32 -0400 Subject: [PATCH 06/29] add quadrature --- effectful/handlers/numpyro/monoid.py | 118 +++++++++++++++++---------- tests/test_handlers_jax_monoid.py | 2 - 2 files changed, 77 insertions(+), 43 deletions(-) diff --git a/effectful/handlers/numpyro/monoid.py b/effectful/handlers/numpyro/monoid.py index c68864cd..455f12c8 100644 --- a/effectful/handlers/numpyro/monoid.py +++ b/effectful/handlers/numpyro/monoid.py @@ -5,46 +5,25 @@ ``weighted(d)`` returns a ``Term`` whose ``args[0]`` is ``d`` — so that specialized reduction rules (closed-form expectations, quadrature, etc.) can pattern-match on the distribution. - -Two general-purpose reduction rules are provided here: - -* :class:`NumpyroSampling` — Monte Carlo approximation. Replaces - ``weighted(d)`` with a sample-backed :class:`WeightedStream` of ``n_samples`` - i.i.d. draws and uniform weights, then delegates to the standard - :class:`ReduceWeightedStream` machinery. - -* :class:`NumpyroLogProb` — generic symbolic lowering. Replaces - ``weighted(d)`` with ``WeightedStream(stream(d.support), d.log_prob, Sum)``. - ``Sum`` acts as multiplication in log space and - ``distributes_over(Sum, LogSumExp)`` is registered, so a subsequent - ``LogSumExp.reduce`` is desugared by ``ReduceWeightedStream`` into the - standard log-space expectation integrand: - - LogSumExp.reduce(Sum.plus(d.log_prob(x), body), {x: stream(d.support)}) """ from dataclasses import dataclass import jax import jax.numpy as jnp +import numpy as np import numpyro.distributions as dist import numpyro.distributions.constraints as constraints +import effectful.handlers.jax.numpy as ejnp +from effectful.handlers.jax import jax_getitem from effectful.handlers.jax.monoid import LogSumExp -from effectful.ops.monoid import ( - Monoid, - Product, - Sum, - WeightedStream, - stream, - weighted, -) +from effectful.handlers.numpyro import NormalTerm +from effectful.ops.monoid import Monoid, Product, Sum, WeightedStream, stream, weighted from effectful.ops.semantics import fwd from effectful.ops.syntax import ObjectInterpretation, deffn, implements from effectful.ops.types import NotHandled, Operation, Term -# --- smart constructors stay symbolic for distributions / constraints ----- - @weighted.register(dist.Distribution) def _weighted_dist(_d): @@ -71,18 +50,15 @@ def _weighted_dist_arg(v) -> dist.Distribution | None: return d if isinstance(d, dist.Distribution) else None -# --- rule: Monte Carlo sampling ------------------------------------------- - - @dataclass class NumpyroSampling(ObjectInterpretation): """Replace ``weighted(d)`` with a sample-backed :class:`WeightedStream`. - Draws ``n_samples`` i.i.d. samples from ``d`` and attaches a uniform - weight ``1/n_samples`` (linear space) or ``-log(n_samples)`` (log - space, when the outer monoid is :data:`LogSumExp`). The resulting - :class:`WeightedStream` is then handled by the standard - :class:`ReduceWeightedStream` rewrite. + Draws ``n_samples`` i.i.d. samples from ``d`` and attaches a uniform weight + ``1/n_samples`` (linear space) or ``-log(n_samples)`` (log space, when the + outer monoid is :data:`LogSumExp`). The resulting :class:`WeightedStream` is + then handled by the standard :class:`ReduceWeightedStream` rewrite. + """ rng_key: jax.Array @@ -114,18 +90,78 @@ def reduce(self, monoid, body, streams): return fwd() -# --- rule: symbolic log-prob lowering ------------------------------------- +@dataclass +class NumpyroGaussHermite(ObjectInterpretation): + """Gauss–Hermite quadrature for ``weighted(Normal(μ, σ))``. + + For ``X ∼ Normal(μ, σ²)``, the change of variable ``u = (x-μ)/(σ√2)`` gives + :: + + E[f(X)] = (1/√π) ∫ f(μ + σ√2 · u) e^{-u²} du + ≈ Σᵢ (wᵢ/√π) · f(μ + σ√2 · uᵢ) + + where ``{uᵢ, wᵢ}`` are the physicists' Hermite nodes/weights from + :func:`numpy.polynomial.hermite.hermgauss`. The rule replaces + ``weighted(d)`` with a :class:`WeightedStream` of length ``n_nodes`` and + lets the standard :class:`ReduceWeightedStream` machinery finish. + + Weight monoid is :data:`Product` for linear-space bodies (e.g. + ``Sum.reduce``) and :data:`Sum` for log-space bodies (e.g. + ``LogSumExp.reduce``); both pairs distribute correctly. + + """ + + n_nodes: int = 20 + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + new_streams = dict(streams) + progress = False + for k, v in streams.items(): + d = _weighted_dist_arg(v) + if not isinstance(d, dist.Normal | NormalTerm): + continue + new_streams[k] = self._gauss_hermite(d, monoid) + progress = True + if progress: + return monoid.reduce(body, new_streams) + return fwd() + + def _gauss_hermite(self, d, monoid: Monoid) -> WeightedStream: + u, w = np.polynomial.hermite.hermgauss(self.n_nodes) + u_jax = jnp.asarray(u, dtype=jnp.float32) + w_jax = jnp.asarray(w, dtype=jnp.float32) + + nodes = d.loc + jnp.sqrt(2.0) * d.scale * u_jax + if monoid is LogSumExp: + weights = jnp.log(w_jax) - 0.5 * jnp.log(jnp.pi) + w_monoid: Monoid = Sum + else: + weights = w_jax / jnp.sqrt(jnp.pi) + w_monoid = Product + + # Position-match the node value back to its weight via argmin. The + # weight function is invoked symbolically by ``ReduceWeightedStream`` + # (with a Term arg), so we use the effectful-wrapped jnp so the + # lookup becomes a Term that evaluates to the right scalar once the + # default reduce binds the stream variable to a concrete node. + def weight_fn(x, _nodes=nodes, _w=weights): + idx = ejnp.argmin(ejnp.abs(_nodes - x)) + return jax_getitem(_w, (idx,)) + + return WeightedStream(stream=nodes, weight=weight_fn, monoid=w_monoid) class NumpyroLogProb(ObjectInterpretation): """Lower ``weighted(d)`` to its symbolic log-prob form. - Generic fallback: produces a :class:`WeightedStream` whose stream is - the symbolic ``stream(d.support)``, weight is ``d.log_prob``, and - weight monoid is :data:`Sum` (log-space multiplication). With - ``distributes_over(Sum, LogSumExp)`` registered, a surrounding - ``LogSumExp.reduce`` will then desugar via :class:`ReduceWeightedStream` - into the standard expectation integrand. + Generic fallback: produces a :class:`WeightedStream` whose stream is the + symbolic ``stream(d.support)``, weight is ``d.log_prob``, and weight monoid + is :data:`Sum` (log-space multiplication). With ``distributes_over(Sum, + LogSumExp)`` registered, a surrounding ``LogSumExp.reduce`` will then + desugar via :class:`ReduceWeightedStream` into the standard expectation + integrand. + """ @implements(Monoid.reduce) diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index a7d2a27f..76de00dc 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -9,13 +9,11 @@ ArrayReduce, LogSumExp, ProductPlusJax, - SumPlusJax, ) from effectful.handlers.jax.scipy.special import logsumexp from effectful.ops.monoid import ( Max, Min, - NormalizeIntp, Product, ReduceWeightedStream, Sum, From 1b3e6ef3c0cf3bfdbfefd979348905f0992f3586 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 22 May 2026 16:40:36 -0400 Subject: [PATCH 07/29] add tests --- effectful/handlers/numpyro/monoid.py | 66 ++++- tests/test_handlers_numpyro_monoid.py | 363 ++++++++++++++++++++++++++ 2 files changed, 428 insertions(+), 1 deletion(-) create mode 100644 tests/test_handlers_numpyro_monoid.py diff --git a/effectful/handlers/numpyro/monoid.py b/effectful/handlers/numpyro/monoid.py index 455f12c8..f28402e1 100644 --- a/effectful/handlers/numpyro/monoid.py +++ b/effectful/handlers/numpyro/monoid.py @@ -18,7 +18,11 @@ import effectful.handlers.jax.numpy as ejnp from effectful.handlers.jax import jax_getitem from effectful.handlers.jax.monoid import LogSumExp -from effectful.handlers.numpyro import NormalTerm +from effectful.handlers.numpyro import ( + CategoricalLogitsTerm, + CategoricalProbsTerm, + NormalTerm, +) from effectful.ops.monoid import Monoid, Product, Sum, WeightedStream, stream, weighted from effectful.ops.semantics import fwd from effectful.ops.syntax import ObjectInterpretation, deffn, implements @@ -152,6 +156,66 @@ def weight_fn(x, _nodes=nodes, _w=weights): return WeightedStream(stream=nodes, weight=weight_fn, monoid=w_monoid) +@dataclass +class NumpyroCategorical(ObjectInterpretation): + """Exact enumeration ('quadrature') for ``weighted(Categorical(...))``. + + A categorical with ``K`` outcomes has finite integer support + ``{0, …, K-1}``; integration reduces to an exact finite sum. The rule + replaces ``weighted(d)`` with a :class:`WeightedStream` whose stream is + ``jnp.arange(K)`` and whose weight indexes into the per-outcome + probability vector. + + Weight monoid is :data:`Product` for linear-space bodies and :data:`Sum` + for log-space bodies (under :data:`LogSumExp`), matching the + distributes-over pairs used by :class:`ReduceWeightedStream`. + + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + new_streams = dict(streams) + progress = False + for k, v in streams.items(): + d = _weighted_dist_arg(v) + ws = self._categorical(d, monoid) + if ws is None: + continue + new_streams[k] = ws + progress = True + if progress: + return monoid.reduce(body, new_streams) + return fwd() + + def _categorical(self, d, monoid: Monoid) -> WeightedStream | None: + # Pick the natural representation for the target weight monoid so we + # don't go probs→log or logits→probs→log unnecessarily. + if monoid is LogSumExp: + w_monoid: Monoid = Sum + if isinstance(d, dist.CategoricalLogits | CategoricalLogitsTerm): + weights = jax.nn.log_softmax(jnp.asarray(d.logits), axis=-1) + elif isinstance(d, dist.CategoricalProbs | CategoricalProbsTerm): + weights = jnp.log(jnp.asarray(d.probs)) + else: + return None + else: + w_monoid = Product + if isinstance(d, dist.CategoricalProbs | CategoricalProbsTerm): + weights = jnp.asarray(d.probs) + elif isinstance(d, dist.CategoricalLogits | CategoricalLogitsTerm): + weights = jax.nn.softmax(jnp.asarray(d.logits), axis=-1) + else: + return None + + indices = jnp.arange(weights.shape[-1]) + + # The support value *is* the index, so the lookup is direct. + def weight_fn(x, _w=weights): + return jax_getitem(_w, (x,)) + + return WeightedStream(stream=indices, weight=weight_fn, monoid=w_monoid) + + class NumpyroLogProb(ObjectInterpretation): """Lower ``weighted(d)`` to its symbolic log-prob form. diff --git a/tests/test_handlers_numpyro_monoid.py b/tests/test_handlers_numpyro_monoid.py new file mode 100644 index 00000000..4b34a388 --- /dev/null +++ b/tests/test_handlers_numpyro_monoid.py @@ -0,0 +1,363 @@ +"""Unit tests for the rewrite rules in ``effectful.handlers.numpyro.monoid``. + +Tests follow the conventions in ``test_ops_monoid.py``: each rule is verified +via a symbolic ``lhs`` and the expected post-rewrite ``rhs``. We assert both +syntactic equivalence after applying the rule and semantic equivalence under +random interpretations of the free body op. +""" + +import math + +import jax +import jax.numpy as jnp +import numpy as np +import numpyro.distributions as dist +import pytest +from hypothesis import HealthCheck, given, settings + +import effectful.handlers.jax.monoid # noqa: F401 # registers jax monoid handlers +import effectful.handlers.jax.numpy as ejnp +from effectful.handlers.jax import jax_getitem +from effectful.handlers.jax.monoid import LogSumExp +from effectful.handlers.numpyro.monoid import ( + NumpyroCategorical, + NumpyroGaussHermite, + NumpyroLogProb, + NumpyroSampling, +) +from effectful.ops.monoid import ( + NormalizeIntp, + Product, + Sum, + WeightedStream, + stream, + weighted, +) +from effectful.ops.semantics import coproduct, evaluate, handler +from effectful.ops.syntax import deffn +from effectful.ops.types import Operation +from tests._monoid_helpers import ( + JAX_BACKEND, + Backend, + define_vars, + random_interpretation, + syntactic_eq_alpha, +) + + +@pytest.fixture +def backend() -> Backend: + return JAX_BACKEND + + +def check_numpyro_rewrite( + lhs, + rhs, + *, + rule, + backend: Backend, + syntactic_rule=None, + free_vars=(), + max_examples: int = 25, +) -> None: + """``check_rewrite`` variant for numpyro rules. + + ``syntactic_rule`` (default ``rule``) is installed for the syntactic + step; ``rule`` itself is installed alongside :data:`NormalizeIntp` for + the property-based semantic step so both sides can reduce to a value. + """ + syn = syntactic_rule if syntactic_rule is not None else rule + with handler(syn): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + @given(intp=random_interpretation(free_vars)) + @settings( + max_examples=max_examples, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + ) + def _check_semantics(intp): + with handler(coproduct(NormalizeIntp, rule)), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert backend.eq(lhs_val, rhs_val) + + _check_semantics() + + +# --------------------------------------------------------------------------- +# NumpyroLogProb — pure structural rewrite +# --------------------------------------------------------------------------- + + +def test_logprob_lowering(backend): + """``NumpyroLogProb`` replaces ``weighted(d)`` with + ``WeightedStream(stream(d.support), d.log_prob, Sum)``. + """ + a = define_vars("a", typ=backend.scalar_typ) + body = backend.fresh_op("body", n_args=1, ret="scalar") + d = dist.Normal(0.0, 1.0) + + lhs = Sum.reduce(body(a()), {a: weighted(d)}) + rhs = Sum.reduce( + body(a()), + { + a: WeightedStream( + stream=stream(d.support), weight=d.log_prob, monoid=Sum + ) + }, + ) + + check_numpyro_rewrite( + lhs=lhs, rhs=rhs, rule=NumpyroLogProb(), backend=backend, free_vars=[body] + ) + + +# --------------------------------------------------------------------------- +# NumpyroGaussHermite — replace weighted(Normal) with explicit n-node sum +# --------------------------------------------------------------------------- + + +def _gauss_hermite_nodes_weights(loc, scale, n, log_space: bool): + u, w_raw = np.polynomial.hermite.hermgauss(n) + u_jax = jnp.asarray(u, dtype=jnp.float32) + w_jax = jnp.asarray(w_raw, dtype=jnp.float32) + nodes = loc + jnp.sqrt(2.0) * scale * u_jax + if log_space: + weights = jnp.log(w_jax) - 0.5 * jnp.log(jnp.pi) + else: + weights = w_jax / jnp.sqrt(jnp.pi) + return nodes, weights + + +def test_gauss_hermite_linear(backend): + """Under ``Sum.reduce``, ``NumpyroGaussHermite`` lowers + ``weighted(Normal(μ, σ))`` to a Product-weighted stream of ``n_nodes`` + nodes, which then reduces (via ``ReduceWeightedStream`` and the default + rule) to the explicit weighted sum ``Σᵢ wᵢ · body(xᵢ)``. + """ + n = 8 + loc, scale = 0.5, 1.3 + d = dist.Normal(loc, scale) + + a = define_vars("a", typ=backend.scalar_typ) + body = backend.fresh_op("body", n_args=1, ret="scalar") + nodes, weights = _gauss_hermite_nodes_weights(loc, scale, n, log_space=False) + + lhs = Sum.reduce(body(a()), {a: weighted(d)}) + rhs = Sum.plus( + *[ + Product.plus(jax_getitem(weights, (i,)), body(jax_getitem(nodes, (i,)))) + for i in range(n) + ] + ) + + # Full pipeline (rule + NormalizeIntp) for syntactic comparison so the + # opaque weight closure inside the WeightedStream gets reduced away. + full = coproduct(NormalizeIntp, NumpyroGaussHermite(n_nodes=n)) + check_numpyro_rewrite( + lhs=lhs, + rhs=rhs, + rule=NumpyroGaussHermite(n_nodes=n), + syntactic_rule=full, + backend=backend, + free_vars=[body], + ) + + +def test_gauss_hermite_logsumexp(backend): + """Under ``LogSumExp.reduce``, weights are log-space and combined via + ``Sum`` (log-multiplication). The lowered form is + ``LogSumExp.plus(Sum.plus(log_wᵢ, log_body(xᵢ)) for i)``. + """ + n = 8 + loc, scale = 0.0, 1.0 + d = dist.Normal(loc, scale) + + a = define_vars("a", typ=backend.scalar_typ) + body = backend.fresh_op("body", n_args=1, ret="scalar") + nodes, log_weights = _gauss_hermite_nodes_weights(loc, scale, n, log_space=True) + + lhs = LogSumExp.reduce(body(a()), {a: weighted(d)}) + rhs = LogSumExp.plus( + *[ + Sum.plus(jax_getitem(log_weights, (i,)), body(jax_getitem(nodes, (i,)))) + for i in range(n) + ] + ) + + full = coproduct(NormalizeIntp, NumpyroGaussHermite(n_nodes=n)) + check_numpyro_rewrite( + lhs=lhs, + rhs=rhs, + rule=NumpyroGaussHermite(n_nodes=n), + syntactic_rule=full, + backend=backend, + free_vars=[body], + ) + + +# --------------------------------------------------------------------------- +# NumpyroCategorical — replace weighted(Categorical) with explicit K-term sum +# --------------------------------------------------------------------------- + + +def test_categorical_probs_linear(backend): + """Under ``Sum.reduce``, ``NumpyroCategorical`` lowers + ``weighted(CategoricalProbs(probs))`` to ``Σᵢ probs[i] · body(i)``. + """ + probs = jnp.array([0.1, 0.2, 0.3, 0.4]) + d = dist.CategoricalProbs(probs=probs) + k = probs.shape[-1] + + i_op = define_vars("i_op", typ=backend.scalar_typ) + body = backend.fresh_op("body", n_args=1, ret="scalar") + indices = jnp.arange(k) + + lhs = Sum.reduce(body(i_op()), {i_op: weighted(d)}) + rhs = Sum.plus( + *[ + Product.plus(jax_getitem(probs, (i,)), body(jax_getitem(indices, (i,)))) + for i in range(k) + ] + ) + + full = coproduct(NormalizeIntp, NumpyroCategorical()) + check_numpyro_rewrite( + lhs=lhs, + rhs=rhs, + rule=NumpyroCategorical(), + syntactic_rule=full, + backend=backend, + free_vars=[body], + ) + + +def test_categorical_logits_matches_probs(backend): + """``CategoricalLogits(log probs)`` and ``CategoricalProbs(probs)`` must + lower to the same value under the same body. + """ + probs = jnp.array([0.1, 0.2, 0.3, 0.4]) + d_p = dist.CategoricalProbs(probs=probs) + d_l = dist.CategoricalLogits(logits=jnp.log(probs)) + + i_op = define_vars("i_op", typ=backend.scalar_typ) + body = backend.fresh_op("body", n_args=1, ret="scalar") + + expr_p = Sum.reduce(body(i_op()), {i_op: weighted(d_p)}) + expr_l = Sum.reduce(body(i_op()), {i_op: weighted(d_l)}) + + @given(intp=random_interpretation([body])) + @settings( + max_examples=25, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + ) + def _check(intp): + with handler(coproduct(NormalizeIntp, NumpyroCategorical())), handler(intp): + r_p = evaluate(expr_p) + r_l = evaluate(expr_l) + assert backend.eq(r_p, r_l) + + _check() + + +def test_categorical_logsumexp(backend): + """Under ``LogSumExp.reduce`` with ``CategoricalProbs``, weights are + ``log(probs)`` combined via ``Sum`` (log-multiplication). + """ + probs = jnp.array([0.1, 0.2, 0.3, 0.4]) + d = dist.CategoricalProbs(probs=probs) + k = probs.shape[-1] + log_probs = jnp.log(probs) + + i_op = define_vars("i_op", typ=backend.scalar_typ) + body = backend.fresh_op("body", n_args=1, ret="scalar") + indices = jnp.arange(k) + + lhs = LogSumExp.reduce(body(i_op()), {i_op: weighted(d)}) + rhs = LogSumExp.plus( + *[ + Sum.plus(jax_getitem(log_probs, (i,)), body(jax_getitem(indices, (i,)))) + for i in range(k) + ] + ) + + full = coproduct(NormalizeIntp, NumpyroCategorical()) + check_numpyro_rewrite( + lhs=lhs, + rhs=rhs, + rule=NumpyroCategorical(), + syntactic_rule=full, + backend=backend, + free_vars=[body], + ) + + +# --------------------------------------------------------------------------- +# NumpyroSampling — replace weighted(d) with a sample-backed WeightedStream +# --------------------------------------------------------------------------- + + +def test_sampling_linear(backend): + """``NumpyroSampling`` lowers ``weighted(d)`` to a Product-weighted + sample stream; the rewrite is deterministic for a fixed ``rng_key``. + """ + n = 64 + key = jax.random.key(0) + d = dist.Normal(0.0, 1.0) + samples = d.sample(key, sample_shape=(n,)) + + a = define_vars("a", typ=backend.scalar_typ) + body = backend.fresh_op("body", n_args=1, ret="scalar") + + lhs = Sum.reduce(body(a()), {a: weighted(d)}) + rhs = Sum.plus( + *[ + Product.plus(1.0 / n, body(jax_getitem(samples, (i,)))) + for i in range(n) + ] + ) + + full = coproduct(NormalizeIntp, NumpyroSampling(rng_key=key, n_samples=n)) + check_numpyro_rewrite( + lhs=lhs, + rhs=rhs, + rule=NumpyroSampling(rng_key=key, n_samples=n), + syntactic_rule=full, + backend=backend, + free_vars=[body], + ) + + +def test_sampling_logsumexp(backend): + """Under ``LogSumExp.reduce``, ``NumpyroSampling`` uses log-uniform weights + ``-log(N)`` combined via ``Sum`` (log-multiplication). + """ + n = 64 + key = jax.random.key(0) + d = dist.Normal(0.0, 1.0) + samples = d.sample(key, sample_shape=(n,)) + log_w = -jnp.log(n) + + a = define_vars("a", typ=backend.scalar_typ) + body = backend.fresh_op("body", n_args=1, ret="scalar") + + lhs = LogSumExp.reduce(body(a()), {a: weighted(d)}) + rhs = LogSumExp.plus( + *[ + Sum.plus(log_w, body(jax_getitem(samples, (i,)))) + for i in range(n) + ] + ) + + full = coproduct(NormalizeIntp, NumpyroSampling(rng_key=key, n_samples=n)) + check_numpyro_rewrite( + lhs=lhs, + rhs=rhs, + rule=NumpyroSampling(rng_key=key, n_samples=n), + syntactic_rule=full, + backend=backend, + free_vars=[body], + ) From 736b9d5c14da6cd9d59896872d8f1118fa2f3b5a Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 22 May 2026 16:50:36 -0400 Subject: [PATCH 08/29] wip --- tests/_monoid_helpers.py | 3 +- tests/test_handlers_numpyro_monoid.py | 250 +++----------------------- 2 files changed, 29 insertions(+), 224 deletions(-) diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 76219b16..2ebd97a4 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -350,6 +350,7 @@ def check_rewrite( free_vars=[], max_examples: int = 25, deadline=None, + normalize=NormalizeIntp, ) -> None: with handler(rule): norm = evaluate(lhs) @@ -358,7 +359,7 @@ def check_rewrite( @given(intp=random_interpretation(free_vars)) @settings(max_examples=max_examples, deadline=deadline) def _check_semantics(intp): - with handler(NormalizeIntp), handler(intp): + with handler(normalize), handler(intp): lhs_val = evaluate(lhs) rhs_val = evaluate(rhs) assert backend.eq(lhs_val, rhs_val) diff --git a/tests/test_handlers_numpyro_monoid.py b/tests/test_handlers_numpyro_monoid.py index 4b34a388..097cbc5d 100644 --- a/tests/test_handlers_numpyro_monoid.py +++ b/tests/test_handlers_numpyro_monoid.py @@ -4,38 +4,26 @@ via a symbolic ``lhs`` and the expected post-rewrite ``rhs``. We assert both syntactic equivalence after applying the rule and semantic equivalence under random interpretations of the free body op. -""" -import math +For the categorical rule the lowered form is naturally a length-K explicit +sum, but ``ArrayReduce`` (inside ``NormalizeIntp``) further collapses that +into a single named-dim ``jnp.sum``; the RHS therefore matches that final +form, in the style of ``test_jax_weighted_reduce``. +""" import jax import jax.numpy as jnp -import numpy as np import numpyro.distributions as dist import pytest from hypothesis import HealthCheck, given, settings import effectful.handlers.jax.monoid # noqa: F401 # registers jax monoid handlers -import effectful.handlers.jax.numpy as ejnp -from effectful.handlers.jax import jax_getitem +from effectful.handlers.jax import bind_dims, unbind_dims from effectful.handlers.jax.monoid import LogSumExp -from effectful.handlers.numpyro.monoid import ( - NumpyroCategorical, - NumpyroGaussHermite, - NumpyroLogProb, - NumpyroSampling, -) -from effectful.ops.monoid import ( - NormalizeIntp, - Product, - Sum, - WeightedStream, - stream, - weighted, -) +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.handlers.numpyro.monoid import NumpyroCategorical +from effectful.ops.monoid import NormalizeIntp, Sum, weighted from effectful.ops.semantics import coproduct, evaluate, handler -from effectful.ops.syntax import deffn -from effectful.ops.types import Operation from tests._monoid_helpers import ( JAX_BACKEND, Backend, @@ -87,140 +75,25 @@ def _check_semantics(intp): # --------------------------------------------------------------------------- -# NumpyroLogProb — pure structural rewrite -# --------------------------------------------------------------------------- - - -def test_logprob_lowering(backend): - """``NumpyroLogProb`` replaces ``weighted(d)`` with - ``WeightedStream(stream(d.support), d.log_prob, Sum)``. - """ - a = define_vars("a", typ=backend.scalar_typ) - body = backend.fresh_op("body", n_args=1, ret="scalar") - d = dist.Normal(0.0, 1.0) - - lhs = Sum.reduce(body(a()), {a: weighted(d)}) - rhs = Sum.reduce( - body(a()), - { - a: WeightedStream( - stream=stream(d.support), weight=d.log_prob, monoid=Sum - ) - }, - ) - - check_numpyro_rewrite( - lhs=lhs, rhs=rhs, rule=NumpyroLogProb(), backend=backend, free_vars=[body] - ) - - -# --------------------------------------------------------------------------- -# NumpyroGaussHermite — replace weighted(Normal) with explicit n-node sum -# --------------------------------------------------------------------------- - - -def _gauss_hermite_nodes_weights(loc, scale, n, log_space: bool): - u, w_raw = np.polynomial.hermite.hermgauss(n) - u_jax = jnp.asarray(u, dtype=jnp.float32) - w_jax = jnp.asarray(w_raw, dtype=jnp.float32) - nodes = loc + jnp.sqrt(2.0) * scale * u_jax - if log_space: - weights = jnp.log(w_jax) - 0.5 * jnp.log(jnp.pi) - else: - weights = w_jax / jnp.sqrt(jnp.pi) - return nodes, weights - - -def test_gauss_hermite_linear(backend): - """Under ``Sum.reduce``, ``NumpyroGaussHermite`` lowers - ``weighted(Normal(μ, σ))`` to a Product-weighted stream of ``n_nodes`` - nodes, which then reduces (via ``ReduceWeightedStream`` and the default - rule) to the explicit weighted sum ``Σᵢ wᵢ · body(xᵢ)``. - """ - n = 8 - loc, scale = 0.5, 1.3 - d = dist.Normal(loc, scale) - - a = define_vars("a", typ=backend.scalar_typ) - body = backend.fresh_op("body", n_args=1, ret="scalar") - nodes, weights = _gauss_hermite_nodes_weights(loc, scale, n, log_space=False) - - lhs = Sum.reduce(body(a()), {a: weighted(d)}) - rhs = Sum.plus( - *[ - Product.plus(jax_getitem(weights, (i,)), body(jax_getitem(nodes, (i,)))) - for i in range(n) - ] - ) - - # Full pipeline (rule + NormalizeIntp) for syntactic comparison so the - # opaque weight closure inside the WeightedStream gets reduced away. - full = coproduct(NormalizeIntp, NumpyroGaussHermite(n_nodes=n)) - check_numpyro_rewrite( - lhs=lhs, - rhs=rhs, - rule=NumpyroGaussHermite(n_nodes=n), - syntactic_rule=full, - backend=backend, - free_vars=[body], - ) - - -def test_gauss_hermite_logsumexp(backend): - """Under ``LogSumExp.reduce``, weights are log-space and combined via - ``Sum`` (log-multiplication). The lowered form is - ``LogSumExp.plus(Sum.plus(log_wᵢ, log_body(xᵢ)) for i)``. - """ - n = 8 - loc, scale = 0.0, 1.0 - d = dist.Normal(loc, scale) - - a = define_vars("a", typ=backend.scalar_typ) - body = backend.fresh_op("body", n_args=1, ret="scalar") - nodes, log_weights = _gauss_hermite_nodes_weights(loc, scale, n, log_space=True) - - lhs = LogSumExp.reduce(body(a()), {a: weighted(d)}) - rhs = LogSumExp.plus( - *[ - Sum.plus(jax_getitem(log_weights, (i,)), body(jax_getitem(nodes, (i,)))) - for i in range(n) - ] - ) - - full = coproduct(NormalizeIntp, NumpyroGaussHermite(n_nodes=n)) - check_numpyro_rewrite( - lhs=lhs, - rhs=rhs, - rule=NumpyroGaussHermite(n_nodes=n), - syntactic_rule=full, - backend=backend, - free_vars=[body], - ) - - -# --------------------------------------------------------------------------- -# NumpyroCategorical — replace weighted(Categorical) with explicit K-term sum +# NumpyroCategorical # --------------------------------------------------------------------------- def test_categorical_probs_linear(backend): - """Under ``Sum.reduce``, ``NumpyroCategorical`` lowers - ``weighted(CategoricalProbs(probs))`` to ``Σᵢ probs[i] · body(i)``. + """Under ``Sum.reduce`` with ``CategoricalProbs``, the per-index weight + is ``probs[i]`` (linear) and the lowered form is the named-dim + ``jnp.sum(probs[k] * body(indices[k]))``. """ probs = jnp.array([0.1, 0.2, 0.3, 0.4]) d = dist.CategoricalProbs(probs=probs) - k = probs.shape[-1] + indices = jnp.arange(probs.shape[-1]) - i_op = define_vars("i_op", typ=backend.scalar_typ) + i_op, k = define_vars("i_op", "k", typ=backend.scalar_typ) body = backend.fresh_op("body", n_args=1, ret="scalar") - indices = jnp.arange(k) lhs = Sum.reduce(body(i_op()), {i_op: weighted(d)}) - rhs = Sum.plus( - *[ - Product.plus(jax_getitem(probs, (i,)), body(jax_getitem(indices, (i,)))) - for i in range(k) - ] + rhs = jnp.sum( + bind_dims(unbind_dims(probs, k) * body(unbind_dims(indices, k)), k), axis=0 ) full = coproduct(NormalizeIntp, NumpyroCategorical()) @@ -230,7 +103,7 @@ def test_categorical_probs_linear(backend): rule=NumpyroCategorical(), syntactic_rule=full, backend=backend, - free_vars=[body], + free_vars=[k, body], ) @@ -265,23 +138,22 @@ def _check(intp): def test_categorical_logsumexp(backend): """Under ``LogSumExp.reduce`` with ``CategoricalProbs``, weights are - ``log(probs)`` combined via ``Sum`` (log-multiplication). + ``log(probs)`` combined via ``Sum`` (log-multiplication); the lowered + form is ``logsumexp(log_probs[k] + body(indices[k]))`` along the + named dim. """ probs = jnp.array([0.1, 0.2, 0.3, 0.4]) d = dist.CategoricalProbs(probs=probs) - k = probs.shape[-1] log_probs = jnp.log(probs) + indices = jnp.arange(probs.shape[-1]) - i_op = define_vars("i_op", typ=backend.scalar_typ) + i_op, k = define_vars("i_op", "k", typ=backend.scalar_typ) body = backend.fresh_op("body", n_args=1, ret="scalar") - indices = jnp.arange(k) lhs = LogSumExp.reduce(body(i_op()), {i_op: weighted(d)}) - rhs = LogSumExp.plus( - *[ - Sum.plus(jax_getitem(log_probs, (i,)), body(jax_getitem(indices, (i,)))) - for i in range(k) - ] + rhs = logsumexp( + bind_dims(unbind_dims(log_probs, k) + body(unbind_dims(indices, k)), k), + axis=0, ) full = coproduct(NormalizeIntp, NumpyroCategorical()) @@ -291,73 +163,5 @@ def test_categorical_logsumexp(backend): rule=NumpyroCategorical(), syntactic_rule=full, backend=backend, - free_vars=[body], - ) - - -# --------------------------------------------------------------------------- -# NumpyroSampling — replace weighted(d) with a sample-backed WeightedStream -# --------------------------------------------------------------------------- - - -def test_sampling_linear(backend): - """``NumpyroSampling`` lowers ``weighted(d)`` to a Product-weighted - sample stream; the rewrite is deterministic for a fixed ``rng_key``. - """ - n = 64 - key = jax.random.key(0) - d = dist.Normal(0.0, 1.0) - samples = d.sample(key, sample_shape=(n,)) - - a = define_vars("a", typ=backend.scalar_typ) - body = backend.fresh_op("body", n_args=1, ret="scalar") - - lhs = Sum.reduce(body(a()), {a: weighted(d)}) - rhs = Sum.plus( - *[ - Product.plus(1.0 / n, body(jax_getitem(samples, (i,)))) - for i in range(n) - ] - ) - - full = coproduct(NormalizeIntp, NumpyroSampling(rng_key=key, n_samples=n)) - check_numpyro_rewrite( - lhs=lhs, - rhs=rhs, - rule=NumpyroSampling(rng_key=key, n_samples=n), - syntactic_rule=full, - backend=backend, - free_vars=[body], - ) - - -def test_sampling_logsumexp(backend): - """Under ``LogSumExp.reduce``, ``NumpyroSampling`` uses log-uniform weights - ``-log(N)`` combined via ``Sum`` (log-multiplication). - """ - n = 64 - key = jax.random.key(0) - d = dist.Normal(0.0, 1.0) - samples = d.sample(key, sample_shape=(n,)) - log_w = -jnp.log(n) - - a = define_vars("a", typ=backend.scalar_typ) - body = backend.fresh_op("body", n_args=1, ret="scalar") - - lhs = LogSumExp.reduce(body(a()), {a: weighted(d)}) - rhs = LogSumExp.plus( - *[ - Sum.plus(log_w, body(jax_getitem(samples, (i,)))) - for i in range(n) - ] - ) - - full = coproduct(NormalizeIntp, NumpyroSampling(rng_key=key, n_samples=n)) - check_numpyro_rewrite( - lhs=lhs, - rhs=rhs, - rule=NumpyroSampling(rng_key=key, n_samples=n), - syntactic_rule=full, - backend=backend, - free_vars=[body], + free_vars=[k, body], ) From 62cced670bb08db568c8ed5abfc7dbd847e3f4cf Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 22 May 2026 17:51:31 -0400 Subject: [PATCH 09/29] refactor tests --- effectful/handlers/numpyro/monoid.py | 13 ++- tests/test_handlers_numpyro_monoid.py | 123 ++++++-------------------- 2 files changed, 38 insertions(+), 98 deletions(-) diff --git a/effectful/handlers/numpyro/monoid.py b/effectful/handlers/numpyro/monoid.py index f28402e1..256c3d61 100644 --- a/effectful/handlers/numpyro/monoid.py +++ b/effectful/handlers/numpyro/monoid.py @@ -23,7 +23,15 @@ CategoricalProbsTerm, NormalTerm, ) -from effectful.ops.monoid import Monoid, Product, Sum, WeightedStream, stream, weighted +from effectful.ops.monoid import ( + Monoid, + NormalizeIntp, + Product, + Sum, + WeightedStream, + stream, + weighted, +) from effectful.ops.semantics import fwd from effectful.ops.syntax import ObjectInterpretation, deffn, implements from effectful.ops.types import NotHandled, Operation, Term @@ -243,3 +251,6 @@ def reduce(self, monoid, body, streams): if progress: return monoid.reduce(body, new_streams) return fwd() + + +NormalizeIntp.extend(NumpyroCategorical()) diff --git a/tests/test_handlers_numpyro_monoid.py b/tests/test_handlers_numpyro_monoid.py index 1510b9f4..ccc2ca1d 100644 --- a/tests/test_handlers_numpyro_monoid.py +++ b/tests/test_handlers_numpyro_monoid.py @@ -1,35 +1,23 @@ """Unit tests for the rewrite rules in ``effectful.handlers.numpyro.monoid``. Tests follow the conventions in ``test_ops_monoid.py``: each rule is verified -via a symbolic ``lhs`` and the expected post-rewrite ``rhs``. We assert both -syntactic equivalence after applying the rule and semantic equivalence under -random interpretations of the free body op. - -For the categorical rule the lowered form is naturally a length-K explicit -sum, but ``ArrayReduce`` (inside ``NormalizeIntp``) further collapses that -into a single named-dim ``jnp.sum``; the RHS therefore matches that final -form, in the style of ``test_jax_weighted_reduce``. +via a symbolic ``lhs`` and the expected post-rewrite ``rhs``. The numpyro +categorical rule is part of :data:`NormalizeIntp`, so plain ``check_rewrite`` +suffices. """ -import jax.numpy as jnp +import jax import numpyro.distributions as dist import pytest -from hypothesis import HealthCheck, given, settings import effectful.handlers.jax.monoid # noqa: F401 # registers jax monoid handlers +import effectful.handlers.jax.numpy as jnp +import effectful.handlers.numpyro.monoid # noqa: F401 # registers numpyro monoid handlers from effectful.handlers.jax import bind_dims, unbind_dims from effectful.handlers.jax.monoid import LogSumExp from effectful.handlers.jax.scipy.special import logsumexp -from effectful.handlers.numpyro.monoid import NumpyroCategorical from effectful.ops.monoid import NormalizeIntp, Sum, weighted -from effectful.ops.semantics import coproduct, evaluate, handler -from tests._monoid_helpers import ( - JAX_BACKEND, - Backend, - define_vars, - random_interpretation, - syntactic_eq_alpha, -) +from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars @pytest.fixture @@ -37,47 +25,6 @@ def backend() -> Backend: return JAX_BACKEND -def check_numpyro_rewrite( - lhs, - rhs, - *, - rule, - backend: Backend, - syntactic_rule=None, - free_vars=(), - max_examples: int = 25, -) -> None: - """``check_rewrite`` variant for numpyro rules. - - ``syntactic_rule`` (default ``rule``) is installed for the syntactic - step; ``rule`` itself is installed alongside :data:`NormalizeIntp` for - the property-based semantic step so both sides can reduce to a value. - """ - syn = syntactic_rule if syntactic_rule is not None else rule - with handler(syn): - norm = evaluate(lhs) - assert syntactic_eq_alpha(norm, rhs) - - @given(intp=random_interpretation(free_vars)) - @settings( - max_examples=max_examples, - deadline=None, - suppress_health_check=[HealthCheck.function_scoped_fixture], - ) - def _check_semantics(intp): - with handler(coproduct(NormalizeIntp, rule)), handler(intp): - lhs_val = evaluate(lhs) - rhs_val = evaluate(rhs) - assert backend.eq(lhs_val, rhs_val) - - _check_semantics() - - -# --------------------------------------------------------------------------- -# NumpyroCategorical -# --------------------------------------------------------------------------- - - def test_categorical_probs_linear(backend): """Under ``Sum.reduce`` with ``CategoricalProbs``, the per-index weight is ``probs[i]`` (linear) and the lowered form is the named-dim @@ -95,44 +42,32 @@ def test_categorical_probs_linear(backend): bind_dims(unbind_dims(probs, k) * body(unbind_dims(indices, k)), k), axis=0 ) - full = coproduct(NormalizeIntp, NumpyroCategorical()) - check_numpyro_rewrite( - lhs=lhs, - rhs=rhs, - rule=NumpyroCategorical(), - syntactic_rule=full, - backend=backend, - free_vars=[k, body], + check_rewrite( + lhs=lhs, rhs=rhs, rule=NormalizeIntp, backend=backend, free_vars=[k, body] ) -def test_categorical_logits_matches_probs(backend): - """``CategoricalLogits(log probs)`` and ``CategoricalProbs(probs)`` must - lower to the same value under the same body. +def test_categorical_logits_linear(backend): + """Under ``Sum.reduce`` with ``CategoricalLogits``, the per-index weight + is ``softmax(logits)[i]`` and the lowered form matches the same + named-dim ``jnp.sum`` shape as the probs case. """ - probs = jnp.array([0.1, 0.2, 0.3, 0.4]) - d_p = dist.CategoricalProbs(probs=probs) - d_l = dist.CategoricalLogits(logits=jnp.log(probs)) + logits = jnp.array([0.5, -1.0, 2.0, 0.1]) + d = dist.CategoricalLogits(logits=logits) + probs = jax.nn.softmax(logits, axis=-1) + indices = jnp.arange(logits.shape[-1]) - i_op = define_vars("i_op", typ=backend.scalar_typ) + i_op, k = define_vars("i_op", "k", typ=backend.scalar_typ) body = backend.fresh_op("body", n_args=1, ret="scalar") - expr_p = Sum.reduce(body(i_op()), {i_op: weighted(d_p)}) - expr_l = Sum.reduce(body(i_op()), {i_op: weighted(d_l)}) - - @given(intp=random_interpretation([body])) - @settings( - max_examples=25, - deadline=None, - suppress_health_check=[HealthCheck.function_scoped_fixture], + lhs = Sum.reduce(body(i_op()), {i_op: weighted(d)}) + rhs = jnp.sum( + bind_dims(unbind_dims(probs, k) * body(unbind_dims(indices, k)), k), axis=0 ) - def _check(intp): - with handler(coproduct(NormalizeIntp, NumpyroCategorical())), handler(intp): - r_p = evaluate(expr_p) - r_l = evaluate(expr_l) - assert backend.eq(r_p, r_l) - _check() + check_rewrite( + lhs=lhs, rhs=rhs, rule=NormalizeIntp, backend=backend, free_vars=[k, body] + ) def test_categorical_logsumexp(backend): @@ -155,12 +90,6 @@ def test_categorical_logsumexp(backend): axis=0, ) - full = coproduct(NormalizeIntp, NumpyroCategorical()) - check_numpyro_rewrite( - lhs=lhs, - rhs=rhs, - rule=NumpyroCategorical(), - syntactic_rule=full, - backend=backend, - free_vars=[k, body], + check_rewrite( + lhs=lhs, rhs=rhs, rule=NormalizeIntp, backend=backend, free_vars=[k, body] ) From afdb8109a1b6db50e54d6328e4015b245b45e479 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 26 May 2026 11:11:46 -0400 Subject: [PATCH 10/29] wip --- effectful/handlers/numpyro/monoid.py | 12 ++-- effectful/ops/monoid.py | 91 ++++++++++++++++++++++------ tests/_monoid_helpers.py | 40 +++++++++++- tests/test_ops_monoid.py | 55 +++++++++++++++-- 4 files changed, 164 insertions(+), 34 deletions(-) diff --git a/effectful/handlers/numpyro/monoid.py b/effectful/handlers/numpyro/monoid.py index 256c3d61..0564c103 100644 --- a/effectful/handlers/numpyro/monoid.py +++ b/effectful/handlers/numpyro/monoid.py @@ -29,7 +29,7 @@ Product, Sum, WeightedStream, - stream, + to_stream, weighted, ) from effectful.ops.semantics import fwd @@ -42,12 +42,12 @@ def _weighted_dist(_d): raise NotHandled -@stream.register(dist.Distribution) +@to_stream.register(dist.Distribution) def _stream_dist(_d): raise NotHandled -@stream.register(constraints.Constraint) +@to_stream.register(constraints.Constraint) def _stream_constraint(_c): raise NotHandled @@ -221,7 +221,7 @@ def _categorical(self, d, monoid: Monoid) -> WeightedStream | None: def weight_fn(x, _w=weights): return jax_getitem(_w, (x,)) - return WeightedStream(stream=indices, weight=weight_fn, monoid=w_monoid) + return weighted(stream=indices, weight=weight_fn, monoid=w_monoid) class NumpyroLogProb(ObjectInterpretation): @@ -244,8 +244,8 @@ def reduce(self, monoid, body, streams): d = _weighted_dist_arg(v) if d is None: continue - new_streams[k] = WeightedStream( - stream=stream(d.support), weight=d.log_prob, monoid=Sum + new_streams[k] = weighted( + stream=to_stream(d.support), weight=d.log_prob, monoid=Sum ) progress = True if progress: diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 3198e0c3..12662723 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -14,10 +14,9 @@ from effectful.ops.syntax import ( ObjectInterpretation, Scoped, - defdata, deffn, + defop, implements, - iter_, syntactic_eq, syntactic_hash, ) @@ -120,30 +119,24 @@ def __init__(self, name: str, identity: T, zero: T): self.zero = zero -@dataclass -class WeightedStream[T, W](Iterable[T]): - stream: Stream[T] - weight: Callable[[T], W] - monoid: Monoid[W] - - def __iter__(self): - return defdata(iter_, self) - - @Operation.define @functools.singledispatch -def weighted(x) -> WeightedStream: +def to_weighted(x) -> Iterable: + """Smart constructor lifting a value (e.g. a distribution) into a + weighted stream. Backends register single-dispatch impls that return a + :func:`weighted` term. + """ raise NotImplementedError("Unsupported type", type(x)) @Operation.define @functools.singledispatch -def stream(x) -> Iterable: +def to_stream(x) -> Iterable: """Smart constructor lifting a value into the :data:`Stream` type. Used to wrap opaque ``support``-like values (e.g. numpyro distributions or constraints) that aren't structurally iterable but - should appear in the stream slot of a :class:`WeightedStream`. Concrete + should appear in the stream slot of a :func:`weighted` term. Concrete iterables can be registered to pass through unchanged; symbolic sources register impls that ``raise NotHandled`` so the call stays a Term and downstream rules can pattern-match on the wrapped value. @@ -151,6 +144,22 @@ def stream(x) -> Iterable: raise NotImplementedError("Unsupported stream source", type(x)) +@defop +def weighted[T, W, A, B]( + stream: Annotated[Iterable, Scoped[B]], + var: Annotated[Operation[[], T], Scoped[A]], + weight: Annotated[W, Scoped[A | B]], + monoid: Monoid[W], +) -> Annotated[Iterable, Scoped[B]]: + """A stream paired with a per-element weight. ``var`` is an + :class:`Operation` standing for "an element of ``stream``"; ``weight`` is + an expression that uses ``var`` and evaluates to the weight of that + element. Always stays as a :class:`Term`; consumers pattern-match on + ``term.op is weighted`` and read ``term.args``. + """ + raise NotHandled + + Min = Monoid(name="Min", identity=float("inf")) Max = Monoid(name="Max", identity=-float("inf")) ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) @@ -589,23 +598,64 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): class ReduceWeightedStream(ObjectInterpretation): - """reduce(M, body, {x: WeightedStream(s, w, WM), ...}) = reduce(M, - WM.plus(w(x), body), {x: s, ...}) + """reduce(M, body, {x: weighted(s, v, w, WM), ...}) = reduce(M, + WM.plus(w[v:=x()], body), {x: s, ...}) requires distributes_over(WM, M). + The substitution ``v -> x`` is done by beta-reducing ``deffn(w, v)`` on + ``x()`` — symbolic, no Python dispatch on the weight expression. """ @implements(Monoid.reduce) def reduce(self, monoid, body, streams): for k, v in streams.items(): - if isinstance(v, WeightedStream) and distributes_over(v.monoid, monoid): - weighted_body = v.monoid.plus(v.weight(k()), body) - new_streams = {**streams, k: v.stream} + if isinstance(v, Term) and v.op is weighted: + v_stream, v_var, v_weight, v_monoid = v.args + if not distributes_over(v_monoid, monoid): + continue + w_at_k = deffn(v_weight, v_var)(k()) + weighted_body = v_monoid.plus(w_at_k, body) + new_streams = {**streams, k: v_stream} return monoid.reduce(weighted_body, new_streams) return fwd() +class ReduceCartesianWeightedStream(ObjectInterpretation): + """``CartesianProduct.reduce`` over a :func:`weighted` body whose + ``weight`` is independent of the plate (product-index) streams:: + + CartesianProduct.reduce(weighted(s, e, w, M), plates) + = weighted( + CartesianProduct.reduce(s, plates), + row, + M.reduce(w, {e: row()}), + M, + ) + + Reuses ``body``'s element binder ``e`` (already typed by construction); + introduces a fresh ``row`` binder typed as ``Iterable[elem_type]``. + + Only fires when ``w`` is independent of the plate vars. + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if monoid is not CartesianProduct: + return fwd() + if not (isinstance(body, Term) and body.op is weighted): + return fwd() + s, e_op, w, weight_monoid = body.args + if set(streams.keys()) & fvsof(w): + return fwd() + + joint_stream = CartesianProduct.reduce(s, streams) + row_op = Operation.define(Iterable, name="row") + joint_weight = weight_monoid.reduce(w, {e_op: row_op()}) + + return weighted(joint_stream, row_op, joint_weight, weight_monoid) + + class MonoidOverCallable(ObjectInterpretation): """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" @@ -801,6 +851,7 @@ def extend(self, *intps: Interpretation) -> typing.Self: ReduceFactorization(), ReduceDistributeCartesianProduct(), ReduceWeightedStream(), + ReduceCartesianWeightedStream(), PlusEmpty(), PlusSingle(), PlusIdentity(), diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 2ebd97a4..d3a4789e 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -9,7 +9,7 @@ import effectful.handlers.jax.numpy as _jnp from effectful.internals.runtime import interpreter -from effectful.ops.monoid import NormalizeIntp +from effectful.ops.monoid import NormalizeIntp, weighted from effectful.ops.semantics import apply, evaluate, handler from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq from effectful.ops.types import NotHandled, Operation, Term @@ -326,11 +326,45 @@ def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation return Operation.define(fn, name=name) +def _weighted_stream_eq(a, b, leaf_eq: Callable[[Any, Any], bool]) -> bool: + a_stream, a_var, a_weight, a_monoid = a.args + b_stream, b_var, b_weight, b_monoid = b.args + if a_monoid is not b_monoid: + return False + a_elems = list(a_stream) + b_elems = list(b_stream) + if len(a_elems) != len(b_elems): + return False + a_weight_fn = deffn(a_weight, a_var) + b_weight_fn = deffn(b_weight, b_var) + for ea, eb in zip(a_elems, b_elems): + if not leaf_eq(ea, eb): + return False + if not leaf_eq(a_weight_fn(ea), b_weight_fn(eb)): + return False + return True + + def _int_eq(a: Any, b: Any) -> bool: + if ( + isinstance(a, Term) + and a.op is weighted + and isinstance(b, Term) + and b.op is weighted + ): + return _weighted_stream_eq(a, b, _int_eq) return not isinstance(a, Term) and not isinstance(b, Term) and a == b def _jax_eq(a: Any, b: Any) -> bool: + if ( + isinstance(a, Term) + and a.op is weighted + and isinstance(b, Term) + and b.op is weighted + ): + return _weighted_stream_eq(a, b, _jax_eq) + def _leaf_eq(x: Any, y: Any) -> bool: return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) @@ -357,12 +391,12 @@ def check_rewrite( assert syntactic_eq_alpha(norm, rhs) @given(intp=random_interpretation(free_vars)) - @settings(max_examples=max_examples, deadline=deadline) + @settings(max_examples=max_examples, deadline=deadline, report_multiple_bugs=False) def _check_semantics(intp): with handler(normalize), handler(intp): lhs_val = evaluate(lhs) rhs_val = evaluate(rhs) - assert backend.eq(lhs_val, rhs_val) + assert backend.eq(lhs_val, rhs_val) _check_semantics() diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 8296cdab..270fcb70 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -23,6 +23,7 @@ PlusSingle, PlusZero, Product, + ReduceCartesianWeightedStream, ReduceDistributeCartesianProduct, ReduceFactorization, ReduceFusion, @@ -30,8 +31,8 @@ ReduceSplit, ReduceWeightedStream, Sum, - WeightedStream, distributes_over, + weighted, ) from effectful.ops.semantics import coproduct, evaluate, fvsof, handler from effectful.ops.types import NotHandled, Operation, Term @@ -686,7 +687,8 @@ def test_reduce_single_weighted_stream(backend): body = backend.fresh_op("body", n_args=1, ret="scalar") w = backend.fresh_op("w", n_args=1, ret="scalar") - ws = WeightedStream(stream=A(), weight=w, monoid=Product) + e = define_vars("e", typ=backend.scalar_typ) + ws = weighted(A(), e, w(e()), Product) lhs = Sum.reduce(body(a()), {a: ws}) rhs = Sum.reduce(Product.plus(w(a()), body(a())), {a: A()}) @@ -714,8 +716,10 @@ def test_reduce_weighted_factorization(backend): w_a = backend.fresh_op("w_a", n_args=1, ret="scalar") w_b = backend.fresh_op("w_b", n_args=1, ret="scalar") - ws_a = WeightedStream(stream=A(), weight=w_a, monoid=Product) - ws_b = WeightedStream(stream=B(), weight=w_b, monoid=Product) + e_a = define_vars("e_a", typ=backend.scalar_typ) + e_b = define_vars("e_b", typ=backend.scalar_typ) + ws_a = weighted(A(), e_a, w_a(e_a()), Product) + ws_b = weighted(B(), e_b, w_b(e_b()), Product) lhs = Sum.reduce(Product.plus(f(a()), g(b())), {a: ws_a, b: ws_b}) rhs = Product.plus( @@ -732,6 +736,46 @@ def test_reduce_weighted_factorization(backend): ) +def test_reduce_cartesian_weighted_stream(backend): + """``CartesianProduct.reduce`` over a ``WeightedStream`` body whose weight + is independent of the plate var rewrites to a single joint + ``WeightedStream``: + + CartesianProduct.reduce(WeightedStream(s, e, w(e), M), {p: P}) + = WeightedStream( + stream = CartesianProduct.reduce(s, {p: P}), + var = row, + weight = M.reduce(w(e), {e: row()}), + monoid = M, + ) + """ + from collections.abc import Iterable + + p = define_vars("p", typ=backend.scalar_typ) + S, P = define_vars("S", "P", typ=backend.stream_typ) + w = backend.fresh_op("w", n_args=1, ret="scalar") + + e_var = define_vars("e", typ=backend.scalar_typ) + ws = weighted(S(), e_var, w(e_var()), Product) + lhs = CartesianProduct.reduce(ws, {p: P()}) + + row_var = Operation.define(Iterable[backend.scalar_typ], name="row") + rhs = weighted( + CartesianProduct.reduce(S(), {p: P()}), + row_var, + Product.reduce(w(e_var()), {e_var: row_var()}), + Product, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceCartesianWeightedStream(), + backend=backend, + free_vars=[S, P, w], + ) + + def test_weighted_expectation_demo(): """Demo: compute E[f(X)] = Σ_x w(x)·f(x) via a weighted reduce. @@ -755,7 +799,8 @@ def _f(v: int) -> float: w = Operation.define(_w, name="w") f = Operation.define(_f, name="f") - ws = WeightedStream(stream=[1, 2, 3, 4], weight=w, monoid=Product) + e = define_vars("e", typ=int) + ws = weighted([1, 2, 3, 4], e, w(e()), Product) with handler(NormalizeIntp): result = evaluate(Sum.reduce(f(a()), {a: ws})) From f0e39a3a646e22e4e0a57c1a96f2cd65285f640e Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 26 May 2026 11:18:41 -0400 Subject: [PATCH 11/29] test composition of lifting and weighting --- tests/test_ops_monoid.py | 46 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 270fcb70..878408d7 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -776,6 +776,52 @@ def test_reduce_cartesian_weighted_stream(backend): ) +def test_lift_weighted_cartesian(backend): + """Compose ``ReduceCartesianWeightedStream`` + ``ReduceWeightedStream`` + + ``ReduceDistributeCartesianProduct`` on a Sum-of-Product-of-weighted shape: + + Sum.reduce( + Product.reduce(body(a()), {a: A()}), + {A: CartesianProduct.reduce(weighted(S, e, w(e), Product), {p: P})}, + ) + + The inner ``weighted`` becomes a joint ``weighted`` (rule 1), lifts its + per-element weight into the outer Sum body (rule 2), and the lifted form + matches the inversion pattern (rule 3), yielding:: + + Product.reduce( + Sum.reduce(Product.plus(w(a()), body(a())), {a: S}), + {p: P}, + ) + """ + a = define_vars("a", typ=backend.scalar_typ) + e, p = define_vars("e", "p", typ=backend.scalar_typ) + A, S, P = define_vars("A", "S", "P", typ=backend.stream_typ) + body = backend.fresh_op("body", n_args=1, ret="scalar") + w = backend.fresh_op("w", n_args=1, ret="scalar") + + ws = weighted(S(), e, w(e()), Product) + lhs = Sum.reduce( + Product.reduce(body(a()), {a: A()}), + {A: CartesianProduct.reduce(ws, {p: P()})}, + ) + rhs = Product.reduce( + Sum.reduce(Product.plus(w(a()), body(a())), {a: S()}), + {p: P()}, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=coproduct( + coproduct(ReduceWeightedStream(), ReduceCartesianWeightedStream()), + ReduceDistributeCartesianProduct(), + ), + backend=backend, + free_vars=[S, P, body, w], + ) + + def test_weighted_expectation_demo(): """Demo: compute E[f(X)] = Σ_x w(x)·f(x) via a weighted reduce. From bf0a34cb599e6fa38804ebe65979530e9e8b35c2 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 26 May 2026 11:24:17 -0400 Subject: [PATCH 12/29] drop numpyro changes --- .../{numpyro/_distributions.py => numpyro.py} | 0 effectful/handlers/numpyro/__init__.py | 1 - effectful/handlers/numpyro/monoid.py | 256 ------------------ effectful/internals/unification.py | 1 - tests/test_handlers_numpyro_monoid.py | 95 ------- 5 files changed, 353 deletions(-) rename effectful/handlers/{numpyro/_distributions.py => numpyro.py} (100%) delete mode 100644 effectful/handlers/numpyro/__init__.py delete mode 100644 effectful/handlers/numpyro/monoid.py delete mode 100644 tests/test_handlers_numpyro_monoid.py diff --git a/effectful/handlers/numpyro/_distributions.py b/effectful/handlers/numpyro.py similarity index 100% rename from effectful/handlers/numpyro/_distributions.py rename to effectful/handlers/numpyro.py diff --git a/effectful/handlers/numpyro/__init__.py b/effectful/handlers/numpyro/__init__.py deleted file mode 100644 index 607151df..00000000 --- a/effectful/handlers/numpyro/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from effectful.handlers.numpyro._distributions import * # noqa: F401, F403 diff --git a/effectful/handlers/numpyro/monoid.py b/effectful/handlers/numpyro/monoid.py deleted file mode 100644 index 0564c103..00000000 --- a/effectful/handlers/numpyro/monoid.py +++ /dev/null @@ -1,256 +0,0 @@ -"""NumPyro distribution support for weighted streams. - -``weighted(dist)`` is the smart constructor for treating a numpyro -distribution as a weighted stream. By default it stays symbolic — i.e. -``weighted(d)`` returns a ``Term`` whose ``args[0]`` is ``d`` — so that -specialized reduction rules (closed-form expectations, quadrature, etc.) -can pattern-match on the distribution. -""" - -from dataclasses import dataclass - -import jax -import jax.numpy as jnp -import numpy as np -import numpyro.distributions as dist -import numpyro.distributions.constraints as constraints - -import effectful.handlers.jax.numpy as ejnp -from effectful.handlers.jax import jax_getitem -from effectful.handlers.jax.monoid import LogSumExp -from effectful.handlers.numpyro import ( - CategoricalLogitsTerm, - CategoricalProbsTerm, - NormalTerm, -) -from effectful.ops.monoid import ( - Monoid, - NormalizeIntp, - Product, - Sum, - WeightedStream, - to_stream, - weighted, -) -from effectful.ops.semantics import fwd -from effectful.ops.syntax import ObjectInterpretation, deffn, implements -from effectful.ops.types import NotHandled, Operation, Term - - -@weighted.register(dist.Distribution) -def _weighted_dist(_d): - raise NotHandled - - -@to_stream.register(dist.Distribution) -def _stream_dist(_d): - raise NotHandled - - -@to_stream.register(constraints.Constraint) -def _stream_constraint(_c): - raise NotHandled - - -def _weighted_dist_arg(v) -> dist.Distribution | None: - """If ``v`` is ``Term(weighted, [d])`` with ``d`` a numpyro Distribution, - return ``d``; otherwise ``None``. - """ - if not (isinstance(v, Term) and v.op is weighted): - return None - (d,) = v.args - return d if isinstance(d, dist.Distribution) else None - - -@dataclass -class NumpyroSampling(ObjectInterpretation): - """Replace ``weighted(d)`` with a sample-backed :class:`WeightedStream`. - - Draws ``n_samples`` i.i.d. samples from ``d`` and attaches a uniform weight - ``1/n_samples`` (linear space) or ``-log(n_samples)`` (log space, when the - outer monoid is :data:`LogSumExp`). The resulting :class:`WeightedStream` is - then handled by the standard :class:`ReduceWeightedStream` rewrite. - - """ - - rng_key: jax.Array - n_samples: int = 1000 - - @implements(Monoid.reduce) - def reduce(self, monoid, body, streams): - new_streams = dict(streams) - progress = False - for k, v in streams.items(): - d = _weighted_dist_arg(v) - if d is None: - continue - samples = d.sample(self.rng_key, sample_shape=(self.n_samples,)) - if monoid is LogSumExp: - w_val = -jnp.log(self.n_samples) - w_monoid: Monoid = Sum - else: - w_val = 1.0 / self.n_samples - w_monoid = Product - new_streams[k] = WeightedStream( - stream=samples, - weight=deffn(w_val, Operation.define(k)), - monoid=w_monoid, - ) - progress = True - if progress: - return monoid.reduce(body, new_streams) - return fwd() - - -@dataclass -class NumpyroGaussHermite(ObjectInterpretation): - """Gauss–Hermite quadrature for ``weighted(Normal(μ, σ))``. - - For ``X ∼ Normal(μ, σ²)``, the change of variable ``u = (x-μ)/(σ√2)`` gives - :: - - E[f(X)] = (1/√π) ∫ f(μ + σ√2 · u) e^{-u²} du - ≈ Σᵢ (wᵢ/√π) · f(μ + σ√2 · uᵢ) - - where ``{uᵢ, wᵢ}`` are the physicists' Hermite nodes/weights from - :func:`numpy.polynomial.hermite.hermgauss`. The rule replaces - ``weighted(d)`` with a :class:`WeightedStream` of length ``n_nodes`` and - lets the standard :class:`ReduceWeightedStream` machinery finish. - - Weight monoid is :data:`Product` for linear-space bodies (e.g. - ``Sum.reduce``) and :data:`Sum` for log-space bodies (e.g. - ``LogSumExp.reduce``); both pairs distribute correctly. - - """ - - n_nodes: int = 20 - - @implements(Monoid.reduce) - def reduce(self, monoid, body, streams): - new_streams = dict(streams) - progress = False - for k, v in streams.items(): - d = _weighted_dist_arg(v) - if not isinstance(d, dist.Normal | NormalTerm): - continue - new_streams[k] = self._gauss_hermite(d, monoid) - progress = True - if progress: - return monoid.reduce(body, new_streams) - return fwd() - - def _gauss_hermite(self, d, monoid: Monoid) -> WeightedStream: - u, w = np.polynomial.hermite.hermgauss(self.n_nodes) - u_jax = jnp.asarray(u, dtype=jnp.float32) - w_jax = jnp.asarray(w, dtype=jnp.float32) - - nodes = d.loc + jnp.sqrt(2.0) * d.scale * u_jax - if monoid is LogSumExp: - weights = jnp.log(w_jax) - 0.5 * jnp.log(jnp.pi) - w_monoid: Monoid = Sum - else: - weights = w_jax / jnp.sqrt(jnp.pi) - w_monoid = Product - - # Position-match the node value back to its weight via argmin. The - # weight function is invoked symbolically by ``ReduceWeightedStream`` - # (with a Term arg), so we use the effectful-wrapped jnp so the - # lookup becomes a Term that evaluates to the right scalar once the - # default reduce binds the stream variable to a concrete node. - def weight_fn(x, _nodes=nodes, _w=weights): - idx = ejnp.argmin(ejnp.abs(_nodes - x)) - return jax_getitem(_w, (idx,)) - - return WeightedStream(stream=nodes, weight=weight_fn, monoid=w_monoid) - - -@dataclass -class NumpyroCategorical(ObjectInterpretation): - """Exact enumeration ('quadrature') for ``weighted(Categorical(...))``. - - A categorical with ``K`` outcomes has finite integer support - ``{0, …, K-1}``; integration reduces to an exact finite sum. The rule - replaces ``weighted(d)`` with a :class:`WeightedStream` whose stream is - ``jnp.arange(K)`` and whose weight indexes into the per-outcome - probability vector. - - Weight monoid is :data:`Product` for linear-space bodies and :data:`Sum` - for log-space bodies (under :data:`LogSumExp`), matching the - distributes-over pairs used by :class:`ReduceWeightedStream`. - - """ - - @implements(Monoid.reduce) - def reduce(self, monoid, body, streams): - new_streams = dict(streams) - progress = False - for k, v in streams.items(): - d = _weighted_dist_arg(v) - ws = self._categorical(d, monoid) - if ws is None: - continue - new_streams[k] = ws - progress = True - if progress: - return monoid.reduce(body, new_streams) - return fwd() - - def _categorical(self, d, monoid: Monoid) -> WeightedStream | None: - # Pick the natural representation for the target weight monoid so we - # don't go probs→log or logits→probs→log unnecessarily. - if monoid is LogSumExp: - w_monoid: Monoid = Sum - if isinstance(d, dist.CategoricalLogits | CategoricalLogitsTerm): - weights = jax.nn.log_softmax(jnp.asarray(d.logits), axis=-1) - elif isinstance(d, dist.CategoricalProbs | CategoricalProbsTerm): - weights = jnp.log(jnp.asarray(d.probs)) - else: - return None - else: - w_monoid = Product - if isinstance(d, dist.CategoricalProbs | CategoricalProbsTerm): - weights = jnp.asarray(d.probs) - elif isinstance(d, dist.CategoricalLogits | CategoricalLogitsTerm): - weights = jax.nn.softmax(jnp.asarray(d.logits), axis=-1) - else: - return None - - indices = jnp.arange(weights.shape[-1]) - - # The support value *is* the index, so the lookup is direct. - def weight_fn(x, _w=weights): - return jax_getitem(_w, (x,)) - - return weighted(stream=indices, weight=weight_fn, monoid=w_monoid) - - -class NumpyroLogProb(ObjectInterpretation): - """Lower ``weighted(d)`` to its symbolic log-prob form. - - Generic fallback: produces a :class:`WeightedStream` whose stream is the - symbolic ``stream(d.support)``, weight is ``d.log_prob``, and weight monoid - is :data:`Sum` (log-space multiplication). With ``distributes_over(Sum, - LogSumExp)`` registered, a surrounding ``LogSumExp.reduce`` will then - desugar via :class:`ReduceWeightedStream` into the standard expectation - integrand. - - """ - - @implements(Monoid.reduce) - def reduce(self, monoid, body, streams): - new_streams = dict(streams) - progress = False - for k, v in streams.items(): - d = _weighted_dist_arg(v) - if d is None: - continue - new_streams[k] = weighted( - stream=to_stream(d.support), weight=d.log_prob, monoid=Sum - ) - progress = True - if progress: - return monoid.reduce(body, new_streams) - return fwd() - - -NormalizeIntp.extend(NumpyroCategorical()) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 36a4a4e9..e425bba6 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -801,7 +801,6 @@ def _(typ: typing._ConcatenateGenericAlias): # type: ignore @canonicalize.register def _(typ: typing._AnyMeta): # type: ignore - return typing.Any diff --git a/tests/test_handlers_numpyro_monoid.py b/tests/test_handlers_numpyro_monoid.py deleted file mode 100644 index ccc2ca1d..00000000 --- a/tests/test_handlers_numpyro_monoid.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Unit tests for the rewrite rules in ``effectful.handlers.numpyro.monoid``. - -Tests follow the conventions in ``test_ops_monoid.py``: each rule is verified -via a symbolic ``lhs`` and the expected post-rewrite ``rhs``. The numpyro -categorical rule is part of :data:`NormalizeIntp`, so plain ``check_rewrite`` -suffices. -""" - -import jax -import numpyro.distributions as dist -import pytest - -import effectful.handlers.jax.monoid # noqa: F401 # registers jax monoid handlers -import effectful.handlers.jax.numpy as jnp -import effectful.handlers.numpyro.monoid # noqa: F401 # registers numpyro monoid handlers -from effectful.handlers.jax import bind_dims, unbind_dims -from effectful.handlers.jax.monoid import LogSumExp -from effectful.handlers.jax.scipy.special import logsumexp -from effectful.ops.monoid import NormalizeIntp, Sum, weighted -from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars - - -@pytest.fixture -def backend() -> Backend: - return JAX_BACKEND - - -def test_categorical_probs_linear(backend): - """Under ``Sum.reduce`` with ``CategoricalProbs``, the per-index weight - is ``probs[i]`` (linear) and the lowered form is the named-dim - ``jnp.sum(probs[k] * body(indices[k]))``. - """ - probs = jnp.array([0.1, 0.2, 0.3, 0.4]) - d = dist.CategoricalProbs(probs=probs) - indices = jnp.arange(probs.shape[-1]) - - i_op, k = define_vars("i_op", "k", typ=backend.scalar_typ) - body = backend.fresh_op("body", n_args=1, ret="scalar") - - lhs = Sum.reduce(body(i_op()), {i_op: weighted(d)}) - rhs = jnp.sum( - bind_dims(unbind_dims(probs, k) * body(unbind_dims(indices, k)), k), axis=0 - ) - - check_rewrite( - lhs=lhs, rhs=rhs, rule=NormalizeIntp, backend=backend, free_vars=[k, body] - ) - - -def test_categorical_logits_linear(backend): - """Under ``Sum.reduce`` with ``CategoricalLogits``, the per-index weight - is ``softmax(logits)[i]`` and the lowered form matches the same - named-dim ``jnp.sum`` shape as the probs case. - """ - logits = jnp.array([0.5, -1.0, 2.0, 0.1]) - d = dist.CategoricalLogits(logits=logits) - probs = jax.nn.softmax(logits, axis=-1) - indices = jnp.arange(logits.shape[-1]) - - i_op, k = define_vars("i_op", "k", typ=backend.scalar_typ) - body = backend.fresh_op("body", n_args=1, ret="scalar") - - lhs = Sum.reduce(body(i_op()), {i_op: weighted(d)}) - rhs = jnp.sum( - bind_dims(unbind_dims(probs, k) * body(unbind_dims(indices, k)), k), axis=0 - ) - - check_rewrite( - lhs=lhs, rhs=rhs, rule=NormalizeIntp, backend=backend, free_vars=[k, body] - ) - - -def test_categorical_logsumexp(backend): - """Under ``LogSumExp.reduce`` with ``CategoricalProbs``, weights are - ``log(probs)`` combined via ``Sum`` (log-multiplication); the lowered - form is ``logsumexp(log_probs[k] + body(indices[k]))`` along the - named dim. - """ - probs = jnp.array([0.1, 0.2, 0.3, 0.4]) - d = dist.CategoricalProbs(probs=probs) - log_probs = jnp.log(probs) - indices = jnp.arange(probs.shape[-1]) - - i_op, k = define_vars("i_op", "k", typ=backend.scalar_typ) - body = backend.fresh_op("body", n_args=1, ret="scalar") - - lhs = LogSumExp.reduce(body(i_op()), {i_op: weighted(d)}) - rhs = logsumexp( - bind_dims(unbind_dims(log_probs, k) + body(unbind_dims(indices, k)), k), - axis=0, - ) - - check_rewrite( - lhs=lhs, rhs=rhs, rule=NormalizeIntp, backend=backend, free_vars=[k, body] - ) From 60da70509b2a91d1cfc88689975be52db5c01117 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 26 May 2026 11:25:18 -0400 Subject: [PATCH 13/29] drop unused ops --- effectful/ops/monoid.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 12662723..9a314d6f 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -119,31 +119,6 @@ def __init__(self, name: str, identity: T, zero: T): self.zero = zero -@Operation.define -@functools.singledispatch -def to_weighted(x) -> Iterable: - """Smart constructor lifting a value (e.g. a distribution) into a - weighted stream. Backends register single-dispatch impls that return a - :func:`weighted` term. - """ - raise NotImplementedError("Unsupported type", type(x)) - - -@Operation.define -@functools.singledispatch -def to_stream(x) -> Iterable: - """Smart constructor lifting a value into the :data:`Stream` type. - - Used to wrap opaque ``support``-like values (e.g. numpyro - distributions or constraints) that aren't structurally iterable but - should appear in the stream slot of a :func:`weighted` term. Concrete - iterables can be registered to pass through unchanged; symbolic sources - register impls that ``raise NotHandled`` so the call stays a Term and - downstream rules can pattern-match on the wrapped value. - """ - raise NotImplementedError("Unsupported stream source", type(x)) - - @defop def weighted[T, W, A, B]( stream: Annotated[Iterable, Scoped[B]], From 9661c452fe3bdd1292a19e93487b81149f9a4062 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 26 May 2026 11:28:01 -0400 Subject: [PATCH 14/29] lint --- tests/test_handlers_jax_monoid.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index 81ef8583..259a1417 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -1,4 +1,5 @@ import functools +import typing import jax import pytest @@ -23,9 +24,10 @@ Product, ReduceWeightedStream, Sum, - WeightedStream, + weighted, ) from effectful.ops.semantics import coproduct, handler +from effectful.ops.types import Interpretation from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars MONOIDS = [ @@ -128,7 +130,7 @@ def test_jax_weighted_reduce(backend: Backend): body = backend.fresh_op("body", n_args=1, ret="scalar") w = backend.fresh_op("w", n_args=1, ret="scalar") - ws = WeightedStream(stream=X(), weight=w, monoid=Product) + ws = weighted(X(), x, w(x()), monoid=Product) lhs = Sum.reduce(body(x()), {x: ws}) rhs = jnp.sum( bind_dims(w(unbind_dims(X(), k)) * body(unbind_dims(X(), k)), k), axis=0 @@ -138,7 +140,11 @@ def test_jax_weighted_reduce(backend: Backend): lhs=lhs, rhs=rhs, rule=functools.reduce( - coproduct, [ReduceWeightedStream(), ArrayReduce(), ProductPlusJax()] + coproduct, + typing.cast( + list[Interpretation], + [ReduceWeightedStream(), ArrayReduce(), ProductPlusJax()], + ), ), backend=backend, free_vars=[x, k, X, body, w], From f65fe37a9687fc94c116b57f02bf2532ebe311c5 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 26 May 2026 13:27:04 -0400 Subject: [PATCH 15/29] make weighted a Monoid method --- effectful/ops/monoid.py | 55 +++++++++++++++------------ tests/_monoid_helpers.py | 14 +++---- tests/test_handlers_jax_monoid.py | 6 +-- tests/test_ops_monoid.py | 62 ++++++++++++------------------- 4 files changed, 65 insertions(+), 72 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 9a314d6f..9409f5e6 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -15,7 +15,6 @@ ObjectInterpretation, Scoped, deffn, - defop, implements, syntactic_eq, syntactic_hash, @@ -110,6 +109,21 @@ def reduce[A, B, U: Body]( return self.plus(*new_reduces) raise NotHandled + @Operation.define + def weighted[Elem, A, B]( + self, + stream: Annotated[Iterable, Scoped[B]], + var: Annotated[Operation[[], Elem], Scoped[A]], + weight: Annotated[T, Scoped[A | B]], + ) -> Annotated[Iterable, Scoped[B]]: + """A stream paired with a per-element weight. ``var`` is an + :class:`Operation` standing for "an element of ``stream``"; ``weight`` + is an expression that uses ``var`` and evaluates to the weight of that + element. + + """ + raise NotHandled + class MonoidWithZero[T](Monoid[T]): zero: T @@ -119,22 +133,6 @@ def __init__(self, name: str, identity: T, zero: T): self.zero = zero -@defop -def weighted[T, W, A, B]( - stream: Annotated[Iterable, Scoped[B]], - var: Annotated[Operation[[], T], Scoped[A]], - weight: Annotated[W, Scoped[A | B]], - monoid: Monoid[W], -) -> Annotated[Iterable, Scoped[B]]: - """A stream paired with a per-element weight. ``var`` is an - :class:`Operation` standing for "an element of ``stream``"; ``weight`` is - an expression that uses ``var`` and evaluates to the weight of that - element. Always stays as a :class:`Term`; consumers pattern-match on - ``term.op is weighted`` and read ``term.args``. - """ - raise NotHandled - - Min = Monoid(name="Min", identity=float("inf")) Max = Monoid(name="Max", identity=-float("inf")) ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) @@ -190,6 +188,12 @@ def _is_monoid_reduce(op: Operation) -> bool: return isinstance(owner, Monoid) and op is owner.reduce +def _is_monoid_weighted(op: Operation) -> bool: + """True if ``op`` is the ``weighted`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.weighted + + class PlusEmpty(ObjectInterpretation): """plus() = 0""" @@ -573,8 +577,7 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): class ReduceWeightedStream(ObjectInterpretation): - """reduce(M, body, {x: weighted(s, v, w, WM), ...}) = reduce(M, - WM.plus(w[v:=x()], body), {x: s, ...}) + """reduce(M, body, {x: WM.weighted(s, v, w), ...}) = reduce(M, WM.plus(w[v:=x()], body), {x: s, ...}) requires distributes_over(WM, M). @@ -585,8 +588,9 @@ class ReduceWeightedStream(ObjectInterpretation): @implements(Monoid.reduce) def reduce(self, monoid, body, streams): for k, v in streams.items(): - if isinstance(v, Term) and v.op is weighted: - v_stream, v_var, v_weight, v_monoid = v.args + if isinstance(v, Term) and _is_monoid_weighted(v.op): + v_stream, v_var, v_weight = v.args + v_monoid = v.op.__self__ if not distributes_over(v_monoid, monoid): continue w_at_k = deffn(v_weight, v_var)(k()) @@ -618,9 +622,12 @@ class ReduceCartesianWeightedStream(ObjectInterpretation): def reduce(self, monoid, body, streams): if monoid is not CartesianProduct: return fwd() - if not (isinstance(body, Term) and body.op is weighted): + if not (isinstance(body, Term) and _is_monoid_weighted(body.op)): return fwd() - s, e_op, w, weight_monoid = body.args + + s, e_op, w = body.args + weight_monoid = body.op.__self__ + if set(streams.keys()) & fvsof(w): return fwd() @@ -628,7 +635,7 @@ def reduce(self, monoid, body, streams): row_op = Operation.define(Iterable, name="row") joint_weight = weight_monoid.reduce(w, {e_op: row_op()}) - return weighted(joint_stream, row_op, joint_weight, weight_monoid) + return weight_monoid.weighted(joint_stream, row_op, joint_weight) class MonoidOverCallable(ObjectInterpretation): diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index d3a4789e..08684dfb 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -9,7 +9,7 @@ import effectful.handlers.jax.numpy as _jnp from effectful.internals.runtime import interpreter -from effectful.ops.monoid import NormalizeIntp, weighted +from effectful.ops.monoid import NormalizeIntp, _is_monoid_weighted from effectful.ops.semantics import apply, evaluate, handler from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq from effectful.ops.types import NotHandled, Operation, Term @@ -327,8 +327,8 @@ def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation def _weighted_stream_eq(a, b, leaf_eq: Callable[[Any, Any], bool]) -> bool: - a_stream, a_var, a_weight, a_monoid = a.args - b_stream, b_var, b_weight, b_monoid = b.args + a_monoid, a_stream, a_var, a_weight = a.args + b_monoid, b_stream, b_var, b_weight = b.args if a_monoid is not b_monoid: return False a_elems = list(a_stream) @@ -348,9 +348,9 @@ def _weighted_stream_eq(a, b, leaf_eq: Callable[[Any, Any], bool]) -> bool: def _int_eq(a: Any, b: Any) -> bool: if ( isinstance(a, Term) - and a.op is weighted + and _is_monoid_weighted(a.op) and isinstance(b, Term) - and b.op is weighted + and _is_monoid_weighted(b.op) ): return _weighted_stream_eq(a, b, _int_eq) return not isinstance(a, Term) and not isinstance(b, Term) and a == b @@ -359,9 +359,9 @@ def _int_eq(a: Any, b: Any) -> bool: def _jax_eq(a: Any, b: Any) -> bool: if ( isinstance(a, Term) - and a.op is weighted + and _is_monoid_weighted(a.op) and isinstance(b, Term) - and b.op is weighted + and _is_monoid_weighted(b.op) ): return _weighted_stream_eq(a, b, _jax_eq) diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index 259a1417..c94e1e79 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -24,7 +24,6 @@ Product, ReduceWeightedStream, Sum, - weighted, ) from effectful.ops.semantics import coproduct, handler from effectful.ops.types import Interpretation @@ -119,18 +118,19 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): def test_jax_weighted_reduce(backend: Backend): - """Sum over a single ``WeightedStream`` with ``Product`` weights lowers to + """Sum over a single stream with ``Product`` weights lowers to ``jnp.sum(w(X) * body(X))`` under ``NormalizeIntp`` ∘ ``ArrayReduce``. Verifies that the desugaring rule composes cleanly with the JAX lowering so existing handlers need no changes to support weighted streams. + """ (x, k) = define_vars("x", "k", typ=jax.Array) X = define_vars("X", typ=backend.stream_typ) body = backend.fresh_op("body", n_args=1, ret="scalar") w = backend.fresh_op("w", n_args=1, ret="scalar") - ws = weighted(X(), x, w(x()), monoid=Product) + ws = Product.weighted(X(), x, w(x())) lhs = Sum.reduce(body(x()), {x: ws}) rhs = jnp.sum( bind_dims(w(unbind_dims(X(), k)) * body(unbind_dims(X(), k)), k), axis=0 diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 878408d7..57168295 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,5 +1,6 @@ import math import typing +from collections.abc import Iterable import pytest from hypothesis import HealthCheck, given, settings @@ -32,7 +33,6 @@ ReduceWeightedStream, Sum, distributes_over, - weighted, ) from effectful.ops.semantics import coproduct, evaluate, fvsof, handler from effectful.ops.types import NotHandled, Operation, Term @@ -682,14 +682,12 @@ def test_reduce_single_weighted_stream(backend): Sum.reduce(body, {a: WS(A, w, Product)}) = Sum.reduce(Product.plus(w(a), body), {a: A}) """ - a = define_vars("a", typ=backend.scalar_typ) + a, e = define_vars("a", "e", typ=backend.scalar_typ) A = define_vars("A", typ=backend.stream_typ) body = backend.fresh_op("body", n_args=1, ret="scalar") w = backend.fresh_op("w", n_args=1, ret="scalar") - e = define_vars("e", typ=backend.scalar_typ) - ws = weighted(A(), e, w(e()), Product) - lhs = Sum.reduce(body(a()), {a: ws}) + lhs = Sum.reduce(body(a()), {a: Product.weighted(A(), e, w(e()))}) rhs = Sum.reduce(Product.plus(w(a()), body(a())), {a: A()}) check_rewrite( @@ -703,25 +701,26 @@ def test_reduce_single_weighted_stream(backend): def test_reduce_weighted_factorization(backend): """Two independent weighted streams under Sum with Product weights factor: - Sum.reduce(f(a)*g(b), {a: WS(A, w_a, Product), b: WS(B, w_b, Product)}) + Sum.reduce(f(a)*g(b), {a: Product.weighted(A, a, w_a), b: Product.weighted(B, b, w_b)}) = (Sum.reduce(w_a(a)*f(a), {a: A})) * (Sum.reduce(w_b(b)*g(b), {b: B})) Exercises chaining of ``ReduceWeightedStream`` with ``ReduceFactorization`` inside ``NormalizeIntp``. """ - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b, e_a, e_b = define_vars("a", "b", "e_a", "e_b", typ=backend.scalar_typ) A, B = define_vars("A", "B", typ=backend.stream_typ) f = backend.fresh_op("f", n_args=1, ret="scalar") g = backend.fresh_op("g", n_args=1, ret="scalar") w_a = backend.fresh_op("w_a", n_args=1, ret="scalar") w_b = backend.fresh_op("w_b", n_args=1, ret="scalar") - e_a = define_vars("e_a", typ=backend.scalar_typ) - e_b = define_vars("e_b", typ=backend.scalar_typ) - ws_a = weighted(A(), e_a, w_a(e_a()), Product) - ws_b = weighted(B(), e_b, w_b(e_b()), Product) - - lhs = Sum.reduce(Product.plus(f(a()), g(b())), {a: ws_a, b: ws_b}) + lhs = Sum.reduce( + Product.plus(f(a()), g(b())), + { + a: Product.weighted(A(), e_a, w_a(e_a())), + b: Product.weighted(B(), e_b, w_b(e_b())), + }, + ) rhs = Product.plus( Sum.reduce(Product.plus(w_a(a()), Product.plus(f(a()))), {a: A()}), Sum.reduce(Product.plus(w_b(b()), Product.plus(g(b()))), {b: B()}), @@ -741,30 +740,20 @@ def test_reduce_cartesian_weighted_stream(backend): is independent of the plate var rewrites to a single joint ``WeightedStream``: - CartesianProduct.reduce(WeightedStream(s, e, w(e), M), {p: P}) - = WeightedStream( - stream = CartesianProduct.reduce(s, {p: P}), - var = row, - weight = M.reduce(w(e), {e: row()}), - monoid = M, - ) + CartesianProduct.reduce(M.weighted(s, e, w(e)), {p: P}) + = M.weighted(CartesianProduct.reduce(s, {p: P}), row, M.reduce(w(e), {e: row()})) """ - from collections.abc import Iterable - - p = define_vars("p", typ=backend.scalar_typ) + p, e_var = define_vars("p", "e_var", typ=backend.scalar_typ) S, P = define_vars("S", "P", typ=backend.stream_typ) w = backend.fresh_op("w", n_args=1, ret="scalar") - e_var = define_vars("e", typ=backend.scalar_typ) - ws = weighted(S(), e_var, w(e_var()), Product) - lhs = CartesianProduct.reduce(ws, {p: P()}) + lhs = CartesianProduct.reduce(Product.weighted(S(), e_var, w(e_var())), {p: P()}) row_var = Operation.define(Iterable[backend.scalar_typ], name="row") - rhs = weighted( + rhs = Product.weighted( CartesianProduct.reduce(S(), {p: P()}), row_var, Product.reduce(w(e_var()), {e_var: row_var()}), - Product, ) check_rewrite( @@ -782,7 +771,7 @@ def test_lift_weighted_cartesian(backend): Sum.reduce( Product.reduce(body(a()), {a: A()}), - {A: CartesianProduct.reduce(weighted(S, e, w(e), Product), {p: P})}, + {A: CartesianProduct.reduce(Product.weighted(S, e, w(e)), {p: P})}, ) The inner ``weighted`` becomes a joint ``weighted`` (rule 1), lifts its @@ -800,14 +789,12 @@ def test_lift_weighted_cartesian(backend): body = backend.fresh_op("body", n_args=1, ret="scalar") w = backend.fresh_op("w", n_args=1, ret="scalar") - ws = weighted(S(), e, w(e()), Product) lhs = Sum.reduce( Product.reduce(body(a()), {a: A()}), - {A: CartesianProduct.reduce(ws, {p: P()})}, + {A: CartesianProduct.reduce(Product.weighted(S(), e, w(e())), {p: P()})}, ) rhs = Product.reduce( - Sum.reduce(Product.plus(w(a()), body(a())), {a: S()}), - {p: P()}, + Sum.reduce(Product.plus(w(a()), body(a())), {a: S()}), {p: P()} ) check_rewrite( @@ -841,14 +828,13 @@ def _f(v: int) -> float: raise NotHandled return float(v * v) - a = define_vars("a", typ=int) + a, e = define_vars("a", "e", typ=int) w = Operation.define(_w, name="w") f = Operation.define(_f, name="f") - e = define_vars("e", typ=int) - ws = weighted([1, 2, 3, 4], e, w(e()), Product) - with handler(NormalizeIntp): - result = evaluate(Sum.reduce(f(a()), {a: ws})) + result = evaluate( + Sum.reduce(f(a()), {a: Product.weighted([1, 2, 3, 4], e, w(e()))}) + ) assert math.isclose(result, 10.0) From 9a459200e3296607dab8b7c70c81fec9452dd79d Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 26 May 2026 15:12:03 -0400 Subject: [PATCH 16/29] fix typing of jax arrays --- effectful/handlers/jax/_terms.py | 7 +++++++ tests/_monoid_helpers.py | 6 ++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/effectful/handlers/jax/_terms.py b/effectful/handlers/jax/_terms.py index c88fe934..05a5390e 100644 --- a/effectful/handlers/jax/_terms.py +++ b/effectful/handlers/jax/_terms.py @@ -14,10 +14,17 @@ unbind_dims, ) from effectful.internals.tensor_utils import _desugar_tensor_index +from effectful.internals.unification import Box, nested_type from effectful.ops.syntax import defdata from effectful.ops.types import Expr, NotHandled, Operation, Term +@nested_type.register(jax.Array) +@nested_type.register(jax._src.core.Tracer) +def _(value): + return Box(jax.Array) + + class _IndexUpdateHelper: """Helper class to implement array-style .at[index].set() updates for effectful arrays.""" diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 08684dfb..8e886e70 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -327,8 +327,10 @@ def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation def _weighted_stream_eq(a, b, leaf_eq: Callable[[Any, Any], bool]) -> bool: - a_monoid, a_stream, a_var, a_weight = a.args - b_monoid, b_stream, b_var, b_weight = b.args + a_monoid = a.op.__self__ + a_stream, a_var, a_weight = a.args + b_monoid = b.op.__self__ + b_stream, b_var, b_weight = b.args if a_monoid is not b_monoid: return False a_elems = list(a_stream) From e109d984d676047ff4b72c5d38ed90d106b80bd2 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 27 May 2026 11:47:29 -0400 Subject: [PATCH 17/29] change weighted typing to take callable --- effectful/handlers/jax/monoid.py | 2 + effectful/ops/monoid.py | 94 ++++++++++++++++++++++++-------- effectful/ops/semantics.py | 29 ++++++---- tests/_monoid_helpers.py | 63 ++++++++++++--------- tests/test_ops_monoid.py | 21 +++---- 5 files changed, 135 insertions(+), 74 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 634a8cd7..5ef06770 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -18,12 +18,14 @@ Sum, distributes_over, outer_stream, + stream_element_type, ) from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof from effectful.ops.syntax import ObjectInterpretation, deffn, implements from effectful.ops.types import Interpretation, NotHandled, Operation, Term Iterable.register(jax.Array) # required to make jax arrays compatible with Stream[T] +stream_element_type.register(jax.Array, lambda t: jax.Array) def cartesian_prod(x, y): diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 9409f5e6..ee51584c 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,7 +10,15 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.internals.unification import nested_type +from effectful.ops.semantics import ( + coproduct, + evaluate, + fvsof, + fwd, + handler, + typeof_full, +) from effectful.ops.syntax import ( ObjectInterpretation, Scoped, @@ -19,7 +27,14 @@ syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term +from effectful.ops.types import ( + Expr, + Interpretation, + NotHandled, + Operation, + Term, + _CustomSingleDispatchCallable, +) type Stream[T] = Iterable[T] @@ -49,13 +64,13 @@ def outer_stream(streams: Streams) -> Iterable[tuple[Operation, Stream, Streams] ) -class Monoid[T]: +class Monoid[W]: """A monoid with ``plus`` and ``reduce`` :class:`Operation` s.""" _name: str - identity: T + identity: W - def __init__(self, identity: T, name: str): + def __init__(self, identity: W, name: str): self._name = name self.identity = identity @@ -110,12 +125,9 @@ def reduce[A, B, U: Body]( raise NotHandled @Operation.define - def weighted[Elem, A, B]( - self, - stream: Annotated[Iterable, Scoped[B]], - var: Annotated[Operation[[], Elem], Scoped[A]], - weight: Annotated[T, Scoped[A | B]], - ) -> Annotated[Iterable, Scoped[B]]: + def weighted[T]( + self, stream: Iterable[T], weight: Callable[[T], W] | Operation[[T], W] + ) -> Iterable[T]: """A stream paired with a per-element weight. ``var`` is an :class:`Operation` standing for "an element of ``stream``"; ``weight`` is an expression that uses ``var`` and evaluates to the weight of that @@ -176,6 +188,35 @@ def __call__(self, s: S, t: T) -> bool: ) +@_CustomSingleDispatchCallable +def stream_element_type(__dispatch: Callable[[type], Callable[[Any], Any]], typ: Any): + """Maps a stream type to the type of its elements. + + Extensible per backend via :meth:`register`. Dispatch is on the type's + *origin* (e.g. ``list`` for ``list[int]``), so a handler registered for a + base class or ABC also covers its subclasses, most-specific match winning. + """ + if typing.get_origin(typ) is Annotated: + typ = typing.get_args(typ)[0] + origin = typing.get_origin(typ) or typ + return __dispatch(origin)(typ) + + +@stream_element_type.register(object) +def _stream_element_type_default(typ: Any) -> Any: + raise NotImplementedError( + f"No element-type handler registered for stream type {typ!r}" + ) + + +# ``Iterable`` covers ``list``, ``tuple``, ``MutableSequence``, ... via origin +# subclass dispatch. +@stream_element_type.register(Iterable) +def _stream_element_type_iterable(typ: Any) -> Any: + """Element type of a single-parameter generic stream like ``list[T]``.""" + return typing.get_args(typ)[0] + + def _is_monoid_plus(op: Operation) -> bool: """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" owner = getattr(op, "__self__", None) @@ -589,11 +630,11 @@ class ReduceWeightedStream(ObjectInterpretation): def reduce(self, monoid, body, streams): for k, v in streams.items(): if isinstance(v, Term) and _is_monoid_weighted(v.op): - v_stream, v_var, v_weight = v.args + v_stream, v_weight = v.args v_monoid = v.op.__self__ if not distributes_over(v_monoid, monoid): continue - w_at_k = deffn(v_weight, v_var)(k()) + w_at_k = v_weight(k()) weighted_body = v_monoid.plus(w_at_k, body) new_streams = {**streams, k: v_stream} return monoid.reduce(weighted_body, new_streams) @@ -604,12 +645,10 @@ class ReduceCartesianWeightedStream(ObjectInterpretation): """``CartesianProduct.reduce`` over a :func:`weighted` body whose ``weight`` is independent of the plate (product-index) streams:: - CartesianProduct.reduce(weighted(s, e, w, M), plates) - = weighted( + CartesianProduct.reduce(M.weighted(s, w), plates) + = M.weighted( CartesianProduct.reduce(s, plates), - row, - M.reduce(w, {e: row()}), - M, + deffn(M.reduce(w, {e: row()}), row), ) Reuses ``body``'s element binder ``e`` (already typed by construction); @@ -625,17 +664,24 @@ def reduce(self, monoid, body, streams): if not (isinstance(body, Term) and _is_monoid_weighted(body.op)): return fwd() - s, e_op, w = body.args - weight_monoid = body.op.__self__ + s, w = body.args + if not isinstance(s, Term) and len(s) == 0: + return CartesianProduct.reduce([], streams) if set(streams.keys()) & fvsof(w): return fwd() - joint_stream = CartesianProduct.reduce(s, streams) - row_op = Operation.define(Iterable, name="row") - joint_weight = weight_monoid.reduce(w, {e_op: row_op()}) + stream_type = typeof_full(s) if isinstance(s, Term) else nested_type(s).value + elem_typ = stream_element_type(stream_type) + elem_op = Operation.define(elem_typ, name="elem") + row_op = Operation.define(Iterable[elem_typ], name="row") - return weight_monoid.weighted(joint_stream, row_op, joint_weight) + weight_monoid = body.op.__self__ + joint_weight = deffn( + weight_monoid.reduce(w(elem_op()), {elem_op: row_op()}), row_op + ) + joint_stream = CartesianProduct.reduce(s, streams) + return weight_monoid.weighted(joint_stream, joint_weight) class MonoidOverCallable(ObjectInterpretation): diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 8fd62bcd..77b3a231 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -307,8 +307,23 @@ def _simple_type(tp: type) -> type: return typing.get_origin(tp) or tp +def typeof_full[T](term: Expr[T]) -> type[T]: + """Return the type of an expression, including any type parameters.""" + from effectful.internals.runtime import interpreter + from effectful.internals.unification import Box + + def _apply(op, *args, **kwargs): + return Box(op.__type_rule__(*args, **kwargs)) + + with interpreter({apply: _apply}): + type_or_value = evaluate(term) + if isinstance(type_or_value, Box): + return type_or_value.value + return typing.cast(type[T], type(type_or_value)) + + def typeof[T](term: Expr[T]) -> type[T]: - """Return the type of an expression. + """Return the type of an expression, with type parameters stripped. **Example usage**: @@ -329,17 +344,7 @@ def typeof[T](term: Expr[T]) -> type[T]: """ - from effectful.internals.runtime import interpreter - from effectful.internals.unification import Box - - def _apply(op, *args, **kwargs): - return Box(op.__type_rule__(*args, **kwargs)) - - with interpreter({apply: _apply}): - type_or_value = evaluate(term) - if isinstance(type_or_value, Box): - return _simple_type(type_or_value.value) - return typing.cast(type[T], type(type_or_value)) + return _simple_type(typeof_full(term)) def fvsof[S](term: Expr[S]) -> collections.abc.Set[Operation]: diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 8e886e70..b052218f 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,4 +1,5 @@ import itertools +import typing from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from typing import Annotated, Any, get_args, get_origin @@ -326,45 +327,55 @@ def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation return Operation.define(fn, name=name) +def _is_weighted(x: Any) -> bool: + return isinstance(x, Term) and _is_monoid_weighted(x.op) + + +def _weight_pairs(x: Any, monoid: Any) -> list[tuple[Any, Any]] | None: + """Return ``(element, weight)`` pairs for a stream. + + A weighted-monoid Term yields each element paired with its weight. A plain + (unweighted) stream yields each element paired with ``monoid.identity`` -- + the no-op weight -- so an unweighted stream compares equal to a weighted one + exactly when every weight reduces to the identity (e.g. ``[()]`` vs a + weighted ``[()]`` whose single empty row reduces to the identity, and, more + generally, whenever both streams are empty). Returns ``None`` for a + non-stream Term, which never compares equal to a weighted stream. + """ + if isinstance(x, Term): + if not _is_monoid_weighted(x.op): + return None + stream, weight = x.args + assert not isinstance(stream, Term) + return [(e, typing.cast(Callable, weight)(e)) for e in stream] + return [(e, monoid.identity) for e in x] + + def _weighted_stream_eq(a, b, leaf_eq: Callable[[Any, Any], bool]) -> bool: - a_monoid = a.op.__self__ - a_stream, a_var, a_weight = a.args - b_monoid = b.op.__self__ - b_stream, b_var, b_weight = b.args - if a_monoid is not b_monoid: + monoids = {x.op.__self__ for x in (a, b) if _is_weighted(x)} + # distinct weight monoids can never be equal + if len(monoids) != 1: return False - a_elems = list(a_stream) - b_elems = list(b_stream) - if len(a_elems) != len(b_elems): + monoid = next(iter(monoids)) + + a_pairs = _weight_pairs(a, monoid) + b_pairs = _weight_pairs(b, monoid) + if a_pairs is None or b_pairs is None or len(a_pairs) != len(b_pairs): return False - a_weight_fn = deffn(a_weight, a_var) - b_weight_fn = deffn(b_weight, b_var) - for ea, eb in zip(a_elems, b_elems): - if not leaf_eq(ea, eb): - return False - if not leaf_eq(a_weight_fn(ea), b_weight_fn(eb)): + for (ea, wa), (eb, wb) in zip(a_pairs, b_pairs): + if not leaf_eq(ea, eb) or not leaf_eq(wa, wb): return False return True def _int_eq(a: Any, b: Any) -> bool: - if ( - isinstance(a, Term) - and _is_monoid_weighted(a.op) - and isinstance(b, Term) - and _is_monoid_weighted(b.op) - ): + if _is_weighted(a) or _is_weighted(b): return _weighted_stream_eq(a, b, _int_eq) return not isinstance(a, Term) and not isinstance(b, Term) and a == b def _jax_eq(a: Any, b: Any) -> bool: - if ( - isinstance(a, Term) - and _is_monoid_weighted(a.op) - and isinstance(b, Term) - and _is_monoid_weighted(b.op) - ): + if _is_weighted(a) or _is_weighted(b): return _weighted_stream_eq(a, b, _jax_eq) def _leaf_eq(x: Any, y: Any) -> bool: diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 57168295..ff639d4e 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -35,6 +35,7 @@ distributes_over, ) from effectful.ops.semantics import coproduct, evaluate, fvsof, handler +from effectful.ops.syntax import deffn from effectful.ops.types import NotHandled, Operation, Term from tests._monoid_helpers import ( INT_BACKEND, @@ -682,12 +683,12 @@ def test_reduce_single_weighted_stream(backend): Sum.reduce(body, {a: WS(A, w, Product)}) = Sum.reduce(Product.plus(w(a), body), {a: A}) """ - a, e = define_vars("a", "e", typ=backend.scalar_typ) + a = define_vars("a", typ=backend.scalar_typ) A = define_vars("A", typ=backend.stream_typ) body = backend.fresh_op("body", n_args=1, ret="scalar") w = backend.fresh_op("w", n_args=1, ret="scalar") - lhs = Sum.reduce(body(a()), {a: Product.weighted(A(), e, w(e()))}) + lhs = Sum.reduce(body(a()), {a: Product.weighted(A(), w)}) rhs = Sum.reduce(Product.plus(w(a()), body(a())), {a: A()}) check_rewrite( @@ -707,7 +708,7 @@ def test_reduce_weighted_factorization(backend): Exercises chaining of ``ReduceWeightedStream`` with ``ReduceFactorization`` inside ``NormalizeIntp``. """ - a, b, e_a, e_b = define_vars("a", "b", "e_a", "e_b", typ=backend.scalar_typ) + a, b = define_vars("a", "b", typ=backend.scalar_typ) A, B = define_vars("A", "B", typ=backend.stream_typ) f = backend.fresh_op("f", n_args=1, ret="scalar") g = backend.fresh_op("g", n_args=1, ret="scalar") @@ -716,10 +717,7 @@ def test_reduce_weighted_factorization(backend): lhs = Sum.reduce( Product.plus(f(a()), g(b())), - { - a: Product.weighted(A(), e_a, w_a(e_a())), - b: Product.weighted(B(), e_b, w_b(e_b())), - }, + {a: Product.weighted(A(), w_a), b: Product.weighted(B(), w_b)}, ) rhs = Product.plus( Sum.reduce(Product.plus(w_a(a()), Product.plus(f(a()))), {a: A()}), @@ -747,13 +745,12 @@ def test_reduce_cartesian_weighted_stream(backend): S, P = define_vars("S", "P", typ=backend.stream_typ) w = backend.fresh_op("w", n_args=1, ret="scalar") - lhs = CartesianProduct.reduce(Product.weighted(S(), e_var, w(e_var())), {p: P()}) + lhs = CartesianProduct.reduce(Product.weighted(S(), w), {p: P()}) row_var = Operation.define(Iterable[backend.scalar_typ], name="row") rhs = Product.weighted( CartesianProduct.reduce(S(), {p: P()}), - row_var, - Product.reduce(w(e_var()), {e_var: row_var()}), + deffn(Product.reduce(w(e_var()), {e_var: row_var()}), row_var), ) check_rewrite( @@ -784,14 +781,14 @@ def test_lift_weighted_cartesian(backend): ) """ a = define_vars("a", typ=backend.scalar_typ) - e, p = define_vars("e", "p", typ=backend.scalar_typ) + p = define_vars("p", typ=backend.scalar_typ) A, S, P = define_vars("A", "S", "P", typ=backend.stream_typ) body = backend.fresh_op("body", n_args=1, ret="scalar") w = backend.fresh_op("w", n_args=1, ret="scalar") lhs = Sum.reduce( Product.reduce(body(a()), {a: A()}), - {A: CartesianProduct.reduce(Product.weighted(S(), e, w(e())), {p: P()})}, + {A: CartesianProduct.reduce(Product.weighted(S(), w), {p: P()})}, ) rhs = Product.reduce( Sum.reduce(Product.plus(w(a()), body(a())), {a: S()}), {p: P()} From 9937e49fa837317dc0eb987014ae28a9d989959f Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 27 May 2026 11:49:07 -0400 Subject: [PATCH 18/29] fix test --- tests/test_handlers_jax_monoid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index c94e1e79..0d342962 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -130,7 +130,7 @@ def test_jax_weighted_reduce(backend: Backend): body = backend.fresh_op("body", n_args=1, ret="scalar") w = backend.fresh_op("w", n_args=1, ret="scalar") - ws = Product.weighted(X(), x, w(x())) + ws = Product.weighted(X(), w) lhs = Sum.reduce(body(x()), {x: ws}) rhs = jnp.sum( bind_dims(w(unbind_dims(X(), k)) * body(unbind_dims(X(), k)), k), axis=0 From 9926de523382e568925d51038114cd6b7fcda839 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 27 May 2026 11:55:33 -0400 Subject: [PATCH 19/29] fix test --- tests/test_ops_monoid.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index ff639d4e..83a12ab7 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -825,13 +825,11 @@ def _f(v: int) -> float: raise NotHandled return float(v * v) - a, e = define_vars("a", "e", typ=int) + a = define_vars("a", typ=int) w = Operation.define(_w, name="w") f = Operation.define(_f, name="f") with handler(NormalizeIntp): - result = evaluate( - Sum.reduce(f(a()), {a: Product.weighted([1, 2, 3, 4], e, w(e()))}) - ) + result = evaluate(Sum.reduce(f(a()), {a: Product.weighted([1, 2, 3, 4], w)})) assert math.isclose(result, 10.0) From b65777b547fba507f3f2fa55e6f5bb771fea4901 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 27 May 2026 13:52:03 -0400 Subject: [PATCH 20/29] resolve type aliases before dispatching --- effectful/ops/monoid.py | 4 ++-- effectful/ops/semantics.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index ee51584c..f42fdf82 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -126,8 +126,8 @@ def reduce[A, B, U: Body]( @Operation.define def weighted[T]( - self, stream: Iterable[T], weight: Callable[[T], W] | Operation[[T], W] - ) -> Iterable[T]: + self, stream: Stream[T], weight: Callable[[T], W] | Operation[[T], W] + ) -> Stream[T]: """A stream paired with a per-element weight. ``var`` is an :class:`Operation` standing for "an element of ``stream``"; ``weight`` is an expression that uses ``var`` and evaluates to the weight of that diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 77b3a231..49653c8d 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -287,6 +287,13 @@ def _evaluate_list_view(expr, **kwargs): def _simple_type(tp: type) -> type: """Convert a type object into a type that can be dispatched on.""" + + def _resolve_aliases(tp: type) -> type: + tp = typing.get_origin(tp) or tp + if isinstance(tp, typing.TypeAliasType): + return _resolve_aliases(tp.__value__) + return tp + if isinstance(tp, typing.TypeVar): tp = ( tp.__bound__ @@ -304,7 +311,7 @@ def _simple_type(tp: type) -> type: tp = functools.reduce(operator.or_, (type(arg) for arg in args)) if isinstance(tp, types.UnionType): raise TypeError(f"Union types are not supported: {tp}") - return typing.get_origin(tp) or tp + return _resolve_aliases(tp) def typeof_full[T](term: Expr[T]) -> type[T]: From e2ecbf6dd450b5806b1af2ac31b82a7e2c5e25aa Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 28 May 2026 08:39:58 -0400 Subject: [PATCH 21/29] wip --- effectful/ops/monoid.py | 40 ++------- tests/_monoid_helpers.py | 176 +++++++++++++++++---------------------- 2 files changed, 87 insertions(+), 129 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index f42fdf82..36fe3ccb 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -2,6 +2,7 @@ import functools import itertools import operator +import types import typing from collections import Counter, UserDict, defaultdict from collections.abc import Callable, Generator, Iterable, Mapping @@ -33,7 +34,6 @@ NotHandled, Operation, Term, - _CustomSingleDispatchCallable, ) type Stream[T] = Iterable[T] @@ -188,35 +188,6 @@ def __call__(self, s: S, t: T) -> bool: ) -@_CustomSingleDispatchCallable -def stream_element_type(__dispatch: Callable[[type], Callable[[Any], Any]], typ: Any): - """Maps a stream type to the type of its elements. - - Extensible per backend via :meth:`register`. Dispatch is on the type's - *origin* (e.g. ``list`` for ``list[int]``), so a handler registered for a - base class or ABC also covers its subclasses, most-specific match winning. - """ - if typing.get_origin(typ) is Annotated: - typ = typing.get_args(typ)[0] - origin = typing.get_origin(typ) or typ - return __dispatch(origin)(typ) - - -@stream_element_type.register(object) -def _stream_element_type_default(typ: Any) -> Any: - raise NotImplementedError( - f"No element-type handler registered for stream type {typ!r}" - ) - - -# ``Iterable`` covers ``list``, ``tuple``, ``MutableSequence``, ... via origin -# subclass dispatch. -@stream_element_type.register(Iterable) -def _stream_element_type_iterable(typ: Any) -> Any: - """Element type of a single-parameter generic stream like ``list[T]``.""" - return typing.get_args(typ)[0] - - def _is_monoid_plus(op: Operation) -> bool: """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" owner = getattr(op, "__self__", None) @@ -672,7 +643,13 @@ def reduce(self, monoid, body, streams): return fwd() stream_type = typeof_full(s) if isinstance(s, Term) else nested_type(s).value - elem_typ = stream_element_type(stream_type) + if not ( + isinstance(stream_type, types.GenericAlias) + and typing.get_origin(stream_type) == Stream + ): + return fwd() + + elem_typ = typing.get_args(stream_type)[0] elem_op = Operation.define(elem_typ, name="elem") row_op = Operation.define(Iterable[elem_typ], name="row") @@ -681,6 +658,7 @@ def reduce(self, monoid, body, streams): weight_monoid.reduce(w(elem_op()), {elem_op: row_op()}), row_op ) joint_stream = CartesianProduct.reduce(s, streams) + return weight_monoid.weighted(joint_stream, joint_weight) diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index b052218f..25329c40 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -2,32 +2,39 @@ import typing from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass -from typing import Annotated, Any, get_args, get_origin +from typing import Any import jax from hypothesis import given, settings from hypothesis import strategies as st import effectful.handlers.jax.numpy as _jnp +from effectful.handlers.jax.monoid import array_to_stream from effectful.internals.runtime import interpreter -from effectful.ops.monoid import NormalizeIntp, _is_monoid_weighted +from effectful.ops.monoid import NormalizeIntp, Stream, _is_monoid_weighted from effectful.ops.semantics import apply, evaluate, handler from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq from effectful.ops.types import NotHandled, Operation, Term -_JAX_ARRAY_SHAPE = (2,) - def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: return st.lists( st.integers(min_value=-5, max_value=5), - min_size=_JAX_ARRAY_SHAPE[0], - max_size=_JAX_ARRAY_SHAPE[0], + min_size=2, + max_size=2, + ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + + +def _jax_array_stream_strategy() -> st.SearchStrategy[jax.Array]: + return st.lists( + st.integers(min_value=-5, max_value=5), + min_size=1, + max_size=2, ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) # Shape-preserving unary jax fns: scalar → scalar (counterpart of -# ``_UNARY_NUM_FNS`` for ints). Used for ops declared with ``ret="scalar"``. +# ``_UNARY_NUM_FNS`` for ints). Used for scalar-returning ops. _UNARY_JAX_SCALAR_FNS: list[Callable[[jax.Array], jax.Array]] = [ lambda a: a, lambda a: a + 1, @@ -36,40 +43,18 @@ def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: lambda a: 2 * a, ] -# Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` -# for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. -# Used for ops declared with ``ret="stream"``. -_UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ - lambda a: _jnp.stack([a, a + 1]), - lambda a: _jnp.stack([a, -a]), - lambda a: _jnp.stack([a, a + 1, 2 * a]), +_UNARY_JAX_STREAM_FNS: list[Callable[[jax.Array], Stream[jax.Array]]] = [ + lambda a: array_to_stream(_jnp.stack([a, a + 1])), + lambda a: array_to_stream(_jnp.stack([a, -a])), + lambda a: array_to_stream(_jnp.stack([a, a + 1, 2 * a])), ] -_BINARY_JAX_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ +_BINARY_JAX_SCALAR_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ lambda a, b: a + b, lambda a, b: a - b, lambda a, b: a * b, ] - -def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: - """Strategy for the value an *0-arg* Operation should return.""" - if annotation is int: - return st.integers(min_value=-100, max_value=100) - if annotation is float: - return st.floats(allow_nan=False) - if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers(min_value=-100, max_value=100), max_size=2) - if annotation is jax.Array: - return _jax_array_value_strategy() - if get_origin(annotation) is list and get_args(annotation) == (jax.Array,): - return st.lists(_jax_array_value_strategy(), max_size=2) - raise NotImplementedError( - f"No value strategy for return annotation {annotation!r}; " - "supported: int, list[int], jax.Array, list[jax.Array]" - ) - - _UNARY_NUM_FNS: list[Callable[[int], int]] = [ lambda x: x, lambda x: x + 1, @@ -95,77 +80,77 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: lambda x: [0, x, x + 1], ] -_UNARY_JAX_LIST_FNS: list[Callable[[jax.Array], list[jax.Array]]] = [ - lambda _x: [], - lambda x: [x], - lambda x: [x, x + 1], - lambda x: [x, -x], -] - -def _is_stream(annotation: Any) -> bool: - """True if ``annotation`` carries the ``"stream"`` Annotated marker. +def _int_strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: + """Strategy producing a callable to bind ``op`` on the int backend. - On the JAX backend ``scalar_typ`` and ``stream_typ`` are both ``jax.Array``, - so :meth:`Backend.fresh_op` tags stream returns as - ``Annotated[jax.Array, "stream"]`` to keep them distinguishable here. + A 0-arg op stands for a value (a scalar, or a stream represented as a + ``list[int]``); an n-arg op stands for a scalar- or stream-returning + function. Scalar and stream returns are told apart by the operation's + return annotation (``int`` vs ``Stream[int]``). """ - return get_origin(annotation) is Annotated and "stream" in annotation.__metadata__ - + sig = op.__signature__ + n_args = len(sig.parameters) + ret = sig.return_annotation + + if n_args == 0: + if ret == int: + return st.integers(min_value=-100, max_value=100).map(deffn) + if ret == Stream[int]: + scalars = st.integers(min_value=-100, max_value=100) + return st.lists(scalars, max_size=2).map(deffn) + elif ret == int: + if n_args == 1: + return st.sampled_from(_UNARY_NUM_FNS) + if n_args == 2: + return st.sampled_from(_BINARY_NUM_FNS) + elif ret == Stream[int] and n_args == 1: + return st.sampled_from(_UNARY_LIST_FNS) + raise NotImplementedError( + f"No int strategy for op with return {ret!r} and {n_args} args" + ) -def _strip(annotation: Any) -> Any: - """Strip an ``Annotated`` wrapper to its underlying type.""" - if get_origin(annotation) is Annotated: - return get_args(annotation)[0] - return annotation +def _jax_strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: + """Strategy producing a callable to bind ``op`` on the jax backend. -def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: - """Pick a strategy producing a callable suitable for binding `op` in an - interpretation. Inspects the operation's signature. + The jax counterpart of :func:`_int_strategy_for_op`: scalars are + ``jax.Array``, streams are ``Stream[jax.Array]`` (a stacked 1-D array), + and the return annotation distinguishes the two. """ sig = op.__signature__ - params = list(sig.parameters.values()) - ret_annot = sig.return_annotation - ret = _strip(ret_annot) - ret_is_stream = _is_stream(ret_annot) - param_types = tuple(_strip(p.annotation) for p in params) - - if not params: - return _value_strategy_for(ret).map(deffn) - if ret in (int, float) and param_types == (int,): - return st.sampled_from(_UNARY_NUM_FNS) - if ret in (int, float) and param_types == (int, int): - return st.sampled_from(_BINARY_NUM_FNS) - if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): - return st.sampled_from(_UNARY_LIST_FNS) - if ret is jax.Array and param_types == (jax.Array,): - if ret_is_stream: - return st.sampled_from(_UNARY_JAX_FNS) - return st.sampled_from(_UNARY_JAX_SCALAR_FNS) - if ret is jax.Array and param_types == (jax.Array, jax.Array): - return st.sampled_from(_BINARY_JAX_FNS) - if ( - get_origin(ret) is list - and get_args(ret) == (jax.Array,) - and param_types == (jax.Array,) - ): - return st.sampled_from(_UNARY_JAX_LIST_FNS) + n_args = len(sig.parameters) + ret = sig.return_annotation + + if n_args == 0: + if ret == jax.Array: + return _jax_array_value_strategy().map(deffn) + if ret == Stream[jax.Array]: + return _jax_array_stream_strategy().map( + lambda arr: deffn(array_to_stream(arr)) + ) + elif ret == jax.Array: + if n_args == 1: + return st.sampled_from(_UNARY_JAX_SCALAR_FNS) + if n_args == 2: + return st.sampled_from(_BINARY_JAX_SCALAR_FNS) + elif ret == Stream[jax.Array] and n_args == 1: + return st.sampled_from(_UNARY_JAX_STREAM_FNS) raise NotImplementedError( - f"No callable strategy for free var with return {ret!r}, params {param_types!r}" + f"No jax strategy for op with return {ret!r} and {n_args} args" ) @st.composite def random_interpretation( - draw: st.DrawFn, free_vars: Sequence[Operation] + draw: st.DrawFn, backend: "Backend", free_vars: Sequence[Operation] ) -> Mapping[Operation, Callable[..., Any]]: - """Draw an Interpretation binding every Operation in `case.free_vars` to + """Draw an Interpretation binding every Operation in `free_vars` to a randomly chosen value/callable. Keys are Operation identities. """ intp: dict[Operation, Callable[..., Any]] = {} for op in free_vars: - intp[op] = draw(_strategy_for_op(op)) + intp[op] = draw(backend.strategy_for_op(op)) return intp @@ -297,6 +282,7 @@ class Backend: stream_typ: Any scalar_strategy: st.SearchStrategy[Any] eq: Callable[[Any, Any], bool] + strategy_for_op: Callable[[Operation], st.SearchStrategy[Callable[..., Any]]] def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation: """Build a fresh, unhandled Operation whose parameter and return @@ -307,15 +293,7 @@ def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation each of type ``scalar_typ``. """ scalar = self.scalar_typ - if ret == "stream": - out = self.stream_typ - # When scalar_typ == stream_typ (e.g. jax backend), tag the return - # with an Annotated marker so ``_strategy_for_op`` can pick the - # right (shape-changing) function family. - if scalar is out: - out = Annotated[out, "stream"] - else: - out = scalar + out = self.stream_typ if ret == "stream" else scalar params = ", ".join(f"_a{i}" for i in range(n_args)) ns: dict[str, Any] = {"NotHandled": NotHandled} exec(f"def _fn({params}):\n raise NotHandled\n", ns) @@ -403,7 +381,7 @@ def check_rewrite( norm = evaluate(lhs) assert syntactic_eq_alpha(norm, rhs) - @given(intp=random_interpretation(free_vars)) + @given(intp=random_interpretation(backend, free_vars)) @settings(max_examples=max_examples, deadline=deadline, report_multiple_bugs=False) def _check_semantics(intp): with handler(normalize), handler(intp): @@ -417,18 +395,20 @@ def _check_semantics(intp): INT_BACKEND = Backend( name="int", scalar_typ=int, - stream_typ=list[int], + stream_typ=Stream[int], scalar_strategy=st.integers(min_value=-100, max_value=100), eq=_int_eq, + strategy_for_op=_int_strategy_for_op, ) JAX_BACKEND = Backend( name="jax", scalar_typ=jax.Array, - stream_typ=jax.Array, + stream_typ=Stream[jax.Array], scalar_strategy=_jax_array_value_strategy(), eq=_jax_eq, + strategy_for_op=_jax_strategy_for_op, ) From 4c03e115b59b7fa68a9a100e8cbf910c790b363c Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 28 May 2026 08:40:04 -0400 Subject: [PATCH 22/29] wip --- effectful/internals/unification.py | 11 +++++++++++ effectful/ops/monoid.py | 15 +++++++-------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index e425bba6..4bc4a005 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -556,6 +556,17 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: and issubclass(subtyp, typing.get_origin(typ)) ): return subs # implicit expansion to subtyp[Any] + elif isinstance(typ, GenericAlias) and isinstance(subtyp, type): + # Special case for treating arrays as iterables of arrays + try: + import jax + + if typing.get_origin(typ) is collections.abc.Iterable and issubclass( + subtyp, jax.Array + ): + return unify(typing.get_args(typ)[0], jax.Array, subs) + except ImportError: + pass raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.") diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 36fe3ccb..52ea5ce6 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -18,6 +18,7 @@ fvsof, fwd, handler, + typeof, typeof_full, ) from effectful.ops.syntax import ( @@ -628,6 +629,11 @@ class ReduceCartesianWeightedStream(ObjectInterpretation): Only fires when ``w`` is independent of the plate vars. """ + @Operation.define + @staticmethod + def _iterable_elem[T](iter: Iterable[T]) -> T: + raise NotHandled + @implements(Monoid.reduce) def reduce(self, monoid, body, streams): if monoid is not CartesianProduct: @@ -642,14 +648,7 @@ def reduce(self, monoid, body, streams): if set(streams.keys()) & fvsof(w): return fwd() - stream_type = typeof_full(s) if isinstance(s, Term) else nested_type(s).value - if not ( - isinstance(stream_type, types.GenericAlias) - and typing.get_origin(stream_type) == Stream - ): - return fwd() - - elem_typ = typing.get_args(stream_type)[0] + elem_typ = typeof(self._iterable_elem(s)) elem_op = Operation.define(elem_typ, name="elem") row_op = Operation.define(Iterable[elem_typ], name="row") From 6a7f7a5d70736476e323841d786a8e537055800f Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 28 May 2026 08:40:35 -0400 Subject: [PATCH 23/29] remove typeof_full --- effectful/ops/monoid.py | 3 --- effectful/ops/semantics.py | 31 +++++++++++++------------------ 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 52ea5ce6..76351fa6 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -2,7 +2,6 @@ import functools import itertools import operator -import types import typing from collections import Counter, UserDict, defaultdict from collections.abc import Callable, Generator, Iterable, Mapping @@ -11,7 +10,6 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.internals.unification import nested_type from effectful.ops.semantics import ( coproduct, evaluate, @@ -19,7 +17,6 @@ fwd, handler, typeof, - typeof_full, ) from effectful.ops.syntax import ( ObjectInterpretation, diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 49653c8d..1b3b273d 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -311,26 +311,11 @@ def _resolve_aliases(tp: type) -> type: tp = functools.reduce(operator.or_, (type(arg) for arg in args)) if isinstance(tp, types.UnionType): raise TypeError(f"Union types are not supported: {tp}") - return _resolve_aliases(tp) - - -def typeof_full[T](term: Expr[T]) -> type[T]: - """Return the type of an expression, including any type parameters.""" - from effectful.internals.runtime import interpreter - from effectful.internals.unification import Box - - def _apply(op, *args, **kwargs): - return Box(op.__type_rule__(*args, **kwargs)) - - with interpreter({apply: _apply}): - type_or_value = evaluate(term) - if isinstance(type_or_value, Box): - return type_or_value.value - return typing.cast(type[T], type(type_or_value)) + return typing.get_origin(tp) or tp def typeof[T](term: Expr[T]) -> type[T]: - """Return the type of an expression, with type parameters stripped. + """Return the type of an expression. **Example usage**: @@ -351,7 +336,17 @@ def typeof[T](term: Expr[T]) -> type[T]: """ - return _simple_type(typeof_full(term)) + from effectful.internals.runtime import interpreter + from effectful.internals.unification import Box + + def _apply(op, *args, **kwargs): + return Box(op.__type_rule__(*args, **kwargs)) + + with interpreter({apply: _apply}): + type_or_value = evaluate(term) + if isinstance(type_or_value, Box): + return _simple_type(type_or_value.value) + return typing.cast(type[T], type(type_or_value)) def fvsof[S](term: Expr[S]) -> collections.abc.Set[Operation]: From 7060b6384e649cca88bba8bf85f57b3c9b121612 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 28 May 2026 08:45:19 -0400 Subject: [PATCH 24/29] wip --- effectful/handlers/jax/monoid.py | 2 -- tests/_monoid_helpers.py | 11 ++++------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 5ef06770..634a8cd7 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -18,14 +18,12 @@ Sum, distributes_over, outer_stream, - stream_element_type, ) from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof from effectful.ops.syntax import ObjectInterpretation, deffn, implements from effectful.ops.types import Interpretation, NotHandled, Operation, Term Iterable.register(jax.Array) # required to make jax arrays compatible with Stream[T] -stream_element_type.register(jax.Array, lambda t: jax.Array) def cartesian_prod(x, y): diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 25329c40..213c0fdc 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -9,7 +9,6 @@ from hypothesis import strategies as st import effectful.handlers.jax.numpy as _jnp -from effectful.handlers.jax.monoid import array_to_stream from effectful.internals.runtime import interpreter from effectful.ops.monoid import NormalizeIntp, Stream, _is_monoid_weighted from effectful.ops.semantics import apply, evaluate, handler @@ -44,9 +43,9 @@ def _jax_array_stream_strategy() -> st.SearchStrategy[jax.Array]: ] _UNARY_JAX_STREAM_FNS: list[Callable[[jax.Array], Stream[jax.Array]]] = [ - lambda a: array_to_stream(_jnp.stack([a, a + 1])), - lambda a: array_to_stream(_jnp.stack([a, -a])), - lambda a: array_to_stream(_jnp.stack([a, a + 1, 2 * a])), + lambda a: _jnp.stack([a, a + 1]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1, 2 * a]), ] _BINARY_JAX_SCALAR_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ @@ -126,9 +125,7 @@ def _jax_strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]] if ret == jax.Array: return _jax_array_value_strategy().map(deffn) if ret == Stream[jax.Array]: - return _jax_array_stream_strategy().map( - lambda arr: deffn(array_to_stream(arr)) - ) + return _jax_array_stream_strategy().map(deffn) elif ret == jax.Array: if n_args == 1: return st.sampled_from(_UNARY_JAX_SCALAR_FNS) From b405f29ab38a6fb6b9d6228e35ac298d17c324c3 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 28 May 2026 11:00:51 -0400 Subject: [PATCH 25/29] wip --- effectful/ops/semantics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 1b3b273d..acfdf9fd 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -311,7 +311,7 @@ def _resolve_aliases(tp: type) -> type: tp = functools.reduce(operator.or_, (type(arg) for arg in args)) if isinstance(tp, types.UnionType): raise TypeError(f"Union types are not supported: {tp}") - return typing.get_origin(tp) or tp + return _resolve_aliases(tp) def typeof[T](term: Expr[T]) -> type[T]: From 8f1936ae725729cbc407207694f62dd7ff39a2ab Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 28 May 2026 11:53:36 -0400 Subject: [PATCH 26/29] wip --- effectful/handlers/jax/monoid.py | 12 +++++++++--- effectful/internals/unification.py | 2 +- tests/test_internals_unification.py | 13 +++++++++++++ tests/test_ops_monoid.py | 19 +++++++++++++++++++ 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 634a8cd7..3f6273be 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -23,8 +23,6 @@ from effectful.ops.syntax import ObjectInterpretation, deffn, implements from effectful.ops.types import Interpretation, NotHandled, Operation, Term -Iterable.register(jax.Array) # required to make jax arrays compatible with Stream[T] - def cartesian_prod(x, y): if x.ndim == 1: @@ -115,7 +113,15 @@ def plus(self, *args): if not isinstance(a, jax.Array): return fwd() result = a if result is None else cartesian_prod(result, a) - return result if result is not None else CartesianProduct.identity + if result is None: + return CartesianProduct.identity + # CartesianProduct values are streams of rows. ``cartesian_prod`` + # already lifts 1D inputs to 2D, but a single-array call seeds + # ``result = a`` unchanged — promote so the rank invariant holds for + # every array-path return. + if result.ndim == 1: + result = result[:, None] + return result ARRAY_REDUCTORS = { diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 4bc4a005..2eadaeab 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -556,7 +556,7 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: and issubclass(subtyp, typing.get_origin(typ)) ): return subs # implicit expansion to subtyp[Any] - elif isinstance(typ, GenericAlias) and isinstance(subtyp, type): + elif isinstance(typ, GenericAlias): # Special case for treating arrays as iterables of arrays try: import jax diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 8b93976f..7318a93a 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -1900,3 +1900,16 @@ class Info(typing.TypedDict): subs = unify(collections.abc.Mapping, Info) assert subs == {} + + +def test_unify_jax_array_iterable(): + import jax + + subs = unify(collections.abc.Iterable[T], jax.Array) + assert subs == {T: jax.Array} + + +def test_nested_type_jax_array(): + import jax + + assert nested_type(jax.numpy.array([0, 1, 2])) == jax.Array diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 83a12ab7..d08e8921 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -2,11 +2,14 @@ import typing from collections.abc import Iterable +import jax import pytest from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st import effectful.handlers.jax.monoid # noqa: F401 +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import jax_getitem from effectful.ops.monoid import ( CartesianProduct, Max, @@ -567,6 +570,22 @@ def test_reduce_independent_4(backend): ) +def test_reduce_cartesian_3(): + i = define_vars("i", typ=jax.Array) + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(2), {i: jnp.arange(3)}) + assert value.shape == (2**3, 3) + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(2), {i: jnp.arange(1)}) + assert value.shape == (2**1, 1) + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(1), {i: jnp.arange(3)}) + assert value.shape == (1**3, 3) + + @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) def test_reduce_lifted_1(outer, inner, backend): a, i = define_vars("a", "i", typ=backend.scalar_typ) From aa87f526c3a9f96ba6134e37cd0c0bff41765dee Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 28 May 2026 11:53:47 -0400 Subject: [PATCH 27/29] format --- tests/test_ops_monoid.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index d08e8921..b4fb629f 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -9,7 +9,6 @@ import effectful.handlers.jax.monoid # noqa: F401 import effectful.handlers.jax.numpy as jnp -from effectful.handlers.jax import jax_getitem from effectful.ops.monoid import ( CartesianProduct, Max, From 3d681a0c0fc183e73a1062e161ed923d242ca9f4 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 28 May 2026 13:24:04 -0400 Subject: [PATCH 28/29] refactor test harness --- tests/_monoid_helpers.py | 445 +++++++++++++------------- tests/test_handlers_jax_monoid.py | 160 ++++------ tests/test_ops_monoid.py | 510 ++++++++++++------------------ 3 files changed, 475 insertions(+), 640 deletions(-) diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 213c0fdc..f8089bec 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,162 +1,23 @@ +import builtins import itertools import typing -from collections.abc import Callable, Mapping, Sequence -from dataclasses import dataclass -from typing import Any +from abc import ABC, abstractmethod +from collections.abc import Callable, Mapping +from typing import Any, Literal, overload import jax from hypothesis import given, settings from hypothesis import strategies as st +from hypothesis.strategies import SearchStrategy import effectful.handlers.jax.numpy as _jnp from effectful.internals.runtime import interpreter from effectful.ops.monoid import NormalizeIntp, Stream, _is_monoid_weighted -from effectful.ops.semantics import apply, evaluate, handler +from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq from effectful.ops.types import NotHandled, Operation, Term -def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: - return st.lists( - st.integers(min_value=-5, max_value=5), - min_size=2, - max_size=2, - ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) - - -def _jax_array_stream_strategy() -> st.SearchStrategy[jax.Array]: - return st.lists( - st.integers(min_value=-5, max_value=5), - min_size=1, - max_size=2, - ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) - - -# Shape-preserving unary jax fns: scalar → scalar (counterpart of -# ``_UNARY_NUM_FNS`` for ints). Used for scalar-returning ops. -_UNARY_JAX_SCALAR_FNS: list[Callable[[jax.Array], jax.Array]] = [ - lambda a: a, - lambda a: a + 1, - lambda a: a - 1, - lambda a: -a, - lambda a: 2 * a, -] - -_UNARY_JAX_STREAM_FNS: list[Callable[[jax.Array], Stream[jax.Array]]] = [ - lambda a: _jnp.stack([a, a + 1]), - lambda a: _jnp.stack([a, -a]), - lambda a: _jnp.stack([a, a + 1, 2 * a]), -] - -_BINARY_JAX_SCALAR_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ - lambda a, b: a + b, - lambda a, b: a - b, - lambda a, b: a * b, -] - -_UNARY_NUM_FNS: list[Callable[[int], int]] = [ - lambda x: x, - lambda x: x + 1, - lambda x: x - 1, - lambda x: -x, - lambda x: 2 * x, - lambda x: 3 * x + 1, -] - -_BINARY_NUM_FNS: list[Callable[[int, int], int]] = [ - lambda x, y: x + y, - lambda x, y: x - y, - lambda x, y: x * y, - lambda x, y: x + 2 * y, - lambda x, y: 2 * x - y, -] - -_UNARY_LIST_FNS: list[Callable[[int], list[int]]] = [ - lambda _x: [], - lambda x: [x], - lambda x: [x, x + 1], - lambda x: [x, -x], - lambda x: [0, x, x + 1], -] - - -def _int_strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: - """Strategy producing a callable to bind ``op`` on the int backend. - - A 0-arg op stands for a value (a scalar, or a stream represented as a - ``list[int]``); an n-arg op stands for a scalar- or stream-returning - function. Scalar and stream returns are told apart by the operation's - return annotation (``int`` vs ``Stream[int]``). - """ - sig = op.__signature__ - n_args = len(sig.parameters) - ret = sig.return_annotation - - if n_args == 0: - if ret == int: - return st.integers(min_value=-100, max_value=100).map(deffn) - if ret == Stream[int]: - scalars = st.integers(min_value=-100, max_value=100) - return st.lists(scalars, max_size=2).map(deffn) - elif ret == int: - if n_args == 1: - return st.sampled_from(_UNARY_NUM_FNS) - if n_args == 2: - return st.sampled_from(_BINARY_NUM_FNS) - elif ret == Stream[int] and n_args == 1: - return st.sampled_from(_UNARY_LIST_FNS) - raise NotImplementedError( - f"No int strategy for op with return {ret!r} and {n_args} args" - ) - - -def _jax_strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: - """Strategy producing a callable to bind ``op`` on the jax backend. - - The jax counterpart of :func:`_int_strategy_for_op`: scalars are - ``jax.Array``, streams are ``Stream[jax.Array]`` (a stacked 1-D array), - and the return annotation distinguishes the two. - """ - sig = op.__signature__ - n_args = len(sig.parameters) - ret = sig.return_annotation - - if n_args == 0: - if ret == jax.Array: - return _jax_array_value_strategy().map(deffn) - if ret == Stream[jax.Array]: - return _jax_array_stream_strategy().map(deffn) - elif ret == jax.Array: - if n_args == 1: - return st.sampled_from(_UNARY_JAX_SCALAR_FNS) - if n_args == 2: - return st.sampled_from(_BINARY_JAX_SCALAR_FNS) - elif ret == Stream[jax.Array] and n_args == 1: - return st.sampled_from(_UNARY_JAX_STREAM_FNS) - raise NotImplementedError( - f"No jax strategy for op with return {ret!r} and {n_args} args" - ) - - -@st.composite -def random_interpretation( - draw: st.DrawFn, backend: "Backend", free_vars: Sequence[Operation] -) -> Mapping[Operation, Callable[..., Any]]: - """Draw an Interpretation binding every Operation in `free_vars` to - a randomly chosen value/callable. Keys are Operation identities. - """ - intp: dict[Operation, Callable[..., Any]] = {} - for op in free_vars: - intp[op] = draw(backend.strategy_for_op(op)) - return intp - - -def define_vars(*names, typ=int): - if len(names) == 1: - return Operation.define(typ, name=names[0]) - return tuple(Operation.define(typ, name=n) for n in names) - - def syntactic_eq_alpha(x, y) -> bool: """Alpha-equivalence-respecting variant of ``syntactic_eq``. @@ -266,8 +127,7 @@ def _apply_canonical(op, *args, **kwargs) -> Term: return evaluate(expr) -@dataclass(frozen=True) -class Backend: +class Backend(ABC): """A value-domain spec used to share monoid tests across int and jax.Array backends. Provides the concrete value type, the hypothesis strategy for drawing scalars in property tests, and an equality predicate that works @@ -277,11 +137,29 @@ class Backend: name: str scalar_typ: Any stream_typ: Any - scalar_strategy: st.SearchStrategy[Any] - eq: Callable[[Any, Any], bool] - strategy_for_op: Callable[[Operation], st.SearchStrategy[Callable[..., Any]]] - - def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation: + strategy_for_op: dict[Operation, st.SearchStrategy[Callable[..., Any]]] + + def __init__(self): + self.strategy_for_op = {} + + @abstractmethod + def eq(self, a: Any, b: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> SearchStrategy: + raise NotImplementedError + + def _fresh_op( + self, + name: str, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> Operation: """Build a fresh, unhandled Operation whose parameter and return annotations are derived from this backend. @@ -291,15 +169,71 @@ def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation """ scalar = self.scalar_typ out = self.stream_typ if ret == "stream" else scalar - params = ", ".join(f"_a{i}" for i in range(n_args)) + params = ", ".join(f"_a{i}" for i in range(len(arg_types))) ns: dict[str, Any] = {"NotHandled": NotHandled} exec(f"def _fn({params}):\n raise NotHandled\n", ns) fn = ns["_fn"] fn.__annotations__ = { - **{f"_a{i}": scalar for i in range(n_args)}, + **{f"_a{i}": t for i, t in enumerate(arg_types)}, "return": out, } - return Operation.define(fn, name=name) + op = Operation.define(fn, name=name) + self.strategy_for_op[op] = self.strategy(arg_types, ret) + return op + + @overload + def define_vars(self, name: str, /, **kwargs) -> Operation: ... + + @overload + def define_vars( + self, n1: str, n2: str, /, *names: str, **kwargs + ) -> tuple[Operation, ...]: ... + + def define_vars(self, *names: str, **kwargs) -> Operation | tuple[Operation, ...]: # type: ignore[misc] + if len(names) == 1: + return self._fresh_op(names[0], **kwargs) + return tuple(self._fresh_op(n, **kwargs) for n in names) + + def check_rewrite( + self, + lhs, + rhs, + rule, + *, + max_examples: int = 25, + deadline=None, + normalize=NormalizeIntp, + ) -> None: + with handler(rule): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + fvs = fvsof(lhs) | fvsof(rhs) + + @st.composite + def random_interpretation( + draw: st.DrawFn, + ) -> Mapping[Operation, Callable[..., Any]]: + """Draw an Interpretation binding every Operation in `free_vars` to + a randomly chosen value/callable. Keys are Operation identities. + """ + intp: dict[Operation, Callable[..., Any]] = {} + for op, strategy in self.strategy_for_op.items(): + if op in fvs: + intp[op] = draw(strategy) + return intp + + @given(intp=random_interpretation()) + @settings( + max_examples=max_examples, deadline=deadline, report_multiple_bugs=False + ) + def _check_semantics(intp): + with handler(normalize), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert self.eq(lhs_val, rhs_val) + + _check_semantics() def _is_weighted(x: Any) -> bool: @@ -343,78 +277,137 @@ def _weighted_stream_eq(a, b, leaf_eq: Callable[[Any, Any], bool]) -> bool: return True -def _int_eq(a: Any, b: Any) -> bool: - if _is_weighted(a) or _is_weighted(b): - return _weighted_stream_eq(a, b, _int_eq) - return not isinstance(a, Term) and not isinstance(b, Term) and a == b +class IntBackend(Backend): + name = "int" + scalar_typ = int + stream_typ = Stream[int] + + _unary_num_fns: list[Callable[[int], int]] = [ + lambda x: x, + lambda x: x + 1, + lambda x: x - 1, + lambda x: -x, + lambda x: 2 * x, + lambda x: 3 * x + 1, + ] + + _binary_num_fns: list[Callable[[int, int], int]] = [ + lambda x, y: x + y, + lambda x, y: x - y, + lambda x, y: x * y, + lambda x, y: x + 2 * y, + lambda x, y: 2 * x - y, + ] + + _unary_list_fns: list[Callable[[int], list[int]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], + lambda x: [0, x, x + 1], + ] + + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> SearchStrategy: + match arg_types, ret: + case (), "scalar": + return st.integers(min_value=-100, max_value=100).map(deffn) + case (), "stream": + scalars = st.integers(min_value=-100, max_value=100) + return st.lists(scalars, max_size=2).map(deffn) + case (builtins.int,), "scalar": + return st.sampled_from(self._unary_num_fns) + case (builtins.int, builtins.int), "scalar": + return st.sampled_from(self._binary_num_fns) + case (builtins.int,), "stream": + return st.sampled_from(self._unary_list_fns) + raise NotImplementedError( + f"No int strategy for op with return {ret!r} and {arg_types} args" + ) + def eq(self, a: Any, b: Any) -> bool: + if _is_weighted(a) or _is_weighted(b): + return _weighted_stream_eq(a, b, self.eq) + return not isinstance(a, Term) and not isinstance(b, Term) and a == b + + +class JaxBackend(Backend): + name = "jax" + scalar_typ = jax.Array + stream_typ = jax.Array + + _unary_jax_scalar_fns: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: a, + lambda a: a + 1, + lambda a: a - 1, + lambda a: -a, + lambda a: 2 * a, + ] + + _unary_jax_stream_fns: list[Callable[[jax.Array], Stream[jax.Array]]] = [ + lambda a: _jnp.stack([a, a + 1]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1, 2 * a]), + ] + + _binary_jax_scalar_fns: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + ] + + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> st.SearchStrategy[Callable]: + match arg_types, ret: + case (), "scalar": + return ( + st.lists( + st.integers(min_value=-5, max_value=5), + min_size=2, + max_size=2, + ) + .map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + .map(deffn) + ) + case (), "stream": + return ( + st.lists( + st.integers(min_value=-5, max_value=5), + min_size=1, + max_size=2, + ) + .map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + .map(deffn) + ) + case (jax.Array,), "scalar": + return st.sampled_from(self._unary_jax_scalar_fns) + case (jax.Array, jax.Array), "scalar": + return st.sampled_from(self._binary_jax_scalar_fns) + case (jax.Array,), "stream": + return st.sampled_from(self._unary_jax_stream_fns) + + raise NotImplementedError( + f"No jax strategy for op with return {ret!r} and {arg_types} args" + ) -def _jax_eq(a: Any, b: Any) -> bool: - if _is_weighted(a) or _is_weighted(b): - return _weighted_stream_eq(a, b, _jax_eq) + def eq(self, a: Any, b: Any) -> bool: + if _is_weighted(a) or _is_weighted(b): + return _weighted_stream_eq(a, b, self.eq) - def _leaf_eq(x: Any, y: Any) -> bool: - return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) + def _leaf_eq(x: Any, y: Any) -> bool: + return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) - try: - leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) - except (ValueError, TypeError): - return False - return all(leaves) - - -def check_rewrite( - lhs, - rhs, - rule, - *, - backend: Backend, - free_vars=[], - max_examples: int = 25, - deadline=None, - normalize=NormalizeIntp, -) -> None: - with handler(rule): - norm = evaluate(lhs) - assert syntactic_eq_alpha(norm, rhs) - - @given(intp=random_interpretation(backend, free_vars)) - @settings(max_examples=max_examples, deadline=deadline, report_multiple_bugs=False) - def _check_semantics(intp): - with handler(normalize), handler(intp): - lhs_val = evaluate(lhs) - rhs_val = evaluate(rhs) - assert backend.eq(lhs_val, rhs_val) - - _check_semantics() - - -INT_BACKEND = Backend( - name="int", - scalar_typ=int, - stream_typ=Stream[int], - scalar_strategy=st.integers(min_value=-100, max_value=100), - eq=_int_eq, - strategy_for_op=_int_strategy_for_op, -) - - -JAX_BACKEND = Backend( - name="jax", - scalar_typ=jax.Array, - stream_typ=Stream[jax.Array], - scalar_strategy=_jax_array_value_strategy(), - eq=_jax_eq, - strategy_for_op=_jax_strategy_for_op, -) - - -__all__ = [ - "Backend", - "INT_BACKEND", - "JAX_BACKEND", - "random_interpretation", - "define_vars", - "syntactic_eq_alpha", - "check_rewrite", -] + try: + leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) + except (ValueError, TypeError): + return False + return all(leaves) + + +__all__ = ["Backend", "IntBackend", "JaxBackend", "syntactic_eq_alpha"] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index 0d342962..18df8401 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -27,7 +27,7 @@ ) from effectful.ops.semantics import coproduct, handler from effectful.ops.types import Interpretation -from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars +from tests._monoid_helpers import JaxBackend MONOIDS = [ pytest.param(Sum, jnp.sum, id="Sum"), @@ -39,28 +39,27 @@ @pytest.fixture -def backend() -> Backend: - return JAX_BACKEND +def backend() -> JaxBackend: + return JaxBackend() @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_1(monoid, reductor, backend: Backend): - (x, k) = define_vars("x", "k", typ=jax.Array) - X = define_vars("X", typ=backend.stream_typ) +def test_reduce_array_1(monoid, reductor, backend: JaxBackend): + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = backend.define_vars("X", ret="stream") lhs = monoid.reduce(x(), {x: X()}) rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ArrayReduce(), backend=backend, free_vars=[x, X, k] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_2(monoid, reductor, backend: Backend): - (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) - (X, Y) = define_vars("X", "Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_array_2(monoid, reductor, backend: JaxBackend): + (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") + (X, Y) = backend.define_vars("X", "Y", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) rhs = reductor( @@ -73,25 +72,20 @@ def test_reduce_array_2(monoid, reductor, backend: Backend): ), axis=0, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ArrayReduce(), - backend=backend, - free_vars=[x, y, k1, k2, X, Y, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_3(monoid, reductor, backend: Backend): +def test_reduce_array_3(monoid, reductor, backend: JaxBackend): """Stream `y` is `g(x())` — depends on the bound element of X. The reducer must inline ``g`` along the same named dim used to unbind `x`.""" - (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) - X = define_vars("X", typ=backend.stream_typ) + (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") + X = backend.define_vars("X", ret="stream") - f = backend.fresh_op("f", n_args=2, ret="scalar") - g = backend.fresh_op("g", n_args=1, ret="stream") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) + g = backend.define_vars("g", arg_types=[backend.scalar_typ], ret="stream") lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) rhs = reductor( @@ -107,17 +101,10 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): ), axis=0, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ArrayReduce(), - backend=backend, - free_vars=[x, y, k1, k2, X, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) -def test_jax_weighted_reduce(backend: Backend): +def test_jax_weighted_reduce(backend: JaxBackend): """Sum over a single stream with ``Product`` weights lowers to ``jnp.sum(w(X) * body(X))`` under ``NormalizeIntp`` ∘ ``ArrayReduce``. @@ -125,18 +112,17 @@ def test_jax_weighted_reduce(backend: Backend): so existing handlers need no changes to support weighted streams. """ - (x, k) = define_vars("x", "k", typ=jax.Array) - X = define_vars("X", typ=backend.stream_typ) - body = backend.fresh_op("body", n_args=1, ret="scalar") - w = backend.fresh_op("w", n_args=1, ret="scalar") + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = backend.define_vars("X", ret="stream") + body = backend.define_vars("body", arg_types=[backend.scalar_typ], ret="scalar") + w = backend.define_vars("w", arg_types=[backend.scalar_typ], ret="scalar") ws = Product.weighted(X(), w) lhs = Sum.reduce(body(x()), {x: ws}) rhs = jnp.sum( bind_dims(w(unbind_dims(X(), k)) * body(unbind_dims(X(), k)), k), axis=0 ) - - check_rewrite( + backend.check_rewrite( lhs=lhs, rhs=rhs, rule=functools.reduce( @@ -146,8 +132,6 @@ def test_jax_weighted_reduce(backend: Backend): [ReduceWeightedStream(), ArrayReduce(), ProductPlusJax()], ), ), - backend=backend, - free_vars=[x, k, X, body, w], ) @@ -158,62 +142,51 @@ def test_jax_weighted_reduce(backend: Backend): @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_empty(monoid, reductor, backend: Backend): +def test_reduce_delta_empty(monoid, reductor, backend: JaxBackend): """An empty-index delta unwraps to its body. reduce(M, streams, delta((), body)) ≡ reduce(M, streams, body) """ - x = define_vars("x", typ=backend.scalar_typ) - X = define_vars("X", typ=backend.stream_typ) + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") lhs = monoid.reduce(delta((), x()), {x: X()}) rhs = monoid.reduce(x(), {x: X()}) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[x, X], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_independent_one(monoid, reductor, backend: Backend): +def test_reduce_delta_independent_one(monoid, reductor, backend: JaxBackend): """One R1 step: peel the final preserved index off a delta. reduce(M, {y: Y()}, delta((y(),), f(y()))) ≡ reduce(M, {}, delta((), bind_dims(f(unbind_dims(Y(), k)), k))) """ - (y, k) = define_vars("y", "k", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") + (y, k) = backend.define_vars("y", "k", ret="scalar") + f = backend.define_vars("f", arg_types=[backend.scalar_typ], ret="scalar") # We use a concrete range here instead of an abstract one, because # unbind_dims is undefined on empty arrays (and the rewrite produces a # different rhs in this case) lhs = monoid.reduce(delta((y(),), f(y())), {y: Range(3)}) rhs = monoid.reduce(bind_dims(f(unbind_dims(jnp.arange(3), k)), k), {}) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[y, k, Y, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_independent_preserves_others(monoid, reductor, backend: Backend): +def test_reduce_delta_independent_preserves_others( + monoid, reductor, backend: JaxBackend +): """R1 peels only the final index. Streams not matching the peeled index op stay untouched, as do earlier entries in the index tuple. reduce(M, {x: X(), y: Y()}, delta((x(), y()), f(x(), y()))) ≡ reduce(M, {x: X()}, delta((x(),), bind_dims(f(x(), unbind_dims(Y(), k)), k))) """ - (x, y, k) = define_vars("x", "y", "k", typ=backend.scalar_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") + (x, y, k) = backend.define_vars("x", "y", "k", ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) lhs = monoid.reduce(delta((x(), y()), f(x(), y())), {x: Range(2), y: Range(3)}) rhs = monoid.reduce( @@ -225,27 +198,22 @@ def test_reduce_delta_independent_preserves_others(monoid, reductor, backend: Ba ), {}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_dependent_range_mask(monoid, reductor, backend: Backend): +def test_reduce_dependent_range_mask(monoid, reductor, backend: JaxBackend): """A dependent range stream gets rewritten to the referent's bbox stream, with the original constraint folded into the body as a where-guard. reduce(M, {u: range(0, N, 1), v: range(0, u(), 1)}, body) ≡ reduce(M, {u: range(0, N, 1), v: range(0, N, 1)}, where(v() < u(), body, M.identity)) """ - (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + (u, v) = backend.define_vars("u", "v", ret="scalar") N = 5 - f = backend.fresh_op("f", n_args=2, ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) body = f(u(), v()) @@ -254,18 +222,11 @@ def test_reduce_dependent_range_mask(monoid, reductor, backend: Backend): jnp.where(v() < u(), body, monoid.identity), {u: Range(0, N, 1), v: Range(0, N, 1)}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDependentRangeMask(), - backend=backend, - free_vars=[u, v, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backend): +def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: JaxBackend): """When the body is a delta term, R4 folds the constraint into the delta's weight while leaving its index tuple untouched. @@ -273,9 +234,11 @@ def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backe ≡ reduce(M, {u: range(N), v: range(N)}, delta((u(), v()), where(v() < u(), w, M.identity))) """ - (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + (u, v) = backend.define_vars("u", "v", ret="scalar") N = 5 - f = backend.fresh_op("f", n_args=2, ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) weight = f(u(), v()) idx = (u(), v()) @@ -285,17 +248,10 @@ def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backe delta(idx, jnp.where(v() < u(), weight, monoid.identity)), {u: Range(0, N, 1), v: Range(0, N, 1)}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDependentRangeMask(), - backend=backend, - free_vars=[u, v, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) -def test_reduce_matmul(): +def test_reduce_matmul(backend: JaxBackend): key = jax.random.PRNGKey(0) # Define dimensions B, I, J, K = 2, 3, 4, 5 @@ -303,7 +259,7 @@ def test_reduce_matmul(): # Create sample matrices X = random.normal(key, (B, I, J)) Y = random.normal(key, (B, J, K)) - (b, i, j, k) = define_vars("b", "i", "j", "k", typ=jax.Array) + (b, i, j, k) = backend.define_vars("b", "i", "j", "k", ret="scalar") with handler(NormalizeIntp): actual = Sum.reduce( diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index b4fb629f..fcd72f06 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -2,7 +2,6 @@ import typing from collections.abc import Iterable -import jax import pytest from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st @@ -39,19 +38,12 @@ from effectful.ops.semantics import coproduct, evaluate, fvsof, handler from effectful.ops.syntax import deffn from effectful.ops.types import NotHandled, Operation, Term -from tests._monoid_helpers import ( - INT_BACKEND, - JAX_BACKEND, - Backend, - check_rewrite, - define_vars, - syntactic_eq_alpha, -) +from tests._monoid_helpers import Backend, IntBackend, JaxBackend, syntactic_eq_alpha -@pytest.fixture(params=[INT_BACKEND, JAX_BACKEND], ids=["int", "jax"]) +@pytest.fixture(params=[IntBackend, JaxBackend], ids=["int", "jax"]) def backend(request) -> Backend: - return request.param + return request.param() ALL_MONOIDS = [ @@ -97,10 +89,10 @@ def backend(request) -> Backend: deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_associativity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) - b = data.draw(backend.scalar_strategy) - c = data.draw(backend.scalar_strategy) +def test_associativity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + b = data.draw(backend.strategy(ret="scalar"))() + c = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): left = monoid.plus(monoid.plus(a, b), c) right = monoid.plus(a, monoid.plus(b, c)) @@ -114,8 +106,8 @@ def test_associativity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_identity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_identity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(monoid.identity, a), a) assert backend.eq(monoid.plus(a, monoid.identity), a) @@ -128,9 +120,9 @@ def test_identity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_commutativity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) - b = data.draw(backend.scalar_strategy) +def test_commutativity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + b = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) @@ -142,8 +134,8 @@ def test_commutativity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_idempotence(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_idempotence(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(a, a), a) @@ -155,102 +147,86 @@ def test_idempotence(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_zero_absorbs(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_zero_absorbs(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(monoid.zero, a), monoid.zero) assert backend.eq(monoid.plus(a, monoid.zero), monoid.zero) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_empty(monoid, backend): - check_rewrite( - lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty(), backend=backend - ) +def test_plus_empty(monoid, backend: Backend): + backend.check_rewrite(lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_single(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) - check_rewrite( - lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle(), backend=backend, free_vars=[x] - ) +def test_plus_single(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + backend.check_rewrite(lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_right(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) +def test_plus_identity_right(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.plus(x(), monoid.identity) rhs = monoid.plus(x()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_left(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) +def test_plus_identity_left(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.plus(monoid.identity, x()) rhs = monoid.plus(x()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_right(monoid, backend): - x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - check_rewrite( +def test_plus_assoc_right(monoid, backend: Backend): + x, y, z = backend.define_vars("x", "y", "z", ret="scalar") + backend.check_rewrite( lhs=monoid.plus(x(), monoid.plus(y(), z())), rhs=monoid.plus(x(), y(), z()), rule=PlusAssoc(), - backend=backend, - free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_left(monoid, backend): - x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - check_rewrite( +def test_plus_assoc_left(monoid, backend: Backend): + x, y, z = backend.define_vars("x", "y", "z", ret="scalar") + backend.check_rewrite( lhs=monoid.plus(monoid.plus(x(), y()), z()), rhs=monoid.plus(x(), y(), z()), rule=PlusAssoc(), - backend=backend, - free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_sequence(monoid, backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) - check_rewrite( +def test_plus_sequence(monoid, backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") + backend.check_rewrite( lhs=monoid.plus((a(), b()), (c(), d())), rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), rule=MonoidOverSequence(), - backend=backend, - free_vars=[a, b, c, d], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_mapping(monoid, backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_mapping(monoid, backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}) rhs = {0: monoid.plus(a(), c()), 1: monoid.plus(b()), 2: monoid.plus(d())} - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverMapping(), - backend=backend, - free_vars=[a, b, c, d], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverMapping()) -def test_plus_distributes(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) rhs = Product.plus( Sum.plus( @@ -260,13 +236,11 @@ def test_plus_distributes(backend): Product.plus(b(), d()), ) ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) -def test_plus_distributes_constant(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes_constant(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) rhs = Product.plus( 5, @@ -277,13 +251,11 @@ def test_plus_distributes_constant(backend): Product.plus(b(), d()), ), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) -def test_plus_distributes_multiple(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes_multiple(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Sum.plus( Min.plus(a(), b()), Min.plus(c(), d()), @@ -304,238 +276,195 @@ def test_plus_distributes_multiple(backend): Sum.plus(b(), d()), ), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_consecutive(monoid, backend): +def test_plus_idempotent_consecutive(monoid, backend: Backend): """``a, a, b → a, b`` — only consecutive duplicates collapse.""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), a(), b()) - return check_rewrite( - lhs=lhs, - rhs=monoid.plus(a(), b()), - rule=PlusConsecutiveDups(), - backend=backend, - free_vars=[a, b], + return backend.check_rewrite( + lhs=lhs, rhs=monoid.plus(a(), b()), rule=PlusConsecutiveDups() ) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_non_consecutive(monoid, backend): +def test_plus_idempotent_non_consecutive(monoid, backend: Backend): """``a, b, a`` — Semilattice (Min/Max) collapses via commutative PlusDups.""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), b(), a()) rhs = monoid.plus(a(), b()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups()) @pytest.mark.parametrize("monoid", [Min, Max]) -def test_plus_commutative_idempotent_long(monoid, backend): +def test_plus_commutative_idempotent_long(monoid, backend: Backend): """Long alternation collapses via commutative dedup (Min/Max only).""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), b(), a(), b(), b(), a(), a()) rhs = monoid.plus(a(), b()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups()) @pytest.mark.parametrize("monoid", WITH_ZERO) -def test_plus_zero(monoid, backend): - a = define_vars("a", typ=backend.scalar_typ) +def test_plus_zero(monoid, backend: Backend): + a = backend.define_vars("a", ret="scalar") lhs_right = monoid.plus(a(), monoid.zero) lhs_left = monoid.plus(monoid.zero, a()) rhs = monoid.zero - check_rewrite( - lhs=lhs_right, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] - ) - check_rewrite( - lhs=lhs_left, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] - ) + backend.check_rewrite(lhs=lhs_right, rhs=rhs, rule=PlusZero()) + backend.check_rewrite(lhs=lhs_left, rhs=rhs, rule=PlusZero()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_1(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) +def test_partial_1(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.reduce(x(), {x: []}) rhs = monoid.identity - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_2(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) +def test_partial_2(monoid, backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + Y = backend.define_vars("Y", ret="stream") lhs = monoid.reduce(x(), {y: Y(), x: []}) rhs = monoid.identity - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, Y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_3(monoid, backend): - x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) +def test_partial_3(monoid, backend: Backend): + x, y, a, b = backend.define_vars("x", "y", "a", "b", ret="scalar") + Y = backend.define_vars("Y", ret="stream") lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, Y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_4(monoid, backend): - x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) - f = backend.fresh_op("f", n_args=1, ret="stream") +def test_partial_4(monoid, backend: Backend): + x, y, a, b = backend.define_vars("x", "y", "a", "b", ret="scalar") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="stream") lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, f]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence(monoid, backend): - x = Operation.define(backend.scalar_typ, name="x") - X = Operation.define(backend.stream_typ, name="X") - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_sequence(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce((f(x()), g(x())), {x: X()}) rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverSequence(), - backend=backend, - free_vars=[X, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverSequence()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence_2(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) - X, Y = define_vars("X", "Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_sequence_2(monoid, backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + X, Y = backend.define_vars("X", "Y", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) rhs = ( monoid.reduce(f(x()), {x: X(), y: Y()}), monoid.reduce(g(y()), {x: X(), y: Y()}), ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverSequence(), - backend=backend, - free_vars=[X, Y, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverSequence()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_mapping(monoid, backend): - x = Operation.define(backend.scalar_typ, name="x") - X = Operation.define(backend.stream_typ, name="X") - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_mapping(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce({0: f(x()), 1: g(x())}, {x: X()}) rhs = { 0: monoid.reduce(f(x()), {x: X()}), 1: monoid.reduce(g(x()), {x: X()}), } - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverMapping(), - backend=backend, - free_vars=[X, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverMapping()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_no_streams(monoid, backend): - a = define_vars("a", typ=backend.scalar_typ) +def test_reduce_no_streams(monoid, backend: Backend): + a = backend.define_vars("a", ret="scalar") + lhs = monoid.reduce(a(), {}) rhs = monoid.identity - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceNoStreams(), backend=backend, free_vars=[a] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceNoStreams()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_reduce(monoid, backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_reduce(monoid, backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceFusion(), backend=backend, free_vars=[A, B, f] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFusion()) @pytest.mark.parametrize("monoid", COMMUTATIVE) -def test_reduce_plus(monoid, backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) +def test_reduce_plus(monoid, backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) rhs = monoid.plus( monoid.reduce(a(), {a: A(), b: B()}), monoid.reduce(b(), {a: A(), b: B()}), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceSplit(), backend=backend, free_vars=[A, B] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceSplit()) -def test_reduce_independent_1(backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) +def test_reduce_independent_1(backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) rhs = Product.plus( Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b()), {b: B()}) ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceFactorization(), backend=backend, free_vars=[A, B] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) -def test_reduce_independent_2(backend): - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_independent_2(backend: Backend): + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) rhs = Product.plus( Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceFactorization(), - backend=backend, - free_vars=[A, B, C, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) -def test_reduce_independent_3_negative(backend): +def test_reduce_independent_3_negative(backend: Backend): """Stream `b` depends on `a` (b: g(a())), so the proposed factorization is unsound — the normalizer must NOT apply it.""" - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, C = define_vars("A", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") - g = backend.fresh_op("g", n_args=1, ret="stream") + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, C = backend.define_vars("A", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + g = backend.define_vars("g", arg_types=(backend.scalar_typ,), ret="stream") with handler(ReduceFactorization()): # ty:ignore[invalid-argument-type] lhs = Sum.reduce( @@ -549,10 +478,12 @@ def test_reduce_independent_3_negative(backend): assert not syntactic_eq_alpha(lhs, bogus_rhs) -def test_reduce_independent_4(backend): - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_independent_4(backend: Backend): + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) rhs = Product.plus( @@ -560,17 +491,12 @@ def test_reduce_independent_4(backend): Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceFactorization(), - backend=backend, - free_vars=[A, B, C, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) def test_reduce_cartesian_3(): - i = define_vars("i", typ=jax.Array) + backend = JaxBackend() + i = backend.define_vars("i", ret="scalar") with handler(NormalizeIntp): value = CartesianProduct.reduce(jnp.zeros(2), {i: jnp.arange(3)}) @@ -586,29 +512,23 @@ def test_reduce_cartesian_3(): @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_1(outer, inner, backend): - a, i = define_vars("a", "i", typ=backend.scalar_typ) - A, N, A_domain = define_vars("A", "N", "A_domain", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") +def test_reduce_lifted_1(outer, inner, backend: Backend): + a, i = backend.define_vars("a", "i", ret="scalar") + A, N, A_domain = backend.define_vars("A", "N", "A_domain", ret="stream") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N()})}, ) - term2 = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) - - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[N, A_domain, f], - ) + rhs = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) def test_reduce_cartesian_1(): - a, i = define_vars("a", "i", typ=int) - A = define_vars("A", typ=tuple[int]) + backend = IntBackend() + a, i = backend.define_vars("a", "i", ret="scalar") + A = backend.define_vars("A", ret="stream") with handler(NormalizeIntp): term1 = Sum.reduce( @@ -620,8 +540,9 @@ def test_reduce_cartesian_1(): def test_reduce_cartesian_2(): - a, i = define_vars("a", "i", typ=int) - A = define_vars("A", typ=tuple[int]) + backend = IntBackend() + a, i = backend.define_vars("a", "i", ret="scalar") + A = backend.define_vars("A", ret="stream") with handler(NormalizeIntp): term1 = Sum.reduce( @@ -633,46 +554,41 @@ def test_reduce_cartesian_2(): @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_multi_index(outer, inner, backend): - a, i, j = define_vars("a", "i", "j", typ=backend.scalar_typ) - A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") +def test_reduce_lifted_multi_index(outer, inner, backend: Backend): + a, i, j = backend.define_vars("a", "i", "j", ret="scalar") + A, N, M, A_domain = backend.define_vars("A", "N", "M", "A_domain", ret="stream") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, ) - term2 = inner.reduce( - outer.reduce(inner.plus(f(a())), {a: A_domain()}), - {i: N(), j: M()}, - ) - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[N, M, A_domain, f], + rhs = inner.reduce( + outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N(), j: M()} ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_2(outer, inner, backend): +def test_reduce_lifted_2(outer, inner, backend: Backend): """The worked example on page 396 of 'Lifted Variable Elimination: Decoupling the Operators from the Constraint Language'. """ - a, i, s, t = define_vars("a", "i", "s", "t", typ=backend.scalar_typ) - A, N, T = define_vars("A", "N", "T", typ=backend.stream_typ) - A_domain = backend.fresh_op("A_domain", n_args=1, ret="stream") - f1 = backend.fresh_op("f1", n_args=2, ret="scalar") - f2 = backend.fresh_op("f2", n_args=2, ret="scalar") + a, i, s, t = backend.define_vars("a", "i", "s", "t", ret="scalar") + A, N, T = backend.define_vars("A", "N", "T", ret="stream") + A_domain = backend.define_vars( + "A_domain", arg_types=(backend.scalar_typ,), ret="stream" + ) + f1, f2 = backend.define_vars( + "f1", "f2", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, ) - - term2 = outer.reduce( + rhs = outer.reduce( inner.reduce( outer.reduce( inner.plus(inner.plus(f1(a(), s()), f2(t(), a()))), {a: A_domain(i())} @@ -681,14 +597,7 @@ def test_reduce_lifted_2(outer, inner, backend): ), {t: T()}, ) - - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) # --------------------------------------------------------------------------- @@ -696,29 +605,23 @@ def test_reduce_lifted_2(outer, inner, backend): # --------------------------------------------------------------------------- -def test_reduce_single_weighted_stream(backend): +def test_reduce_single_weighted_stream(backend: Backend): """Single weighted stream desugars: Sum.reduce(body, {a: WS(A, w, Product)}) = Sum.reduce(Product.plus(w(a), body), {a: A}) """ - a = define_vars("a", typ=backend.scalar_typ) - A = define_vars("A", typ=backend.stream_typ) - body = backend.fresh_op("body", n_args=1, ret="scalar") - w = backend.fresh_op("w", n_args=1, ret="scalar") + a = backend.define_vars("a", ret="scalar") + A = backend.define_vars("A", ret="stream") + body, w = backend.define_vars( + "body", "w", arg_types=(backend.scalar_typ,), ret="scalar" + ) lhs = Sum.reduce(body(a()), {a: Product.weighted(A(), w)}) rhs = Sum.reduce(Product.plus(w(a()), body(a())), {a: A()}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceWeightedStream()) - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceWeightedStream(), - backend=backend, - free_vars=[A, body, w], - ) - -def test_reduce_weighted_factorization(backend): +def test_reduce_weighted_factorization(backend: Backend): """Two independent weighted streams under Sum with Product weights factor: Sum.reduce(f(a)*g(b), {a: Product.weighted(A, a, w_a), b: Product.weighted(B, b, w_b)}) = (Sum.reduce(w_a(a)*f(a), {a: A})) * (Sum.reduce(w_b(b)*g(b), {b: B})) @@ -726,12 +629,11 @@ def test_reduce_weighted_factorization(backend): Exercises chaining of ``ReduceWeightedStream`` with ``ReduceFactorization`` inside ``NormalizeIntp``. """ - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = backend.fresh_op("g", n_args=1, ret="scalar") - w_a = backend.fresh_op("w_a", n_args=1, ret="scalar") - w_b = backend.fresh_op("w_b", n_args=1, ret="scalar") + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + f, g, w_a, w_b = backend.define_vars( + "f", "g", "w_a", "w_b", arg_types=(backend.scalar_typ,), ret="scalar" + ) lhs = Sum.reduce( Product.plus(f(a()), g(b())), @@ -741,17 +643,12 @@ def test_reduce_weighted_factorization(backend): Sum.reduce(Product.plus(w_a(a()), Product.plus(f(a()))), {a: A()}), Sum.reduce(Product.plus(w_b(b()), Product.plus(g(b()))), {b: B()}), ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=coproduct(ReduceWeightedStream(), ReduceFactorization()), - backend=backend, - free_vars=[A, B, f, g, w_a, w_b], + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceWeightedStream(), ReduceFactorization()) ) -def test_reduce_cartesian_weighted_stream(backend): +def test_reduce_cartesian_weighted_stream(backend: Backend): """``CartesianProduct.reduce`` over a ``WeightedStream`` body whose weight is independent of the plate var rewrites to a single joint ``WeightedStream``: @@ -759,28 +656,20 @@ def test_reduce_cartesian_weighted_stream(backend): CartesianProduct.reduce(M.weighted(s, e, w(e)), {p: P}) = M.weighted(CartesianProduct.reduce(s, {p: P}), row, M.reduce(w(e), {e: row()})) """ - p, e_var = define_vars("p", "e_var", typ=backend.scalar_typ) - S, P = define_vars("S", "P", typ=backend.stream_typ) - w = backend.fresh_op("w", n_args=1, ret="scalar") + p, e_var = backend.define_vars("p", "e_var", ret="scalar") + S, P = backend.define_vars("S", "P", ret="stream") + w = backend.define_vars("w", arg_types=(backend.scalar_typ,), ret="scalar") lhs = CartesianProduct.reduce(Product.weighted(S(), w), {p: P()}) - - row_var = Operation.define(Iterable[backend.scalar_typ], name="row") + row_var = Operation.define(Iterable[backend.scalar_typ], name="row") # type: ignore[name-defined] rhs = Product.weighted( CartesianProduct.reduce(S(), {p: P()}), deffn(Product.reduce(w(e_var()), {e_var: row_var()}), row_var), ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceCartesianWeightedStream(), - backend=backend, - free_vars=[S, P, w], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceCartesianWeightedStream()) -def test_lift_weighted_cartesian(backend): +def test_lift_weighted_cartesian(backend: Backend): """Compose ``ReduceCartesianWeightedStream`` + ``ReduceWeightedStream`` + ``ReduceDistributeCartesianProduct`` on a Sum-of-Product-of-weighted shape: @@ -798,11 +687,11 @@ def test_lift_weighted_cartesian(backend): {p: P}, ) """ - a = define_vars("a", typ=backend.scalar_typ) - p = define_vars("p", typ=backend.scalar_typ) - A, S, P = define_vars("A", "S", "P", typ=backend.stream_typ) - body = backend.fresh_op("body", n_args=1, ret="scalar") - w = backend.fresh_op("w", n_args=1, ret="scalar") + a, p = backend.define_vars("a", "p", ret="scalar") + A, S, P = backend.define_vars("A", "S", "P", ret="stream") + body, w = backend.define_vars( + "body", "w", arg_types=(backend.scalar_typ,), ret="scalar" + ) lhs = Sum.reduce( Product.reduce(body(a()), {a: A()}), @@ -811,16 +700,13 @@ def test_lift_weighted_cartesian(backend): rhs = Product.reduce( Sum.reduce(Product.plus(w(a()), body(a())), {a: S()}), {p: P()} ) - - check_rewrite( + backend.check_rewrite( lhs=lhs, rhs=rhs, rule=coproduct( coproduct(ReduceWeightedStream(), ReduceCartesianWeightedStream()), ReduceDistributeCartesianProduct(), ), - backend=backend, - free_vars=[S, P, body, w], ) @@ -843,7 +729,7 @@ def _f(v: int) -> float: raise NotHandled return float(v * v) - a = define_vars("a", typ=int) + a = Operation.define(int, name="a") w = Operation.define(_w, name="w") f = Operation.define(_f, name="f") From 81ccf99f99256ad558a1d878e27567af53aa5e1d Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 29 May 2026 18:21:46 -0400 Subject: [PATCH 29/29] drop unused test --- tests/test_internals_unification.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 7318a93a..8abfdc5a 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -1907,9 +1907,3 @@ def test_unify_jax_array_iterable(): subs = unify(collections.abc.Iterable[T], jax.Array) assert subs == {T: jax.Array} - - -def test_nested_type_jax_array(): - import jax - - assert nested_type(jax.numpy.array([0, 1, 2])) == jax.Array