Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions effectful/handlers/jax/_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@
unbind_dims,
)
from effectful.internals.tensor_utils import _desugar_tensor_index
from effectful.internals.unification import Box, nested_type
from effectful.ops.syntax import defdata
from effectful.ops.types import Expr, NotHandled, Operation, Term


@nested_type.register(jax.Array)
@nested_type.register(jax._src.core.Tracer)
def _(value):
return Box(jax.Array)


class _IndexUpdateHelper:
"""Helper class to implement array-style .at[index].set() updates for effectful arrays."""

Expand Down
15 changes: 14 additions & 1 deletion effectful/handlers/jax/monoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Product,
Streams,
Sum,
distributes_over,
outer_stream,
)
from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof
Expand All @@ -38,6 +39,10 @@ def cartesian_prod(x, y):

LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf")))

# ``Sum`` in log space is multiplication, which distributes over ``LogSumExp``:
# a + logsumexp(b, c) = logsumexp(a + b, a + c)
distributes_over.register(Sum, LogSumExp)


def _jax_args(args):
"""True iff ``args`` is non-empty and every arg is a concrete
Expand Down Expand Up @@ -108,7 +113,15 @@ def plus(self, *args):
if not isinstance(a, jax.Array):
return fwd()
result = a if result is None else cartesian_prod(result, a)
return result if result is not None else CartesianProduct.identity
if result is None:
return CartesianProduct.identity
# CartesianProduct values are streams of rows. ``cartesian_prod``
# already lifts 1D inputs to 2D, but a single-array call seeds
# ``result = a`` unchanged — promote so the rank invariant holds for
# every array-path return.
if result.ndim == 1:
result = result[:, None]
return result


ARRAY_REDUCTORS = {
Expand Down
11 changes: 11 additions & 0 deletions effectful/internals/unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,17 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions:
and issubclass(subtyp, typing.get_origin(typ))
):
return subs # implicit expansion to subtyp[Any]
elif isinstance(typ, GenericAlias):
# Special case for treating arrays as iterables of arrays
try:
import jax

if typing.get_origin(typ) is collections.abc.Iterable and issubclass(
subtyp, jax.Array
):
return unify(typing.get_args(typ)[0], jax.Array, subs)
except ImportError:
pass
raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.")


Expand Down
125 changes: 114 additions & 11 deletions effectful/ops/monoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
from typing import Annotated, Any

from effectful.internals.disjoint_set import DisjointSet
from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler
from effectful.ops.semantics import (
coproduct,
evaluate,
fvsof,
fwd,
handler,
typeof,
)
from effectful.ops.syntax import (
ObjectInterpretation,
Scoped,
Expand All @@ -19,11 +26,17 @@
syntactic_eq,
syntactic_hash,
)
from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term
from effectful.ops.types import (
Expr,
Interpretation,
NotHandled,
Operation,
Term,
)

type Stream[T] = Iterable[T]

# Note: The streams value type should be something like Iterable[T], but some of
# our target stream types (e.g. jax.Array) are not subtypes of Iterable
type Streams[T] = Mapping[Operation[[], T], Any]
type Streams = Mapping[Operation[[], Any], Stream[Any]]

type Body[T] = (
Iterable[T]
Expand All @@ -34,9 +47,7 @@
)


def outer_stream(
streams: Streams,
) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]:
def outer_stream(streams: Streams) -> Iterable[tuple[Operation, Stream, Streams]]:
"""Returns the streams that can be ordered outermost in the loop nest as
well as the remaining streams in the nest.

Expand All @@ -51,13 +62,13 @@ def outer_stream(
)


class Monoid[T]:
class Monoid[W]:
"""A monoid with ``plus`` and ``reduce`` :class:`Operation` s."""

_name: str
identity: T
identity: W

def __init__(self, identity: T, name: str):
def __init__(self, identity: W, name: str):
self._name = name
self.identity = identity

Expand Down Expand Up @@ -111,6 +122,18 @@ def reduce[A, B, U: Body](
return self.plus(*new_reduces)
raise NotHandled

@Operation.define
def weighted[T](
self, stream: Stream[T], weight: Callable[[T], W] | Operation[[T], W]
) -> Stream[T]:
"""A stream paired with a per-element weight. ``var`` is an
:class:`Operation` standing for "an element of ``stream``"; ``weight``
is an expression that uses ``var`` and evaluates to the weight of that
element.

"""
raise NotHandled


class MonoidWithZero[T](Monoid[T]):
zero: T
Expand Down Expand Up @@ -175,6 +198,12 @@ def _is_monoid_reduce(op: Operation) -> bool:
return isinstance(owner, Monoid) and op is owner.reduce


def _is_monoid_weighted(op: Operation) -> bool:
"""True if ``op`` is the ``weighted`` operation of some :class:`Monoid`."""
owner = getattr(op, "__self__", None)
return isinstance(owner, Monoid) and op is owner.weighted


class PlusEmpty(ObjectInterpretation):
"""plus() = 0"""

Expand Down Expand Up @@ -557,6 +586,78 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams):
return fwd()


class ReduceWeightedStream(ObjectInterpretation):
"""reduce(M, body, {x: WM.weighted(s, v, w), ...}) = reduce(M, WM.plus(w[v:=x()], body), {x: s, ...})

requires distributes_over(WM, M).

The substitution ``v -> x`` is done by beta-reducing ``deffn(w, v)`` on
``x()`` — symbolic, no Python dispatch on the weight expression.
"""

@implements(Monoid.reduce)
def reduce(self, monoid, body, streams):
for k, v in streams.items():
if isinstance(v, Term) and _is_monoid_weighted(v.op):
v_stream, v_weight = v.args
v_monoid = v.op.__self__
if not distributes_over(v_monoid, monoid):
continue
w_at_k = v_weight(k())
weighted_body = v_monoid.plus(w_at_k, body)
new_streams = {**streams, k: v_stream}
return monoid.reduce(weighted_body, new_streams)
return fwd()


class ReduceCartesianWeightedStream(ObjectInterpretation):
"""``CartesianProduct.reduce`` over a :func:`weighted` body whose
``weight`` is independent of the plate (product-index) streams::

CartesianProduct.reduce(M.weighted(s, w), plates)
= M.weighted(
CartesianProduct.reduce(s, plates),
deffn(M.reduce(w, {e: row()}), row),
)

Reuses ``body``'s element binder ``e`` (already typed by construction);
introduces a fresh ``row`` binder typed as ``Iterable[elem_type]``.

Only fires when ``w`` is independent of the plate vars.
"""

@Operation.define
@staticmethod
def _iterable_elem[T](iter: Iterable[T]) -> T:
raise NotHandled

@implements(Monoid.reduce)
def reduce(self, monoid, body, streams):
if monoid is not CartesianProduct:
return fwd()
if not (isinstance(body, Term) and _is_monoid_weighted(body.op)):
return fwd()

s, w = body.args
if not isinstance(s, Term) and len(s) == 0:
return CartesianProduct.reduce([], streams)

if set(streams.keys()) & fvsof(w):
return fwd()

elem_typ = typeof(self._iterable_elem(s))
elem_op = Operation.define(elem_typ, name="elem")
row_op = Operation.define(Iterable[elem_typ], name="row")

weight_monoid = body.op.__self__
joint_weight = deffn(
weight_monoid.reduce(w(elem_op()), {elem_op: row_op()}), row_op
)
joint_stream = CartesianProduct.reduce(s, streams)

return weight_monoid.weighted(joint_stream, joint_weight)


class MonoidOverCallable(ObjectInterpretation):
"""``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``."""

Expand Down Expand Up @@ -751,6 +852,8 @@ def extend(self, *intps: Interpretation) -> typing.Self:
ReduceSplit(),
ReduceFactorization(),
ReduceDistributeCartesianProduct(),
ReduceWeightedStream(),
ReduceCartesianWeightedStream(),
PlusEmpty(),
PlusSingle(),
PlusIdentity(),
Expand Down
9 changes: 8 additions & 1 deletion effectful/ops/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,13 @@ def _evaluate_list_view(expr, **kwargs):

def _simple_type(tp: type) -> type:
"""Convert a type object into a type that can be dispatched on."""

def _resolve_aliases(tp: type) -> type:
tp = typing.get_origin(tp) or tp
if isinstance(tp, typing.TypeAliasType):
return _resolve_aliases(tp.__value__)
return tp

if isinstance(tp, typing.TypeVar):
tp = (
tp.__bound__
Expand All @@ -304,7 +311,7 @@ def _simple_type(tp: type) -> type:
tp = functools.reduce(operator.or_, (type(arg) for arg in args))
if isinstance(tp, types.UnionType):
raise TypeError(f"Union types are not supported: {tp}")
return typing.get_origin(tp) or tp
return _resolve_aliases(tp)


def typeof[T](term: Expr[T]) -> type[T]:
Expand Down
Loading
Loading