diff --git a/conformance/test/client.py b/conformance/test/client.py index e621496..cd6d27c 100644 --- a/conformance/test/client.py +++ b/conformance/test/client.py @@ -2,12 +2,13 @@ import argparse import asyncio +import multiprocessing import ssl import sys import time import traceback from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Literal, TypeVar import httpx from _util import create_standard_streams @@ -106,6 +107,45 @@ def _unpack_request(message: Any, request: T) -> T: return request +def _build_tls_context( + server_cert: bytes, + client_cert: bytes | None = None, + client_key: bytes | None = None, +) -> ssl.SSLContext: + ctx = ssl.create_default_context( + purpose=ssl.Purpose.SERVER_AUTH, cadata=server_cert.decode() + ) + if client_cert is None or client_key is None: + return ctx + with NamedTemporaryFile() as cert_file, NamedTemporaryFile() as key_file: + cert_file.write(client_cert) + cert_file.flush() + key_file.write(client_key) + key_file.flush() + ctx.load_cert_chain(certfile=cert_file.name, keyfile=key_file.name) + return ctx + + +def _schedule_cancel(task: asyncio.Task, delay_s: float) -> asyncio.Handle: + loop = asyncio.get_running_loop() + + def _cancel() -> None: + task.cancel() + + return loop.call_later(delay_s, _cancel) + + +def _schedule_cancel_after_close_send( + task: asyncio.Task, delay_s: float, close_send_event: asyncio.Event +) -> asyncio.Task: + async def _run() -> None: + await close_send_event.wait() + await asyncio.sleep(delay_s) + task.cancel() + + return asyncio.create_task(_run()) + + async def _run_test( mode: Literal["sync", "async"], test_request: ClientCompatRequest ) -> ClientCompatResponse: @@ -125,7 +165,8 @@ async def _run_test( request_headers.add(header.name, value) payloads: list[ConformancePayload] = [] - + close_send_event = asyncio.Event() + loop = asyncio.get_running_loop() with ResponseMetadata() as meta: try: task: asyncio.Task @@ -140,22 +181,17 @@ async def _run_test( scheme = "http" if test_request.server_tls_cert: scheme = "https" - ctx = ssl.create_default_context( - purpose=ssl.Purpose.SERVER_AUTH, - cadata=test_request.server_tls_cert.decode(), - ) if test_request.HasField("client_tls_creds"): - with ( - NamedTemporaryFile() as cert_file, - NamedTemporaryFile() as key_file, - ): - cert_file.write(test_request.client_tls_creds.cert) - cert_file.flush() - key_file.write(test_request.client_tls_creds.key) - key_file.flush() - ctx.load_cert_chain( - certfile=cert_file.name, keyfile=key_file.name - ) + ctx = await asyncio.to_thread( + _build_tls_context, + test_request.server_tls_cert, + test_request.client_tls_creds.cert, + test_request.client_tls_creds.key, + ) + else: + ctx = await asyncio.to_thread( + _build_tls_context, test_request.server_tls_cert + ) session_kwargs["verify"] = ctx match mode: case "sync": @@ -188,7 +224,7 @@ def send_bidi_stream_request_sync( num := test_request.cancel.after_num_responses ) and len(payloads) >= num: - task.cancel() + loop.call_soon_threadsafe(task.cancel) def bidi_request_stream_sync(): for message in test_request.request_messages: @@ -199,6 +235,13 @@ def bidi_request_stream_sync(): yield _unpack_request( message, BidiStreamRequest() ) + if test_request.cancel.HasField( + "before_close_send" + ): + loop.call_soon_threadsafe(task.cancel) + time.sleep(600) + else: + loop.call_soon_threadsafe(close_send_event.set) task = asyncio.create_task( asyncio.to_thread( @@ -230,6 +273,13 @@ def request_stream_sync(): yield _unpack_request( message, ClientStreamRequest() ) + if test_request.cancel.HasField( + "before_close_send" + ): + loop.call_soon_threadsafe(task.cancel) + time.sleep(600) + else: + loop.call_soon_threadsafe(close_send_event.set) task = asyncio.create_task( asyncio.to_thread( @@ -262,6 +312,7 @@ def send_idempotent_unary_request_sync( ), ) ) + close_send_event.set() case "ServerStream": def send_server_stream_request_sync( @@ -278,7 +329,7 @@ def send_server_stream_request_sync( num := test_request.cancel.after_num_responses ) and len(payloads) >= num: - task.cancel() + loop.call_soon_threadsafe(task.cancel) task = asyncio.create_task( asyncio.to_thread( @@ -290,6 +341,7 @@ def send_server_stream_request_sync( ), ) ) + close_send_event.set() case "Unary": def send_unary_request_sync( @@ -313,6 +365,7 @@ def send_unary_request_sync( ), ) ) + close_send_event.set() case "Unimplemented": task = asyncio.create_task( asyncio.to_thread( @@ -325,15 +378,21 @@ def send_unary_request_sync( timeout_ms=timeout_ms, ) ) + close_send_event.set() case _: msg = f"Unrecognized method: {test_request.method}" raise ValueError(msg) + cancel_task: asyncio.Task | None = None if test_request.cancel.after_close_send_ms: - await asyncio.sleep( - test_request.cancel.after_close_send_ms / 1000.0 + delay = test_request.cancel.after_close_send_ms / 1000.0 + cancel_task = _schedule_cancel_after_close_send( + task, delay, close_send_event ) - task.cancel() - await task + try: + await task + finally: + if cancel_task is not None: + cancel_task.cancel() case "async": async with ( httpx.AsyncClient(**session_kwargs) as session, @@ -384,6 +443,8 @@ async def bidi_stream_request(): # a long time. We won't end up sleeping for long since we # cancelled. await asyncio.sleep(600) + else: + close_send_event.set() task = asyncio.create_task( send_bidi_stream_request( @@ -421,6 +482,8 @@ async def client_stream_request(): # a long time. We won't end up sleeping for long since we # cancelled. await asyncio.sleep(600) + else: + close_send_event.set() task = asyncio.create_task( send_client_stream_request( @@ -450,6 +513,7 @@ async def send_idempotent_unary_request( ), ) ) + close_send_event.set() case "ServerStream": async def send_server_stream_request( @@ -477,6 +541,7 @@ async def send_server_stream_request( ), ) ) + close_send_event.set() case "Unary": async def send_unary_request( @@ -499,6 +564,7 @@ async def send_unary_request( ), ) ) + close_send_event.set() case "Unimplemented": task = asyncio.create_task( client.unimplemented( @@ -510,15 +576,21 @@ async def send_unary_request( timeout_ms=timeout_ms, ) ) + close_send_event.set() case _: msg = f"Unrecognized method: {test_request.method}" raise ValueError(msg) + cancel_task: asyncio.Task | None = None if test_request.cancel.after_close_send_ms: - await asyncio.sleep( - test_request.cancel.after_close_send_ms / 1000.0 + delay = test_request.cancel.after_close_send_ms / 1000.0 + cancel_task = _schedule_cancel_after_close_send( + task, delay, close_send_event ) - task.cancel() - await task + try: + await task + finally: + if cancel_task is not None: + cancel_task.cancel() except ConnectError as e: test_response.response.error.code = _convert_code(e.code) test_response.response.error.message = e.message @@ -543,32 +615,45 @@ async def send_unary_request( class Args(argparse.Namespace): mode: Literal["sync", "async"] + parallel: int async def main() -> None: parser = argparse.ArgumentParser(description="Conformance client") parser.add_argument("--mode", choices=["sync", "async"]) + parser.add_argument("--parallel", type=int, default=multiprocessing.cpu_count() * 4) args = parser.parse_args(namespace=Args()) stdin, stdout = await create_standard_streams() - while True: - try: - size_buf = await stdin.readexactly(4) - except asyncio.IncompleteReadError: - return - size = int.from_bytes(size_buf, byteorder="big") - # Allow to raise even on EOF since we always should have a message - request_buf = await stdin.readexactly(size) - request = ClientCompatRequest() - request.ParseFromString(request_buf) - - response = await _run_test(args.mode, request) - - response_buf = response.SerializeToString() - size_buf = len(response_buf).to_bytes(4, byteorder="big") - stdout.write(size_buf) - stdout.write(response_buf) - await stdout.drain() + sema = asyncio.Semaphore(args.parallel) + stdout_lock = asyncio.Lock() + tasks: list[asyncio.Task] = [] + try: + while True: + try: + size_buf = await stdin.readexactly(4) + except asyncio.IncompleteReadError: + return + size = int.from_bytes(size_buf, byteorder="big") + # Allow to raise even on EOF since we always should have a message + request_buf = await stdin.readexactly(size) + request = ClientCompatRequest() + request.ParseFromString(request_buf) + + async def task(request: ClientCompatRequest) -> None: + async with sema: + response = await _run_test(args.mode, request) + + response_buf = response.SerializeToString() + size_buf = len(response_buf).to_bytes(4, byteorder="big") + async with stdout_lock: + stdout.write(size_buf) + stdout.write(response_buf) + await stdout.drain() + + tasks.append(asyncio.create_task(task(request))) + finally: + await asyncio.gather(*tasks) if __name__ == "__main__": diff --git a/conformance/test/test_client.py b/conformance/test/test_client.py index 22426ca..46012ad 100644 --- a/conformance/test/test_client.py +++ b/conformance/test/test_client.py @@ -82,7 +82,7 @@ def test_client_async() -> None: "client", *_skipped_tests_async, "--known-flaky", - "Client Cancellation/**", + "**/cancel-after-responses", "--", *args, ], diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index d8fcfc3..a520682 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -1,15 +1,15 @@ from __future__ import annotations import asyncio +import contextlib import functools -from asyncio import CancelledError, sleep, wait_for -from typing import TYPE_CHECKING, Any, Protocol, TypeVar +from asyncio import CancelledError, wait_for +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast import httpx from httpx import USE_CLIENT_DEFAULT, Timeout from . import _client_shared -from ._asyncio_timeout import timeout as asyncio_timeout from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec from ._envelope import EnvelopeReader from ._interceptor_async import ( @@ -29,13 +29,6 @@ from .code import Code from .errors import ConnectError -try: - from asyncio import ( - timeout as asyncio_timeout, # pyright: ignore[reportAttributeAccessIssue] - ) -except ImportError: - from ._asyncio_timeout import timeout as asyncio_timeout - if TYPE_CHECKING: import sys from collections.abc import AsyncIterator, Iterable, Mapping @@ -276,65 +269,95 @@ async def _send_request_unary( timeout_s = None timeout = USE_CLIENT_DEFAULT - try: + result_queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1) + + async def _do_request() -> None: request_data = self._codec.encode(request) if self._send_compression: request_data = self._send_compression.compress(request_data) - if ctx.http_method() == "GET": - params = _client_shared.prepare_get_params( - self._codec, request_data, request_headers - ) - request_headers.pop("content-type", None) - resp = await wait_for( - self._session.get( - url=url, headers=request_headers, params=params, timeout=timeout - ), - timeout_s, - ) - else: - resp = await wait_for( - self._session.post( + try: + if ctx.http_method() == "GET": + params = _client_shared.prepare_get_params( + self._codec, request_data, request_headers + ) + request_headers.pop("content-type", None) + httpx_request = self._session.build_request( + method="GET", + url=url, + headers=request_headers, + params=params, + timeout=timeout, + ) + else: + httpx_request = self._session.build_request( + method="POST", url=url, headers=request_headers, content=request_data, timeout=timeout, - ), - timeout_s, - ) - - _client_shared.validate_response_content_encoding( - resp.headers.get("content-encoding", "") - ) - _client_shared.validate_response_content_type( - self._codec.name(), - resp.status_code, - resp.headers.get("content-type", ""), - ) - handle_response_headers(resp.headers) - - if resp.status_code == 200: - if ( - self._read_max_bytes is not None - and len(resp.content) > self._read_max_bytes - ): - raise ConnectError( - Code.RESOURCE_EXHAUSTED, - f"message is larger than configured max {self._read_max_bytes}", ) - response = ctx.method().output() - self._codec.decode(resp.content, response) - return response - raise ConnectWireError.from_response(resp).to_exception() + resp = await wait_for(self._session.send(httpx_request), timeout_s) + + _client_shared.validate_response_content_encoding( + resp.headers.get("content-encoding", "") + ) + _client_shared.validate_response_content_type( + self._codec.name(), + resp.status_code, + resp.headers.get("content-type", ""), + ) + handle_response_headers(resp.headers) + + if resp.status_code == 200: + if ( + self._read_max_bytes is not None + and len(resp.content) > self._read_max_bytes + ): + raise ConnectError( + Code.RESOURCE_EXHAUSTED, + f"message is larger than configured max {self._read_max_bytes}", + ) + + response = ctx.method().output() + self._codec.decode(resp.content, response) + result_queue.put_nowait(response) + return + raise ConnectWireError.from_response(resp).to_exception() + except BaseException as exc: + if result_queue.empty(): + result_queue.put_nowait(exc) + raise + + task = asyncio.create_task(_do_request()) + task.add_done_callback(_consume_task_result) + try: + try: + if timeout_s is None: + item = await result_queue.get() + else: + item = await asyncio.wait_for(result_queue.get(), timeout_s) + except asyncio.TimeoutError: + if not task.done(): + task.cancel() + raise + if isinstance(item, BaseException): + raise item + return cast("RES", item) except (httpx.TimeoutException, TimeoutError, asyncio.TimeoutError) as e: raise ConnectError(Code.DEADLINE_EXCEEDED, "Request timed out") from e except ConnectError: raise except CancelledError as e: + if not task.done(): + task.cancel() raise ConnectError(Code.CANCELED, "Request was cancelled") from e except Exception as e: raise ConnectError(Code.UNAVAILABLE, str(e)) from e + finally: + if not task.done(): + task.cancel() async def _send_request_client_stream( self, request: AsyncIterator[REQ], ctx: RequestContext[REQ, RES] @@ -360,21 +383,29 @@ async def _send_request_bidi_stream( timeout_s = None timeout = USE_CLIENT_DEFAULT - try: - request_data = _streaming_request_content( - request, self._codec, self._send_compression - ) + loop = asyncio.get_running_loop() + deadline = None if timeout_s is None else loop.time() + timeout_s + + queue: asyncio.Queue[object] = asyncio.Queue() + sentinel = object() + + async def _produce() -> None: + resp = None + try: + request_data = _streaming_request_content( + request, self._codec, self._send_compression + ) - async with ( - asyncio_timeout(timeout_s), - self._session.stream( + # Use build_request + send to avoid AsyncContextManager which + # has issues in cleanup during cancellation. + httpx_request = self._session.build_request( method="POST", url=url, headers=request_headers, content=request_data, timeout=timeout, - ) as resp, - ): + ) + resp = await self._session.send(httpx_request, stream=True) compression = _client_shared.validate_response_content_encoding( resp.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, "") ) @@ -392,20 +423,50 @@ async def _send_request_bidi_stream( ) async for chunk in resp.aiter_bytes(): for message in reader.feed(chunk): - yield message - # Check for cancellation each message. While this seems heavyweight, - # conformance tests require it. - await sleep(0) + await queue.put(message) else: raise ConnectWireError.from_response(resp).to_exception() + except Exception as exc: + queue.put_nowait(exc) + finally: + if resp is not None: + with contextlib.suppress(Exception): + await asyncio.shield(resp.aclose()) + queue.put_nowait(sentinel) + + producer = asyncio.create_task(_produce()) + producer.add_done_callback(_consume_task_result) + try: + while True: + try: + if deadline is None: + item = await queue.get() + else: + remaining = deadline - loop.time() + if remaining <= 0: + raise asyncio.TimeoutError + item = await asyncio.wait_for(queue.get(), remaining) + except asyncio.TimeoutError: + if not producer.done(): + producer.cancel() + raise + if item is sentinel: + break + if isinstance(item, Exception): + raise item + yield cast("RES", item) except (httpx.TimeoutException, TimeoutError, asyncio.TimeoutError) as e: raise ConnectError(Code.DEADLINE_EXCEEDED, "Request timed out") from e except ConnectError: raise except CancelledError as e: + producer.cancel() raise ConnectError(Code.CANCELED, "Request was cancelled") from e except Exception as e: raise ConnectError(Code.UNAVAILABLE, str(e)) from e + finally: + if not producer.done(): + producer.cancel() def _convert_connect_timeout(timeout_ms: float | None) -> Timeout: @@ -418,6 +479,11 @@ def _convert_connect_timeout(timeout_ms: float | None) -> Timeout: return Timeout(None) +def _consume_task_result(task: asyncio.Task[Any]) -> None: + with contextlib.suppress(BaseException): + task.result() + + async def _streaming_request_content( msgs: AsyncIterator[Any], codec: Codec, compression: Compression | None ) -> AsyncIterator[bytes]: diff --git a/test/test_errors.py b/test/test_errors.py index 03b21b3..9b48253 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -353,7 +353,7 @@ def make_hat(self, request, ctx) -> NoReturn: @pytest.mark.parametrize( - ("client_timeout_ms", "call_timeout_ms"), [(200, None), (None, 200)] + ("client_timeout_ms", "call_timeout_ms"), [(50, None), (None, 50)] ) def test_sync_client_timeout( client_timeout_ms, call_timeout_ms, timeout_server: str @@ -385,12 +385,12 @@ def modify_timeout_header(request: Request) -> None: assert exc_info.value.code == Code.DEADLINE_EXCEEDED assert exc_info.value.message == "Request timed out" - assert recorded_timeout_header == "200" + assert recorded_timeout_header == "50" @pytest.mark.asyncio @pytest.mark.parametrize( - ("client_timeout_ms", "call_timeout_ms"), [(200, None), (None, 200)] + ("client_timeout_ms", "call_timeout_ms"), [(50, None), (None, 50)] ) async def test_async_client_timeout( client_timeout_ms, call_timeout_ms, timeout_server: str @@ -416,4 +416,4 @@ async def modify_timeout_header(request: Request) -> None: assert exc_info.value.code == Code.DEADLINE_EXCEEDED assert exc_info.value.message == "Request timed out" - assert recorded_timeout_header == "200" + assert recorded_timeout_header == "50"