diff --git a/effectful/handlers/jax/_terms.py b/effectful/handlers/jax/_terms.py index c88fe934..05a5390e 100644 --- a/effectful/handlers/jax/_terms.py +++ b/effectful/handlers/jax/_terms.py @@ -14,10 +14,17 @@ unbind_dims, ) from effectful.internals.tensor_utils import _desugar_tensor_index +from effectful.internals.unification import Box, nested_type from effectful.ops.syntax import defdata from effectful.ops.types import Expr, NotHandled, Operation, Term +@nested_type.register(jax.Array) +@nested_type.register(jax._src.core.Tracer) +def _(value): + return Box(jax.Array) + + class _IndexUpdateHelper: """Helper class to implement array-style .at[index].set() updates for effectful arrays.""" diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 42d7866e..3f6273be 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -16,6 +16,7 @@ Product, Streams, Sum, + distributes_over, outer_stream, ) from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof @@ -38,6 +39,10 @@ def cartesian_prod(x, y): LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) +# ``Sum`` in log space is multiplication, which distributes over ``LogSumExp``: +# a + logsumexp(b, c) = logsumexp(a + b, a + c) +distributes_over.register(Sum, LogSumExp) + def _jax_args(args): """True iff ``args`` is non-empty and every arg is a concrete @@ -108,7 +113,15 @@ def plus(self, *args): if not isinstance(a, jax.Array): return fwd() result = a if result is None else cartesian_prod(result, a) - return result if result is not None else CartesianProduct.identity + if result is None: + return CartesianProduct.identity + # CartesianProduct values are streams of rows. ``cartesian_prod`` + # already lifts 1D inputs to 2D, but a single-array call seeds + # ``result = a`` unchanged — promote so the rank invariant holds for + # every array-path return. + if result.ndim == 1: + result = result[:, None] + return result ARRAY_REDUCTORS = { diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index e425bba6..2eadaeab 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -556,6 +556,17 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: and issubclass(subtyp, typing.get_origin(typ)) ): return subs # implicit expansion to subtyp[Any] + elif isinstance(typ, GenericAlias): + # Special case for treating arrays as iterables of arrays + try: + import jax + + if typing.get_origin(typ) is collections.abc.Iterable and issubclass( + subtyp, jax.Array + ): + return unify(typing.get_args(typ)[0], jax.Array, subs) + except ImportError: + pass raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.") diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index c9231510..76351fa6 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,7 +10,14 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.ops.semantics import ( + coproduct, + evaluate, + fvsof, + fwd, + handler, + typeof, +) from effectful.ops.syntax import ( ObjectInterpretation, Scoped, @@ -19,11 +26,17 @@ syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term +from effectful.ops.types import ( + Expr, + Interpretation, + NotHandled, + Operation, + Term, +) + +type Stream[T] = Iterable[T] -# Note: The streams value type should be something like Iterable[T], but some of -# our target stream types (e.g. jax.Array) are not subtypes of Iterable -type Streams[T] = Mapping[Operation[[], T], Any] +type Streams = Mapping[Operation[[], Any], Stream[Any]] type Body[T] = ( Iterable[T] @@ -34,9 +47,7 @@ ) -def outer_stream( - streams: Streams, -) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: +def outer_stream(streams: Streams) -> Iterable[tuple[Operation, Stream, Streams]]: """Returns the streams that can be ordered outermost in the loop nest as well as the remaining streams in the nest. @@ -51,13 +62,13 @@ def outer_stream( ) -class Monoid[T]: +class Monoid[W]: """A monoid with ``plus`` and ``reduce`` :class:`Operation` s.""" _name: str - identity: T + identity: W - def __init__(self, identity: T, name: str): + def __init__(self, identity: W, name: str): self._name = name self.identity = identity @@ -111,6 +122,18 @@ def reduce[A, B, U: Body]( return self.plus(*new_reduces) raise NotHandled + @Operation.define + def weighted[T]( + self, stream: Stream[T], weight: Callable[[T], W] | Operation[[T], W] + ) -> Stream[T]: + """A stream paired with a per-element weight. ``var`` is an + :class:`Operation` standing for "an element of ``stream``"; ``weight`` + is an expression that uses ``var`` and evaluates to the weight of that + element. + + """ + raise NotHandled + class MonoidWithZero[T](Monoid[T]): zero: T @@ -175,6 +198,12 @@ def _is_monoid_reduce(op: Operation) -> bool: return isinstance(owner, Monoid) and op is owner.reduce +def _is_monoid_weighted(op: Operation) -> bool: + """True if ``op`` is the ``weighted`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.weighted + + class PlusEmpty(ObjectInterpretation): """plus() = 0""" @@ -557,6 +586,78 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): return fwd() +class ReduceWeightedStream(ObjectInterpretation): + """reduce(M, body, {x: WM.weighted(s, v, w), ...}) = reduce(M, WM.plus(w[v:=x()], body), {x: s, ...}) + + requires distributes_over(WM, M). + + The substitution ``v -> x`` is done by beta-reducing ``deffn(w, v)`` on + ``x()`` — symbolic, no Python dispatch on the weight expression. + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + for k, v in streams.items(): + if isinstance(v, Term) and _is_monoid_weighted(v.op): + v_stream, v_weight = v.args + v_monoid = v.op.__self__ + if not distributes_over(v_monoid, monoid): + continue + w_at_k = v_weight(k()) + weighted_body = v_monoid.plus(w_at_k, body) + new_streams = {**streams, k: v_stream} + return monoid.reduce(weighted_body, new_streams) + return fwd() + + +class ReduceCartesianWeightedStream(ObjectInterpretation): + """``CartesianProduct.reduce`` over a :func:`weighted` body whose + ``weight`` is independent of the plate (product-index) streams:: + + CartesianProduct.reduce(M.weighted(s, w), plates) + = M.weighted( + CartesianProduct.reduce(s, plates), + deffn(M.reduce(w, {e: row()}), row), + ) + + Reuses ``body``'s element binder ``e`` (already typed by construction); + introduces a fresh ``row`` binder typed as ``Iterable[elem_type]``. + + Only fires when ``w`` is independent of the plate vars. + """ + + @Operation.define + @staticmethod + def _iterable_elem[T](iter: Iterable[T]) -> T: + raise NotHandled + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if monoid is not CartesianProduct: + return fwd() + if not (isinstance(body, Term) and _is_monoid_weighted(body.op)): + return fwd() + + s, w = body.args + if not isinstance(s, Term) and len(s) == 0: + return CartesianProduct.reduce([], streams) + + if set(streams.keys()) & fvsof(w): + return fwd() + + elem_typ = typeof(self._iterable_elem(s)) + elem_op = Operation.define(elem_typ, name="elem") + row_op = Operation.define(Iterable[elem_typ], name="row") + + weight_monoid = body.op.__self__ + joint_weight = deffn( + weight_monoid.reduce(w(elem_op()), {elem_op: row_op()}), row_op + ) + joint_stream = CartesianProduct.reduce(s, streams) + + return weight_monoid.weighted(joint_stream, joint_weight) + + class MonoidOverCallable(ObjectInterpretation): """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" @@ -751,6 +852,8 @@ def extend(self, *intps: Interpretation) -> typing.Self: ReduceSplit(), ReduceFactorization(), ReduceDistributeCartesianProduct(), + ReduceWeightedStream(), + ReduceCartesianWeightedStream(), PlusEmpty(), PlusSingle(), PlusIdentity(), diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 8fd62bcd..acfdf9fd 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -287,6 +287,13 @@ def _evaluate_list_view(expr, **kwargs): def _simple_type(tp: type) -> type: """Convert a type object into a type that can be dispatched on.""" + + def _resolve_aliases(tp: type) -> type: + tp = typing.get_origin(tp) or tp + if isinstance(tp, typing.TypeAliasType): + return _resolve_aliases(tp.__value__) + return tp + if isinstance(tp, typing.TypeVar): tp = ( tp.__bound__ @@ -304,7 +311,7 @@ def _simple_type(tp: type) -> type: tp = functools.reduce(operator.or_, (type(arg) for arg in args)) if isinstance(tp, types.UnionType): raise TypeError(f"Union types are not supported: {tp}") - return typing.get_origin(tp) or tp + return _resolve_aliases(tp) def typeof[T](term: Expr[T]) -> type[T]: diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index f15103e3..f8089bec 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,146 +1,22 @@ +import builtins import itertools -from collections.abc import Callable, Mapping, Sequence -from dataclasses import dataclass -from typing import Any, get_args, get_origin +import typing +from abc import ABC, abstractmethod +from collections.abc import Callable, Mapping +from typing import Any, Literal, overload import jax from hypothesis import given, settings from hypothesis import strategies as st +from hypothesis.strategies import SearchStrategy import effectful.handlers.jax.numpy as _jnp from effectful.internals.runtime import interpreter -from effectful.ops.monoid import NormalizeIntp -from effectful.ops.semantics import apply, evaluate, handler +from effectful.ops.monoid import NormalizeIntp, Stream, _is_monoid_weighted +from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq from effectful.ops.types import NotHandled, Operation, Term -_JAX_ARRAY_SHAPE = (2,) - - -def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: - return st.lists( - st.integers(min_value=-5, max_value=5), - min_size=_JAX_ARRAY_SHAPE[0], - max_size=_JAX_ARRAY_SHAPE[0], - ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) - - -# Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` -# for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. -_UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ - lambda a: _jnp.stack([a, a + 1]), - lambda a: _jnp.stack([a, -a]), - lambda a: _jnp.stack([a, a + 1, 2 * a]), -] - -_BINARY_JAX_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ - lambda a, b: a + b, - lambda a, b: a - b, - lambda a, b: a * b, -] - - -def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: - """Strategy for the value an *0-arg* Operation should return.""" - if annotation is int: - return st.integers(min_value=-100, max_value=100) - if annotation is float: - return st.floats(allow_nan=False) - if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers(min_value=-100, max_value=100), max_size=2) - if annotation is jax.Array: - return _jax_array_value_strategy() - if get_origin(annotation) is list and get_args(annotation) == (jax.Array,): - return st.lists(_jax_array_value_strategy(), max_size=2) - raise NotImplementedError( - f"No value strategy for return annotation {annotation!r}; " - "supported: int, list[int], jax.Array, list[jax.Array]" - ) - - -_UNARY_NUM_FNS: list[Callable[[int], int]] = [ - lambda x: x, - lambda x: x + 1, - lambda x: x - 1, - lambda x: -x, - lambda x: 2 * x, - lambda x: 3 * x + 1, -] - -_BINARY_NUM_FNS: list[Callable[[int, int], int]] = [ - lambda x, y: x + y, - lambda x, y: x - y, - lambda x, y: x * y, - lambda x, y: x + 2 * y, - lambda x, y: 2 * x - y, -] - -_UNARY_LIST_FNS: list[Callable[[int], list[int]]] = [ - lambda _x: [], - lambda x: [x], - lambda x: [x, x + 1], - lambda x: [x, -x], - lambda x: [0, x, x + 1], -] - -_UNARY_JAX_LIST_FNS: list[Callable[[jax.Array], list[jax.Array]]] = [ - lambda _x: [], - lambda x: [x], - lambda x: [x, x + 1], - lambda x: [x, -x], -] - - -def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: - """Pick a strategy producing a callable suitable for binding `op` in an - interpretation. Inspects the operation's signature. - """ - sig = op.__signature__ - params = list(sig.parameters.values()) - ret = sig.return_annotation - param_types = tuple(p.annotation for p in params) - - if not params: - return _value_strategy_for(ret).map(deffn) - if ret in (int, float) and param_types == (int,): - return st.sampled_from(_UNARY_NUM_FNS) - if ret in (int, float) and param_types == (int, int): - return st.sampled_from(_BINARY_NUM_FNS) - if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): - return st.sampled_from(_UNARY_LIST_FNS) - if ret is jax.Array and param_types == (jax.Array,): - return st.sampled_from(_UNARY_JAX_FNS) - if ret is jax.Array and param_types == (jax.Array, jax.Array): - return st.sampled_from(_BINARY_JAX_FNS) - if ( - get_origin(ret) is list - and get_args(ret) == (jax.Array,) - and param_types == (jax.Array,) - ): - return st.sampled_from(_UNARY_JAX_LIST_FNS) - raise NotImplementedError( - f"No callable strategy for free var with return {ret!r}, params {param_types!r}" - ) - - -@st.composite -def random_interpretation( - draw: st.DrawFn, free_vars: Sequence[Operation] -) -> Mapping[Operation, Callable[..., Any]]: - """Draw an Interpretation binding every Operation in `case.free_vars` to - a randomly chosen value/callable. Keys are Operation identities. - """ - intp: dict[Operation, Callable[..., Any]] = {} - for op in free_vars: - intp[op] = draw(_strategy_for_op(op)) - return intp - - -def define_vars(*names, typ=int): - if len(names) == 1: - return Operation.define(typ, name=names[0]) - return tuple(Operation.define(typ, name=n) for n in names) - def syntactic_eq_alpha(x, y) -> bool: """Alpha-equivalence-respecting variant of ``syntactic_eq``. @@ -251,8 +127,7 @@ def _apply_canonical(op, *args, **kwargs) -> Term: return evaluate(expr) -@dataclass(frozen=True) -class Backend: +class Backend(ABC): """A value-domain spec used to share monoid tests across int and jax.Array backends. Provides the concrete value type, the hypothesis strategy for drawing scalars in property tests, and an equality predicate that works @@ -262,10 +137,29 @@ class Backend: name: str scalar_typ: Any stream_typ: Any - scalar_strategy: st.SearchStrategy[Any] - eq: Callable[[Any, Any], bool] - - def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation: + strategy_for_op: dict[Operation, st.SearchStrategy[Callable[..., Any]]] + + def __init__(self): + self.strategy_for_op = {} + + @abstractmethod + def eq(self, a: Any, b: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> SearchStrategy: + raise NotImplementedError + + def _fresh_op( + self, + name: str, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> Operation: """Build a fresh, unhandled Operation whose parameter and return annotations are derived from this backend. @@ -275,81 +169,245 @@ def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation """ scalar = self.scalar_typ out = self.stream_typ if ret == "stream" else scalar - params = ", ".join(f"_a{i}" for i in range(n_args)) + params = ", ".join(f"_a{i}" for i in range(len(arg_types))) ns: dict[str, Any] = {"NotHandled": NotHandled} exec(f"def _fn({params}):\n raise NotHandled\n", ns) fn = ns["_fn"] fn.__annotations__ = { - **{f"_a{i}": scalar for i in range(n_args)}, + **{f"_a{i}": t for i, t in enumerate(arg_types)}, "return": out, } - return Operation.define(fn, name=name) + op = Operation.define(fn, name=name) + self.strategy_for_op[op] = self.strategy(arg_types, ret) + return op + + @overload + def define_vars(self, name: str, /, **kwargs) -> Operation: ... + + @overload + def define_vars( + self, n1: str, n2: str, /, *names: str, **kwargs + ) -> tuple[Operation, ...]: ... + + def define_vars(self, *names: str, **kwargs) -> Operation | tuple[Operation, ...]: # type: ignore[misc] + if len(names) == 1: + return self._fresh_op(names[0], **kwargs) + return tuple(self._fresh_op(n, **kwargs) for n in names) + + def check_rewrite( + self, + lhs, + rhs, + rule, + *, + max_examples: int = 25, + deadline=None, + normalize=NormalizeIntp, + ) -> None: + with handler(rule): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + fvs = fvsof(lhs) | fvsof(rhs) + + @st.composite + def random_interpretation( + draw: st.DrawFn, + ) -> Mapping[Operation, Callable[..., Any]]: + """Draw an Interpretation binding every Operation in `free_vars` to + a randomly chosen value/callable. Keys are Operation identities. + """ + intp: dict[Operation, Callable[..., Any]] = {} + for op, strategy in self.strategy_for_op.items(): + if op in fvs: + intp[op] = draw(strategy) + return intp + + @given(intp=random_interpretation()) + @settings( + max_examples=max_examples, deadline=deadline, report_multiple_bugs=False + ) + def _check_semantics(intp): + with handler(normalize), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert self.eq(lhs_val, rhs_val) + _check_semantics() -def _int_eq(a: Any, b: Any) -> bool: - return not isinstance(a, Term) and not isinstance(b, Term) and a == b +def _is_weighted(x: Any) -> bool: + return isinstance(x, Term) and _is_monoid_weighted(x.op) -def _jax_eq(a: Any, b: Any) -> bool: - def _leaf_eq(x: Any, y: Any) -> bool: - return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) - try: - leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) - except (ValueError, TypeError): +def _weight_pairs(x: Any, monoid: Any) -> list[tuple[Any, Any]] | None: + """Return ``(element, weight)`` pairs for a stream. + + A weighted-monoid Term yields each element paired with its weight. A plain + (unweighted) stream yields each element paired with ``monoid.identity`` -- + the no-op weight -- so an unweighted stream compares equal to a weighted one + exactly when every weight reduces to the identity (e.g. ``[()]`` vs a + weighted ``[()]`` whose single empty row reduces to the identity, and, more + generally, whenever both streams are empty). Returns ``None`` for a + non-stream Term, which never compares equal to a weighted stream. + """ + if isinstance(x, Term): + if not _is_monoid_weighted(x.op): + return None + stream, weight = x.args + assert not isinstance(stream, Term) + return [(e, typing.cast(Callable, weight)(e)) for e in stream] + return [(e, monoid.identity) for e in x] + + +def _weighted_stream_eq(a, b, leaf_eq: Callable[[Any, Any], bool]) -> bool: + monoids = {x.op.__self__ for x in (a, b) if _is_weighted(x)} + # distinct weight monoids can never be equal + if len(monoids) != 1: return False - return all(leaves) - - -def check_rewrite( - lhs, - rhs, - rule, - *, - backend: Backend, - free_vars=[], - max_examples: int = 25, - deadline=None, -) -> None: - with handler(rule): - norm = evaluate(lhs) - assert syntactic_eq_alpha(norm, rhs) - - @given(intp=random_interpretation(free_vars)) - @settings(max_examples=max_examples, deadline=deadline) - def _check_semantics(intp): - with handler(NormalizeIntp), handler(intp): - lhs_val = evaluate(lhs) - rhs_val = evaluate(rhs) - assert backend.eq(lhs_val, rhs_val) - - _check_semantics() - - -INT_BACKEND = Backend( - name="int", - scalar_typ=int, - stream_typ=list[int], - scalar_strategy=st.integers(min_value=-100, max_value=100), - eq=_int_eq, -) - - -JAX_BACKEND = Backend( - name="jax", - scalar_typ=jax.Array, - stream_typ=jax.Array, - scalar_strategy=_jax_array_value_strategy(), - eq=_jax_eq, -) - - -__all__ = [ - "Backend", - "INT_BACKEND", - "JAX_BACKEND", - "random_interpretation", - "define_vars", - "syntactic_eq_alpha", - "check_rewrite", -] + monoid = next(iter(monoids)) + + a_pairs = _weight_pairs(a, monoid) + b_pairs = _weight_pairs(b, monoid) + if a_pairs is None or b_pairs is None or len(a_pairs) != len(b_pairs): + return False + for (ea, wa), (eb, wb) in zip(a_pairs, b_pairs): + if not leaf_eq(ea, eb) or not leaf_eq(wa, wb): + return False + return True + + +class IntBackend(Backend): + name = "int" + scalar_typ = int + stream_typ = Stream[int] + + _unary_num_fns: list[Callable[[int], int]] = [ + lambda x: x, + lambda x: x + 1, + lambda x: x - 1, + lambda x: -x, + lambda x: 2 * x, + lambda x: 3 * x + 1, + ] + + _binary_num_fns: list[Callable[[int, int], int]] = [ + lambda x, y: x + y, + lambda x, y: x - y, + lambda x, y: x * y, + lambda x, y: x + 2 * y, + lambda x, y: 2 * x - y, + ] + + _unary_list_fns: list[Callable[[int], list[int]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], + lambda x: [0, x, x + 1], + ] + + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> SearchStrategy: + match arg_types, ret: + case (), "scalar": + return st.integers(min_value=-100, max_value=100).map(deffn) + case (), "stream": + scalars = st.integers(min_value=-100, max_value=100) + return st.lists(scalars, max_size=2).map(deffn) + case (builtins.int,), "scalar": + return st.sampled_from(self._unary_num_fns) + case (builtins.int, builtins.int), "scalar": + return st.sampled_from(self._binary_num_fns) + case (builtins.int,), "stream": + return st.sampled_from(self._unary_list_fns) + raise NotImplementedError( + f"No int strategy for op with return {ret!r} and {arg_types} args" + ) + + def eq(self, a: Any, b: Any) -> bool: + if _is_weighted(a) or _is_weighted(b): + return _weighted_stream_eq(a, b, self.eq) + return not isinstance(a, Term) and not isinstance(b, Term) and a == b + + +class JaxBackend(Backend): + name = "jax" + scalar_typ = jax.Array + stream_typ = jax.Array + + _unary_jax_scalar_fns: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: a, + lambda a: a + 1, + lambda a: a - 1, + lambda a: -a, + lambda a: 2 * a, + ] + + _unary_jax_stream_fns: list[Callable[[jax.Array], Stream[jax.Array]]] = [ + lambda a: _jnp.stack([a, a + 1]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1, 2 * a]), + ] + + _binary_jax_scalar_fns: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + ] + + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> st.SearchStrategy[Callable]: + match arg_types, ret: + case (), "scalar": + return ( + st.lists( + st.integers(min_value=-5, max_value=5), + min_size=2, + max_size=2, + ) + .map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + .map(deffn) + ) + case (), "stream": + return ( + st.lists( + st.integers(min_value=-5, max_value=5), + min_size=1, + max_size=2, + ) + .map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + .map(deffn) + ) + case (jax.Array,), "scalar": + return st.sampled_from(self._unary_jax_scalar_fns) + case (jax.Array, jax.Array), "scalar": + return st.sampled_from(self._binary_jax_scalar_fns) + case (jax.Array,), "stream": + return st.sampled_from(self._unary_jax_stream_fns) + + raise NotImplementedError( + f"No jax strategy for op with return {ret!r} and {arg_types} args" + ) + + def eq(self, a: Any, b: Any) -> bool: + if _is_weighted(a) or _is_weighted(b): + return _weighted_stream_eq(a, b, self.eq) + + def _leaf_eq(x: Any, y: Any) -> bool: + return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) + + try: + leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) + except (ValueError, TypeError): + return False + return all(leaves) + + +__all__ = ["Backend", "IntBackend", "JaxBackend", "syntactic_eq_alpha"] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index fe888ad4..18df8401 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -1,3 +1,6 @@ +import functools +import typing + import jax import pytest from jax import random as random @@ -7,15 +10,24 @@ from effectful.handlers.jax.monoid import ( ArrayReduce, LogSumExp, + ProductPlusJax, ReduceDeltaIndependent, ReduceDependentRangeMask, delta, ) from effectful.handlers.jax.monoid import range as Range from effectful.handlers.jax.scipy.special import logsumexp -from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Sum -from effectful.ops.semantics import handler -from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars +from effectful.ops.monoid import ( + Max, + Min, + NormalizeIntp, + Product, + ReduceWeightedStream, + Sum, +) +from effectful.ops.semantics import coproduct, handler +from effectful.ops.types import Interpretation +from tests._monoid_helpers import JaxBackend MONOIDS = [ pytest.param(Sum, jnp.sum, id="Sum"), @@ -27,28 +39,27 @@ @pytest.fixture -def backend() -> Backend: - return JAX_BACKEND +def backend() -> JaxBackend: + return JaxBackend() @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_1(monoid, reductor, backend: Backend): - (x, k) = define_vars("x", "k", typ=jax.Array) - X = define_vars("X", typ=backend.stream_typ) +def test_reduce_array_1(monoid, reductor, backend: JaxBackend): + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = backend.define_vars("X", ret="stream") lhs = monoid.reduce(x(), {x: X()}) rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ArrayReduce(), backend=backend, free_vars=[x, X, k] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_2(monoid, reductor, backend: Backend): - (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) - (X, Y) = define_vars("X", "Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_array_2(monoid, reductor, backend: JaxBackend): + (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") + (X, Y) = backend.define_vars("X", "Y", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) rhs = reductor( @@ -61,25 +72,20 @@ def test_reduce_array_2(monoid, reductor, backend: Backend): ), axis=0, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ArrayReduce(), - backend=backend, - free_vars=[x, y, k1, k2, X, Y, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_3(monoid, reductor, backend: Backend): +def test_reduce_array_3(monoid, reductor, backend: JaxBackend): """Stream `y` is `g(x())` — depends on the bound element of X. The reducer must inline ``g`` along the same named dim used to unbind `x`.""" - (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) - X = define_vars("X", typ=backend.stream_typ) + (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") + X = backend.define_vars("X", ret="stream") - f = backend.fresh_op("f", n_args=2, ret="scalar") - g = backend.fresh_op("g", n_args=1, ret="stream") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) + g = backend.define_vars("g", arg_types=[backend.scalar_typ], ret="stream") lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) rhs = reductor( @@ -95,13 +101,37 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): ), axis=0, ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) - check_rewrite( + +def test_jax_weighted_reduce(backend: JaxBackend): + """Sum over a single stream with ``Product`` weights lowers to + ``jnp.sum(w(X) * body(X))`` under ``NormalizeIntp`` ∘ ``ArrayReduce``. + + Verifies that the desugaring rule composes cleanly with the JAX lowering + so existing handlers need no changes to support weighted streams. + + """ + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = backend.define_vars("X", ret="stream") + body = backend.define_vars("body", arg_types=[backend.scalar_typ], ret="scalar") + w = backend.define_vars("w", arg_types=[backend.scalar_typ], ret="scalar") + + ws = Product.weighted(X(), w) + lhs = Sum.reduce(body(x()), {x: ws}) + rhs = jnp.sum( + bind_dims(w(unbind_dims(X(), k)) * body(unbind_dims(X(), k)), k), axis=0 + ) + backend.check_rewrite( lhs=lhs, rhs=rhs, - rule=ArrayReduce(), - backend=backend, - free_vars=[x, y, k1, k2, X, f, g], + rule=functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ReduceWeightedStream(), ArrayReduce(), ProductPlusJax()], + ), + ), ) @@ -112,62 +142,51 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_empty(monoid, reductor, backend: Backend): +def test_reduce_delta_empty(monoid, reductor, backend: JaxBackend): """An empty-index delta unwraps to its body. reduce(M, streams, delta((), body)) ≡ reduce(M, streams, body) """ - x = define_vars("x", typ=backend.scalar_typ) - X = define_vars("X", typ=backend.stream_typ) + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") lhs = monoid.reduce(delta((), x()), {x: X()}) rhs = monoid.reduce(x(), {x: X()}) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[x, X], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_independent_one(monoid, reductor, backend: Backend): +def test_reduce_delta_independent_one(monoid, reductor, backend: JaxBackend): """One R1 step: peel the final preserved index off a delta. reduce(M, {y: Y()}, delta((y(),), f(y()))) ≡ reduce(M, {}, delta((), bind_dims(f(unbind_dims(Y(), k)), k))) """ - (y, k) = define_vars("y", "k", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") + (y, k) = backend.define_vars("y", "k", ret="scalar") + f = backend.define_vars("f", arg_types=[backend.scalar_typ], ret="scalar") # We use a concrete range here instead of an abstract one, because # unbind_dims is undefined on empty arrays (and the rewrite produces a # different rhs in this case) lhs = monoid.reduce(delta((y(),), f(y())), {y: Range(3)}) rhs = monoid.reduce(bind_dims(f(unbind_dims(jnp.arange(3), k)), k), {}) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[y, k, Y, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_independent_preserves_others(monoid, reductor, backend: Backend): +def test_reduce_delta_independent_preserves_others( + monoid, reductor, backend: JaxBackend +): """R1 peels only the final index. Streams not matching the peeled index op stay untouched, as do earlier entries in the index tuple. reduce(M, {x: X(), y: Y()}, delta((x(), y()), f(x(), y()))) ≡ reduce(M, {x: X()}, delta((x(),), bind_dims(f(x(), unbind_dims(Y(), k)), k))) """ - (x, y, k) = define_vars("x", "y", "k", typ=backend.scalar_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") + (x, y, k) = backend.define_vars("x", "y", "k", ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) lhs = monoid.reduce(delta((x(), y()), f(x(), y())), {x: Range(2), y: Range(3)}) rhs = monoid.reduce( @@ -179,27 +198,22 @@ def test_reduce_delta_independent_preserves_others(monoid, reductor, backend: Ba ), {}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_dependent_range_mask(monoid, reductor, backend: Backend): +def test_reduce_dependent_range_mask(monoid, reductor, backend: JaxBackend): """A dependent range stream gets rewritten to the referent's bbox stream, with the original constraint folded into the body as a where-guard. reduce(M, {u: range(0, N, 1), v: range(0, u(), 1)}, body) ≡ reduce(M, {u: range(0, N, 1), v: range(0, N, 1)}, where(v() < u(), body, M.identity)) """ - (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + (u, v) = backend.define_vars("u", "v", ret="scalar") N = 5 - f = backend.fresh_op("f", n_args=2, ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) body = f(u(), v()) @@ -208,18 +222,11 @@ def test_reduce_dependent_range_mask(monoid, reductor, backend: Backend): jnp.where(v() < u(), body, monoid.identity), {u: Range(0, N, 1), v: Range(0, N, 1)}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDependentRangeMask(), - backend=backend, - free_vars=[u, v, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backend): +def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: JaxBackend): """When the body is a delta term, R4 folds the constraint into the delta's weight while leaving its index tuple untouched. @@ -227,9 +234,11 @@ def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backe ≡ reduce(M, {u: range(N), v: range(N)}, delta((u(), v()), where(v() < u(), w, M.identity))) """ - (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + (u, v) = backend.define_vars("u", "v", ret="scalar") N = 5 - f = backend.fresh_op("f", n_args=2, ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) weight = f(u(), v()) idx = (u(), v()) @@ -239,17 +248,10 @@ def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backe delta(idx, jnp.where(v() < u(), weight, monoid.identity)), {u: Range(0, N, 1), v: Range(0, N, 1)}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDependentRangeMask(), - backend=backend, - free_vars=[u, v, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) -def test_reduce_matmul(): +def test_reduce_matmul(backend: JaxBackend): key = jax.random.PRNGKey(0) # Define dimensions B, I, J, K = 2, 3, 4, 5 @@ -257,7 +259,7 @@ def test_reduce_matmul(): # Create sample matrices X = random.normal(key, (B, I, J)) Y = random.normal(key, (B, J, K)) - (b, i, j, k) = define_vars("b", "i", "j", "k", typ=jax.Array) + (b, i, j, k) = backend.define_vars("b", "i", "j", "k", ret="scalar") with handler(NormalizeIntp): actual = Sum.reduce( diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 8b93976f..8abfdc5a 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -1900,3 +1900,10 @@ class Info(typing.TypedDict): subs = unify(collections.abc.Mapping, Info) assert subs == {} + + +def test_unify_jax_array_iterable(): + import jax + + subs = unify(collections.abc.Iterable[T], jax.Array) + assert subs == {T: jax.Array} diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index c7ee7567..fcd72f06 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,10 +1,13 @@ +import math import typing +from collections.abc import Iterable import pytest from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st import effectful.handlers.jax.monoid # noqa: F401 +import effectful.handlers.jax.numpy as jnp from effectful.ops.monoid import ( CartesianProduct, Max, @@ -22,29 +25,25 @@ PlusSingle, PlusZero, Product, + ReduceCartesianWeightedStream, ReduceDistributeCartesianProduct, ReduceFactorization, ReduceFusion, ReduceNoStreams, ReduceSplit, + ReduceWeightedStream, Sum, distributes_over, ) -from effectful.ops.semantics import fvsof, handler -from effectful.ops.types import Operation -from tests._monoid_helpers import ( - INT_BACKEND, - JAX_BACKEND, - Backend, - check_rewrite, - define_vars, - syntactic_eq_alpha, -) +from effectful.ops.semantics import coproduct, evaluate, fvsof, handler +from effectful.ops.syntax import deffn +from effectful.ops.types import NotHandled, Operation, Term +from tests._monoid_helpers import Backend, IntBackend, JaxBackend, syntactic_eq_alpha -@pytest.fixture(params=[INT_BACKEND, JAX_BACKEND], ids=["int", "jax"]) +@pytest.fixture(params=[IntBackend, JaxBackend], ids=["int", "jax"]) def backend(request) -> Backend: - return request.param + return request.param() ALL_MONOIDS = [ @@ -90,10 +89,10 @@ def backend(request) -> Backend: deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_associativity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) - b = data.draw(backend.scalar_strategy) - c = data.draw(backend.scalar_strategy) +def test_associativity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + b = data.draw(backend.strategy(ret="scalar"))() + c = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): left = monoid.plus(monoid.plus(a, b), c) right = monoid.plus(a, monoid.plus(b, c)) @@ -107,8 +106,8 @@ def test_associativity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_identity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_identity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(monoid.identity, a), a) assert backend.eq(monoid.plus(a, monoid.identity), a) @@ -121,9 +120,9 @@ def test_identity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_commutativity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) - b = data.draw(backend.scalar_strategy) +def test_commutativity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + b = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) @@ -135,8 +134,8 @@ def test_commutativity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_idempotence(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_idempotence(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(a, a), a) @@ -148,102 +147,86 @@ def test_idempotence(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_zero_absorbs(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_zero_absorbs(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(monoid.zero, a), monoid.zero) assert backend.eq(monoid.plus(a, monoid.zero), monoid.zero) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_empty(monoid, backend): - check_rewrite( - lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty(), backend=backend - ) +def test_plus_empty(monoid, backend: Backend): + backend.check_rewrite(lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_single(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) - check_rewrite( - lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle(), backend=backend, free_vars=[x] - ) +def test_plus_single(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + backend.check_rewrite(lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_right(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) +def test_plus_identity_right(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.plus(x(), monoid.identity) rhs = monoid.plus(x()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_left(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) +def test_plus_identity_left(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.plus(monoid.identity, x()) rhs = monoid.plus(x()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_right(monoid, backend): - x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - check_rewrite( +def test_plus_assoc_right(monoid, backend: Backend): + x, y, z = backend.define_vars("x", "y", "z", ret="scalar") + backend.check_rewrite( lhs=monoid.plus(x(), monoid.plus(y(), z())), rhs=monoid.plus(x(), y(), z()), rule=PlusAssoc(), - backend=backend, - free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_left(monoid, backend): - x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - check_rewrite( +def test_plus_assoc_left(monoid, backend: Backend): + x, y, z = backend.define_vars("x", "y", "z", ret="scalar") + backend.check_rewrite( lhs=monoid.plus(monoid.plus(x(), y()), z()), rhs=monoid.plus(x(), y(), z()), rule=PlusAssoc(), - backend=backend, - free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_sequence(monoid, backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) - check_rewrite( +def test_plus_sequence(monoid, backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") + backend.check_rewrite( lhs=monoid.plus((a(), b()), (c(), d())), rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), rule=MonoidOverSequence(), - backend=backend, - free_vars=[a, b, c, d], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_mapping(monoid, backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_mapping(monoid, backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}) rhs = {0: monoid.plus(a(), c()), 1: monoid.plus(b()), 2: monoid.plus(d())} - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverMapping(), - backend=backend, - free_vars=[a, b, c, d], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverMapping()) -def test_plus_distributes(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) rhs = Product.plus( Sum.plus( @@ -253,13 +236,11 @@ def test_plus_distributes(backend): Product.plus(b(), d()), ) ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) -def test_plus_distributes_constant(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes_constant(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) rhs = Product.plus( 5, @@ -270,13 +251,11 @@ def test_plus_distributes_constant(backend): Product.plus(b(), d()), ), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) -def test_plus_distributes_multiple(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes_multiple(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Sum.plus( Min.plus(a(), b()), Min.plus(c(), d()), @@ -297,238 +276,195 @@ def test_plus_distributes_multiple(backend): Sum.plus(b(), d()), ), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_consecutive(monoid, backend): +def test_plus_idempotent_consecutive(monoid, backend: Backend): """``a, a, b → a, b`` — only consecutive duplicates collapse.""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), a(), b()) - return check_rewrite( - lhs=lhs, - rhs=monoid.plus(a(), b()), - rule=PlusConsecutiveDups(), - backend=backend, - free_vars=[a, b], + return backend.check_rewrite( + lhs=lhs, rhs=monoid.plus(a(), b()), rule=PlusConsecutiveDups() ) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_non_consecutive(monoid, backend): +def test_plus_idempotent_non_consecutive(monoid, backend: Backend): """``a, b, a`` — Semilattice (Min/Max) collapses via commutative PlusDups.""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), b(), a()) rhs = monoid.plus(a(), b()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups()) @pytest.mark.parametrize("monoid", [Min, Max]) -def test_plus_commutative_idempotent_long(monoid, backend): +def test_plus_commutative_idempotent_long(monoid, backend: Backend): """Long alternation collapses via commutative dedup (Min/Max only).""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), b(), a(), b(), b(), a(), a()) rhs = monoid.plus(a(), b()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups()) @pytest.mark.parametrize("monoid", WITH_ZERO) -def test_plus_zero(monoid, backend): - a = define_vars("a", typ=backend.scalar_typ) +def test_plus_zero(monoid, backend: Backend): + a = backend.define_vars("a", ret="scalar") lhs_right = monoid.plus(a(), monoid.zero) lhs_left = monoid.plus(monoid.zero, a()) rhs = monoid.zero - check_rewrite( - lhs=lhs_right, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] - ) - check_rewrite( - lhs=lhs_left, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] - ) + backend.check_rewrite(lhs=lhs_right, rhs=rhs, rule=PlusZero()) + backend.check_rewrite(lhs=lhs_left, rhs=rhs, rule=PlusZero()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_1(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) +def test_partial_1(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.reduce(x(), {x: []}) rhs = monoid.identity - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_2(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) +def test_partial_2(monoid, backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + Y = backend.define_vars("Y", ret="stream") lhs = monoid.reduce(x(), {y: Y(), x: []}) rhs = monoid.identity - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, Y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_3(monoid, backend): - x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) +def test_partial_3(monoid, backend: Backend): + x, y, a, b = backend.define_vars("x", "y", "a", "b", ret="scalar") + Y = backend.define_vars("Y", ret="stream") lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, Y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_4(monoid, backend): - x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) - f = backend.fresh_op("f", n_args=1, ret="stream") +def test_partial_4(monoid, backend: Backend): + x, y, a, b = backend.define_vars("x", "y", "a", "b", ret="scalar") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="stream") lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, f]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence(monoid, backend): - x = Operation.define(backend.scalar_typ, name="x") - X = Operation.define(backend.stream_typ, name="X") - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_sequence(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce((f(x()), g(x())), {x: X()}) rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverSequence(), - backend=backend, - free_vars=[X, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverSequence()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence_2(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) - X, Y = define_vars("X", "Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_sequence_2(monoid, backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + X, Y = backend.define_vars("X", "Y", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) rhs = ( monoid.reduce(f(x()), {x: X(), y: Y()}), monoid.reduce(g(y()), {x: X(), y: Y()}), ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverSequence(), - backend=backend, - free_vars=[X, Y, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverSequence()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_mapping(monoid, backend): - x = Operation.define(backend.scalar_typ, name="x") - X = Operation.define(backend.stream_typ, name="X") - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_mapping(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce({0: f(x()), 1: g(x())}, {x: X()}) rhs = { 0: monoid.reduce(f(x()), {x: X()}), 1: monoid.reduce(g(x()), {x: X()}), } - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverMapping(), - backend=backend, - free_vars=[X, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverMapping()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_no_streams(monoid, backend): - a = define_vars("a", typ=backend.scalar_typ) +def test_reduce_no_streams(monoid, backend: Backend): + a = backend.define_vars("a", ret="scalar") + lhs = monoid.reduce(a(), {}) rhs = monoid.identity - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceNoStreams(), backend=backend, free_vars=[a] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceNoStreams()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_reduce(monoid, backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_reduce(monoid, backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceFusion(), backend=backend, free_vars=[A, B, f] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFusion()) @pytest.mark.parametrize("monoid", COMMUTATIVE) -def test_reduce_plus(monoid, backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) +def test_reduce_plus(monoid, backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) rhs = monoid.plus( monoid.reduce(a(), {a: A(), b: B()}), monoid.reduce(b(), {a: A(), b: B()}), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceSplit(), backend=backend, free_vars=[A, B] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceSplit()) -def test_reduce_independent_1(backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) +def test_reduce_independent_1(backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) rhs = Product.plus( Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b()), {b: B()}) ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceFactorization(), backend=backend, free_vars=[A, B] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) -def test_reduce_independent_2(backend): - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_independent_2(backend: Backend): + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) rhs = Product.plus( Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceFactorization(), - backend=backend, - free_vars=[A, B, C, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) -def test_reduce_independent_3_negative(backend): +def test_reduce_independent_3_negative(backend: Backend): """Stream `b` depends on `a` (b: g(a())), so the proposed factorization is unsound — the normalizer must NOT apply it.""" - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, C = define_vars("A", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") - g = backend.fresh_op("g", n_args=1, ret="stream") + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, C = backend.define_vars("A", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + g = backend.define_vars("g", arg_types=(backend.scalar_typ,), ret="stream") with handler(ReduceFactorization()): # ty:ignore[invalid-argument-type] lhs = Sum.reduce( @@ -542,10 +478,12 @@ def test_reduce_independent_3_negative(backend): assert not syntactic_eq_alpha(lhs, bogus_rhs) -def test_reduce_independent_4(backend): - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_independent_4(backend: Backend): + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) rhs = Product.plus( @@ -553,39 +491,44 @@ def test_reduce_independent_4(backend): Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceFactorization(), - backend=backend, - free_vars=[A, B, C, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +def test_reduce_cartesian_3(): + backend = JaxBackend() + i = backend.define_vars("i", ret="scalar") + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(2), {i: jnp.arange(3)}) + assert value.shape == (2**3, 3) + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(2), {i: jnp.arange(1)}) + assert value.shape == (2**1, 1) + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(1), {i: jnp.arange(3)}) + assert value.shape == (1**3, 3) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_1(outer, inner, backend): - a, i = define_vars("a", "i", typ=backend.scalar_typ) - A, N, A_domain = define_vars("A", "N", "A_domain", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") +def test_reduce_lifted_1(outer, inner, backend: Backend): + a, i = backend.define_vars("a", "i", ret="scalar") + A, N, A_domain = backend.define_vars("A", "N", "A_domain", ret="stream") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N()})}, ) - term2 = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) - - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[N, A_domain, f], - ) + rhs = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) def test_reduce_cartesian_1(): - a, i = define_vars("a", "i", typ=int) - A = define_vars("A", typ=tuple[int]) + backend = IntBackend() + a, i = backend.define_vars("a", "i", ret="scalar") + A = backend.define_vars("A", ret="stream") with handler(NormalizeIntp): term1 = Sum.reduce( @@ -597,8 +540,9 @@ def test_reduce_cartesian_1(): def test_reduce_cartesian_2(): - a, i = define_vars("a", "i", typ=int) - A = define_vars("A", typ=tuple[int]) + backend = IntBackend() + a, i = backend.define_vars("a", "i", ret="scalar") + A = backend.define_vars("A", ret="stream") with handler(NormalizeIntp): term1 = Sum.reduce( @@ -610,46 +554,41 @@ def test_reduce_cartesian_2(): @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_multi_index(outer, inner, backend): - a, i, j = define_vars("a", "i", "j", typ=backend.scalar_typ) - A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") +def test_reduce_lifted_multi_index(outer, inner, backend: Backend): + a, i, j = backend.define_vars("a", "i", "j", ret="scalar") + A, N, M, A_domain = backend.define_vars("A", "N", "M", "A_domain", ret="stream") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, ) - term2 = inner.reduce( - outer.reduce(inner.plus(f(a())), {a: A_domain()}), - {i: N(), j: M()}, - ) - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[N, M, A_domain, f], + rhs = inner.reduce( + outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N(), j: M()} ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_2(outer, inner, backend): +def test_reduce_lifted_2(outer, inner, backend: Backend): """The worked example on page 396 of 'Lifted Variable Elimination: Decoupling the Operators from the Constraint Language'. """ - a, i, s, t = define_vars("a", "i", "s", "t", typ=backend.scalar_typ) - A, N, T = define_vars("A", "N", "T", typ=backend.stream_typ) - A_domain = backend.fresh_op("A_domain", n_args=1, ret="stream") - f1 = backend.fresh_op("f1", n_args=2, ret="scalar") - f2 = backend.fresh_op("f2", n_args=2, ret="scalar") + a, i, s, t = backend.define_vars("a", "i", "s", "t", ret="scalar") + A, N, T = backend.define_vars("A", "N", "T", ret="stream") + A_domain = backend.define_vars( + "A_domain", arg_types=(backend.scalar_typ,), ret="stream" + ) + f1, f2 = backend.define_vars( + "f1", "f2", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, ) - - term2 = outer.reduce( + rhs = outer.reduce( inner.reduce( outer.reduce( inner.plus(inner.plus(f1(a(), s()), f2(t(), a()))), {a: A_domain(i())} @@ -658,11 +597,143 @@ def test_reduce_lifted_2(outer, inner, backend): ), {t: T()}, ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) + + +# --------------------------------------------------------------------------- +# Weighted streams +# --------------------------------------------------------------------------- + + +def test_reduce_single_weighted_stream(backend: Backend): + """Single weighted stream desugars: + Sum.reduce(body, {a: WS(A, w, Product)}) + = Sum.reduce(Product.plus(w(a), body), {a: A}) + """ + a = backend.define_vars("a", ret="scalar") + A = backend.define_vars("A", ret="stream") + body, w = backend.define_vars( + "body", "w", arg_types=(backend.scalar_typ,), ret="scalar" + ) + + lhs = Sum.reduce(body(a()), {a: Product.weighted(A(), w)}) + rhs = Sum.reduce(Product.plus(w(a()), body(a())), {a: A()}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceWeightedStream()) + + +def test_reduce_weighted_factorization(backend: Backend): + """Two independent weighted streams under Sum with Product weights factor: + Sum.reduce(f(a)*g(b), {a: Product.weighted(A, a, w_a), b: Product.weighted(B, b, w_b)}) + = (Sum.reduce(w_a(a)*f(a), {a: A})) * (Sum.reduce(w_b(b)*g(b), {b: B})) + + Exercises chaining of ``ReduceWeightedStream`` with ``ReduceFactorization`` + inside ``NormalizeIntp``. + """ + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + f, g, w_a, w_b = backend.define_vars( + "f", "g", "w_a", "w_b", arg_types=(backend.scalar_typ,), ret="scalar" + ) + + lhs = Sum.reduce( + Product.plus(f(a()), g(b())), + {a: Product.weighted(A(), w_a), b: Product.weighted(B(), w_b)}, + ) + rhs = Product.plus( + Sum.reduce(Product.plus(w_a(a()), Product.plus(f(a()))), {a: A()}), + Sum.reduce(Product.plus(w_b(b()), Product.plus(g(b()))), {b: B()}), + ) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceWeightedStream(), ReduceFactorization()) + ) + + +def test_reduce_cartesian_weighted_stream(backend: Backend): + """``CartesianProduct.reduce`` over a ``WeightedStream`` body whose weight + is independent of the plate var rewrites to a single joint + ``WeightedStream``: + + CartesianProduct.reduce(M.weighted(s, e, w(e)), {p: P}) + = M.weighted(CartesianProduct.reduce(s, {p: P}), row, M.reduce(w(e), {e: row()})) + """ + p, e_var = backend.define_vars("p", "e_var", ret="scalar") + S, P = backend.define_vars("S", "P", ret="stream") + w = backend.define_vars("w", arg_types=(backend.scalar_typ,), ret="scalar") + + lhs = CartesianProduct.reduce(Product.weighted(S(), w), {p: P()}) + row_var = Operation.define(Iterable[backend.scalar_typ], name="row") # type: ignore[name-defined] + rhs = Product.weighted( + CartesianProduct.reduce(S(), {p: P()}), + deffn(Product.reduce(w(e_var()), {e_var: row_var()}), row_var), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceCartesianWeightedStream()) + + +def test_lift_weighted_cartesian(backend: Backend): + """Compose ``ReduceCartesianWeightedStream`` + ``ReduceWeightedStream`` + + ``ReduceDistributeCartesianProduct`` on a Sum-of-Product-of-weighted shape: + + Sum.reduce( + Product.reduce(body(a()), {a: A()}), + {A: CartesianProduct.reduce(Product.weighted(S, e, w(e)), {p: P})}, + ) + + The inner ``weighted`` becomes a joint ``weighted`` (rule 1), lifts its + per-element weight into the outer Sum body (rule 2), and the lifted form + matches the inversion pattern (rule 3), yielding:: + + Product.reduce( + Sum.reduce(Product.plus(w(a()), body(a())), {a: S}), + {p: P}, + ) + """ + a, p = backend.define_vars("a", "p", ret="scalar") + A, S, P = backend.define_vars("A", "S", "P", ret="stream") + body, w = backend.define_vars( + "body", "w", arg_types=(backend.scalar_typ,), ret="scalar" + ) - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2], + lhs = Sum.reduce( + Product.reduce(body(a()), {a: A()}), + {A: CartesianProduct.reduce(Product.weighted(S(), w), {p: P()})}, + ) + rhs = Product.reduce( + Sum.reduce(Product.plus(w(a()), body(a())), {a: S()}), {p: P()} ) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=coproduct( + coproduct(ReduceWeightedStream(), ReduceCartesianWeightedStream()), + ReduceDistributeCartesianProduct(), + ), + ) + + +def test_weighted_expectation_demo(): + """Demo: compute E[f(X)] = Σ_x w(x)·f(x) via a weighted reduce. + + X ranges over [1, 2, 3, 4] with weights w(x) = x/10 (a valid distribution + since the weights sum to 1) and f(x) = x*x. Expected value: + 0.1·1 + 0.2·4 + 0.3·9 + 0.4·16 = 10.0 + """ + weights = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4} + + def _w(v: int) -> float: + if isinstance(v, Term): + raise NotHandled + return weights[v] + + def _f(v: int) -> float: + if isinstance(v, Term): + raise NotHandled + return float(v * v) + + a = Operation.define(int, name="a") + w = Operation.define(_w, name="w") + f = Operation.define(_f, name="f") + + with handler(NormalizeIntp): + result = evaluate(Sum.reduce(f(a()), {a: Product.weighted([1, 2, 3, 4], w)})) + + assert math.isclose(result, 10.0)