diff --git a/docs/source/api.rst b/docs/source/api.rst index c42abbc7..94c9ad8b 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -22,6 +22,7 @@ Stream filter flatten map + map_async partition rate_limit scatter diff --git a/docs/source/async.rst b/docs/source/async.rst index 974ae376..8f32cff8 100644 --- a/docs/source/async.rst +++ b/docs/source/async.rst @@ -72,8 +72,8 @@ This would also work with async-await syntax in Python 3 .. code-block:: python + import asyncio from streamz import Stream - from tornado.ioloop import IOLoop async def f(): source = Stream(asynchronous=True) # tell the stream we're working asynchronously @@ -82,7 +82,28 @@ This would also work with async-await syntax in Python 3 for x in range(10): await source.emit(x) - IOLoop().run_sync(f) + asyncio.run(f()) + +When working asynchronously, we can also map asynchronous functions. + +.. code-block:: python + + async def increment_async(x): + """ A "long-running" increment function + + Simulates a function that does real asyncio work. + """ + await asyncio.sleep(0.1) + return x + 1 + + async def f_inc(): + source = Stream(asynchronous=True) # tell the stream we're working asynchronously + source.map_async(increment_async).rate_limit(0.500).sink(write) + + for x in range(10): + await source.emit(x) + + asyncio.run(f_inc()) Event Loop on a Separate Thread diff --git a/streamz/core.py b/streamz/core.py index ae7a27f9..00b5ed4c 100644 --- a/streamz/core.py +++ b/streamz/core.py @@ -718,6 +718,86 @@ def update(self, x, who=None, metadata=None): return self._emit(result, metadata=metadata) +@Stream.register_api() +class map_async(Stream): + """ Apply an async function to every element in the stream, preserving order + even when evaluating multiple inputs in parallel. + + Parameters + ---------- + func: async callable + *args : + The arguments to pass to the function. + parallelism: + The maximum number of parallel Tasks for evaluating func, default value is 1 + **kwargs: + Keyword arguments to pass to func + + Examples + -------- + >>> async def mult(x, factor=1): + ... return factor*x + >>> async def run(): + ... source = Stream(asynchronous=True) + ... source.map_async(mult, factor=2).sink(print) + ... for i in range(5): + ... await source.emit(i) + >>> asyncio.run(run()) + 0 + 2 + 4 + 6 + 8 + """ + def __init__(self, upstream, func, *args, parallelism=1, **kwargs): + self.func = func + stream_name = kwargs.pop('stream_name', None) + self.kwargs = kwargs + self.args = args + self.work_queue = asyncio.Queue(maxsize=parallelism) + + Stream.__init__(self, upstream, stream_name=stream_name, ensure_io_loop=True) + self.work_task = self._create_task(self.work_callback()) + + def update(self, x, who=None, metadata=None): + return self._create_task(self._insert_job(x, metadata)) + + def _create_task(self, coro): + if gen.is_future(coro): + return coro + return self.loop.asyncio_loop.create_task(coro) + + async def work_callback(self): + while True: + try: + task, metadata = await self.work_queue.get() + self.work_queue.task_done() + result = await task + except Exception as e: + logger.exception(e) + raise + else: + results = self._emit(result, metadata=metadata) + if results: + await asyncio.gather(*results) + self._release_refs(metadata) + + async def _wait_for_work_slot(self): + while self.work_queue.full(): + await asyncio.sleep(0) + + async def _insert_job(self, x, metadata): + try: + await self._wait_for_work_slot() + coro = self.func(x, *self.args, **self.kwargs) + task = self._create_task(coro) + await self.work_queue.put((task, metadata)) + self._retain_refs(metadata) + except Exception as e: + logger.exception(e) + raise + + @Stream.register_api() class starmap(Stream): """ Apply a function to every element in the stream, splayed out diff --git a/streamz/dataframe/tests/test_dataframes.py b/streamz/dataframe/tests/test_dataframes.py index f8345263..62029e82 100644 --- a/streamz/dataframe/tests/test_dataframes.py +++ b/streamz/dataframe/tests/test_dataframes.py @@ -8,6 +8,7 @@ from dask.dataframe.utils import assert_eq import numpy as np import pandas as pd +from flaky import flaky from tornado import gen from streamz import Stream @@ -570,6 +571,7 @@ def test_cumulative_aggregations(op, getter, stream): assert_eq(pd.concat(L), expected) +@flaky(max_runs=3, min_passes=1) @gen_test() def test_gc(): sdf = sd.Random(freq='5ms', interval='100ms') diff --git a/streamz/tests/test_core.py b/streamz/tests/test_core.py index 96ab9a24..9245f2e6 100644 --- a/streamz/tests/test_core.py +++ b/streamz/tests/test_core.py @@ -126,6 +126,56 @@ def add(x=0, y=0): assert L[0] == 11 +@gen_test() +def test_map_async_tornado(): + @gen.coroutine + def add_tor(x=0, y=0): + return x + y + + async def add_native(x=0, y=0): + await asyncio.sleep(0.1) + return x + y + + source = Stream(asynchronous=True) + L = source.map_async(add_tor, y=1).map_async(add_native, parallelism=2, y=2).buffer(1).sink_to_list() + + start = time() + yield source.emit(0) + yield source.emit(1) + yield source.emit(2) + + def fail_func(): + assert L == [3, 4, 5] + + yield await_for(lambda: L == [3, 4, 5], 1, fail_func=fail_func) + assert (time() - start) == pytest.approx(0.1, abs=4e-3) + + +@pytest.mark.asyncio +async def test_map_async(): + @gen.coroutine + def add_tor(x=0, y=0): + return x + y + + async def add_native(x=0, y=0): + await asyncio.sleep(0.1) + return x + y + + source = Stream(asynchronous=True) + L = source.map_async(add_tor, y=1).map_async(add_native, parallelism=2, y=2).sink_to_list() + + start = time() + await source.emit(0) + await source.emit(1) + await source.emit(2) + + def fail_func(): + assert L == [3, 4, 5] + + await await_for(lambda: L == [3, 4, 5], 1, fail_func=fail_func) + assert (time() - start) == pytest.approx(0.1, abs=4e-3) + + def test_map_args(): source = Stream() L = source.map(operator.add, 10).sink_to_list()