diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 67b4ca50..5f055a2c 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -17,7 +17,10 @@ AbstractBrownianIncrement as AbstractBrownianIncrement, AbstractSpaceTimeLevyArea as AbstractSpaceTimeLevyArea, AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea, + AbstractWeakSpaceSpaceLevyArea as AbstractWeakSpaceSpaceLevyArea, BrownianIncrement as BrownianIncrement, + DavieFosterWeakSpaceSpaceLevyArea as DavieFosterWeakSpaceSpaceLevyArea, + DavieWeakSpaceSpaceLevyArea as DavieWeakSpaceSpaceLevyArea, SpaceTimeLevyArea as SpaceTimeLevyArea, SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea, ) diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index 0333caa5..b2d7f0bc 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -14,6 +14,8 @@ from .._custom_types import ( AbstractBrownianIncrement, BrownianIncrement, + DavieFosterWeakSpaceSpaceLevyArea, + DavieWeakSpaceSpaceLevyArea, levy_tree_transpose, RealScalarLike, SpaceTimeLevyArea, @@ -27,6 +29,15 @@ from .base import AbstractBrownianPath +_Levy_Areas = Union[ + BrownianIncrement, + SpaceTimeLevyArea, + SpaceTimeTimeLevyArea, + DavieWeakSpaceSpaceLevyArea, + DavieFosterWeakSpaceSpaceLevyArea, +] + + class UnsafeBrownianPath(AbstractBrownianPath): """Brownian simulation that is only suitable for certain cases. @@ -62,18 +73,14 @@ class UnsafeBrownianPath(AbstractBrownianPath): """ shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) - levy_area: type[ - Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea] - ] = eqx.field(static=True) + levy_area: type[_Levy_Areas] = eqx.field(static=True) key: PRNGKeyArray def __init__( self, shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]], key: PRNGKeyArray, - levy_area: type[ - Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea] - ] = BrownianIncrement, + levy_area: type[_Levy_Areas] = BrownianIncrement, ): self.shape = ( jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype()) @@ -141,9 +148,7 @@ def _evaluate_leaf( t1: RealScalarLike, key, shape: jax.ShapeDtypeStruct, - levy_area: type[ - Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea] - ], + levy_area: type[_Levy_Areas], use_levy: bool, ): w_std = jnp.sqrt(t1 - t0).astype(shape.dtype) @@ -158,6 +163,57 @@ def _evaluate_leaf( kk = jr.normal(key_kk, shape.shape, shape.dtype) * kk_std levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk) + elif levy_area is DavieWeakSpaceSpaceLevyArea: + key_w, key_hh, key_b = jr.split(key, 3) + w = jr.normal(key_w, shape.shape, shape.dtype) * w_std + hh_std = w_std / math.sqrt(12) + hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std + if w.ndim == 0 or w.ndim == 1: + a = jnp.zeros_like(w, dtype=shape.dtype) + levy_val = DavieWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) + else: + b_std = (dt / jnp.sqrt(12)).astype(shape.dtype) + b = ( + jr.normal(key_b, shape.shape + shape.shape[-1:], shape.dtype) + * b_std + ) + b = b - b.transpose(*range(b.ndim - 2), -1, -2) + a = jnp.expand_dims(hh, -1) * jnp.expand_dims(w, -2) - jnp.expand_dims( + w, -1 + ) * jnp.expand_dims(hh, -2) + a += b + levy_val = DavieWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) + + elif levy_area is DavieFosterWeakSpaceSpaceLevyArea: + key_w, key_hh, key_b = jr.split(key, 3) + w = jr.normal(key_w, shape.shape, shape.dtype) * w_std + hh_std = w_std / math.sqrt(12) + hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std + if w.ndim == 0 or w.ndim == 1: + a = jnp.zeros_like(w, dtype=shape.dtype) + levy_val = DavieFosterWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) + else: + tenth_dt = (0.1 * dt).astype(shape.dtype) + hh_squared = hh**2 + b_std = jnp.sqrt( + tenth_dt + * ( + tenth_dt + + jnp.expand_dims(hh_squared, -1) + + jnp.expand_dims(hh_squared, -2) + ) + ).astype(shape.dtype) + b = ( + jr.normal(key_b, shape.shape + shape.shape[-1:], shape.dtype) + * b_std + ) + b = b - b.transpose(*range(b.ndim - 2), -1, -2) + a = jnp.expand_dims(hh, -1) * jnp.expand_dims(w, -2) - jnp.expand_dims( + w, -1 + ) * jnp.expand_dims(hh, -2) + a += b + levy_val = DavieFosterWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a) + elif levy_area is SpaceTimeLevyArea: key_w, key_hh = jr.split(key, 2) w = jr.normal(key_w, shape.shape, shape.dtype) * w_std diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index 70ec5a1a..8bfe16f1 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -48,6 +48,7 @@ Args = PyTree[Any] BM = PyTree[Shaped[ArrayLike, "?*bm"], "BM"] +Area = PyTree[Shaped[ArrayLike, "?*area"], "Area"] DenseInfo = dict[str, PyTree[Array]] DenseInfos = dict[str, PyTree[Shaped[Array, "times-1 ..."]]] @@ -72,6 +73,39 @@ class AbstractSpaceTimeLevyArea(AbstractBrownianIncrement): H: eqx.AbstractVar[BM] +class AbstractWeakSpaceSpaceLevyArea(AbstractBrownianIncrement): + """ + Abstract base class for all weak Space Space Levy Areas. + """ + + H: eqx.AbstractVar[BM] + A: eqx.AbstractVar[BM] + + +class DavieWeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea): + """ + Davie's approximation to weak Space Space Levy Areas. + See (7.4.1) of Foster's thesis. + """ + + dt: PyTree[FloatScalarLike, "BM"] + W: BM + H: BM + A: Area + + +class DavieFosterWeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea): + """ + Davie's approximation to weak Space Space Levy Areas. + See (7.4.2) of Foster's thesis. + """ + + dt: PyTree[FloatScalarLike, "BM"] + W: BM + H: BM + A: Area + + class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea): """ Abstract base class for all Space Time Time Levy Areas. diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 938eee37..a278a7a0 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -1054,7 +1054,7 @@ def _promote(yi): if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): # Specific check to not work even if using HalfSolver(Euler()) if isinstance(solver, Euler): - raise ValueError( + warnings.warn( "An SDE should not be solved with adaptive step sizes with Euler's " "method, as it may not converge to the correct solution." ) diff --git a/test/test_brownian.py b/test/test_brownian.py index 3a265019..978cb49d 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -1,3 +1,7 @@ +import jax + + +jax.config.update("jax_enable_x64", True) import contextlib import math from typing import Literal @@ -36,12 +40,22 @@ def _make_struct(shape, dtype): @pytest.mark.parametrize( "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] ) -@pytest.mark.parametrize("levy_area", _levy_areas) +@pytest.mark.parametrize( + "levy_area", + _levy_areas + + (diffrax.DavieWeakSpaceSpaceLevyArea, diffrax.DavieFosterWeakSpaceSpaceLevyArea), +) @pytest.mark.parametrize("use_levy", (False, True)) def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): t0 = 0.0 t1 = 2.0 + if ( + issubclass(levy_area, diffrax.AbstractWeakSpaceSpaceLevyArea) + and ctr is diffrax.VirtualBrownianTree + ): + return + shapes_dtypes1 = ( ((), None), ((0,), None),