Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 130 additions & 45 deletions conformance/test/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import argparse
import asyncio
import multiprocessing
import ssl
import sys
import time
import traceback
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Literal, TypeVar
from typing import TYPE_CHECKING, Any, Literal, TypeVar

import httpx
from _util import create_standard_streams
Expand Down Expand Up @@ -106,6 +107,45 @@ def _unpack_request(message: Any, request: T) -> T:
return request


def _build_tls_context(
server_cert: bytes,
client_cert: bytes | None = None,
client_key: bytes | None = None,
) -> ssl.SSLContext:
ctx = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH, cadata=server_cert.decode()
)
if client_cert is None or client_key is None:
return ctx
with NamedTemporaryFile() as cert_file, NamedTemporaryFile() as key_file:
cert_file.write(client_cert)
cert_file.flush()
key_file.write(client_key)
key_file.flush()
ctx.load_cert_chain(certfile=cert_file.name, keyfile=key_file.name)
return ctx


def _schedule_cancel(task: asyncio.Task, delay_s: float) -> asyncio.Handle:
loop = asyncio.get_running_loop()

def _cancel() -> None:
task.cancel()

return loop.call_later(delay_s, _cancel)


def _schedule_cancel_after_close_send(
task: asyncio.Task, delay_s: float, close_send_event: asyncio.Event
) -> asyncio.Task:
async def _run() -> None:
await close_send_event.wait()
await asyncio.sleep(delay_s)
task.cancel()

return asyncio.create_task(_run())


async def _run_test(
mode: Literal["sync", "async"], test_request: ClientCompatRequest
) -> ClientCompatResponse:
Expand All @@ -125,7 +165,8 @@ async def _run_test(
request_headers.add(header.name, value)

payloads: list[ConformancePayload] = []

close_send_event = asyncio.Event()
loop = asyncio.get_running_loop()
with ResponseMetadata() as meta:
try:
task: asyncio.Task
Expand All @@ -140,22 +181,17 @@ async def _run_test(
scheme = "http"
if test_request.server_tls_cert:
scheme = "https"
ctx = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH,
cadata=test_request.server_tls_cert.decode(),
)
if test_request.HasField("client_tls_creds"):
with (
NamedTemporaryFile() as cert_file,
NamedTemporaryFile() as key_file,
):
cert_file.write(test_request.client_tls_creds.cert)
cert_file.flush()
key_file.write(test_request.client_tls_creds.key)
key_file.flush()
ctx.load_cert_chain(
certfile=cert_file.name, keyfile=key_file.name
)
ctx = await asyncio.to_thread(
_build_tls_context,
test_request.server_tls_cert,
test_request.client_tls_creds.cert,
test_request.client_tls_creds.key,
)
else:
ctx = await asyncio.to_thread(
_build_tls_context, test_request.server_tls_cert
)
session_kwargs["verify"] = ctx
match mode:
case "sync":
Expand Down Expand Up @@ -188,7 +224,7 @@ def send_bidi_stream_request_sync(
num
:= test_request.cancel.after_num_responses
) and len(payloads) >= num:
task.cancel()
loop.call_soon_threadsafe(task.cancel)

def bidi_request_stream_sync():
for message in test_request.request_messages:
Expand All @@ -199,6 +235,13 @@ def bidi_request_stream_sync():
yield _unpack_request(
message, BidiStreamRequest()
)
if test_request.cancel.HasField(
"before_close_send"
):
loop.call_soon_threadsafe(task.cancel)
time.sleep(600)
else:
loop.call_soon_threadsafe(close_send_event.set)

task = asyncio.create_task(
asyncio.to_thread(
Expand Down Expand Up @@ -230,6 +273,13 @@ def request_stream_sync():
yield _unpack_request(
message, ClientStreamRequest()
)
if test_request.cancel.HasField(
"before_close_send"
):
loop.call_soon_threadsafe(task.cancel)
time.sleep(600)
else:
loop.call_soon_threadsafe(close_send_event.set)

task = asyncio.create_task(
asyncio.to_thread(
Expand Down Expand Up @@ -262,6 +312,7 @@ def send_idempotent_unary_request_sync(
),
)
)
close_send_event.set()
case "ServerStream":

def send_server_stream_request_sync(
Expand All @@ -278,7 +329,7 @@ def send_server_stream_request_sync(
num
:= test_request.cancel.after_num_responses
) and len(payloads) >= num:
task.cancel()
loop.call_soon_threadsafe(task.cancel)

task = asyncio.create_task(
asyncio.to_thread(
Expand All @@ -290,6 +341,7 @@ def send_server_stream_request_sync(
),
)
)
close_send_event.set()
case "Unary":

def send_unary_request_sync(
Expand All @@ -313,6 +365,7 @@ def send_unary_request_sync(
),
)
)
close_send_event.set()
case "Unimplemented":
task = asyncio.create_task(
asyncio.to_thread(
Expand All @@ -325,15 +378,21 @@ def send_unary_request_sync(
timeout_ms=timeout_ms,
)
)
close_send_event.set()
case _:
msg = f"Unrecognized method: {test_request.method}"
raise ValueError(msg)
cancel_task: asyncio.Task | None = None
if test_request.cancel.after_close_send_ms:
await asyncio.sleep(
test_request.cancel.after_close_send_ms / 1000.0
delay = test_request.cancel.after_close_send_ms / 1000.0
cancel_task = _schedule_cancel_after_close_send(
task, delay, close_send_event
)
task.cancel()
await task
try:
await task
finally:
if cancel_task is not None:
cancel_task.cancel()
case "async":
async with (
httpx.AsyncClient(**session_kwargs) as session,
Expand Down Expand Up @@ -384,6 +443,8 @@ async def bidi_stream_request():
# a long time. We won't end up sleeping for long since we
# cancelled.
await asyncio.sleep(600)
else:
close_send_event.set()

task = asyncio.create_task(
send_bidi_stream_request(
Expand Down Expand Up @@ -421,6 +482,8 @@ async def client_stream_request():
# a long time. We won't end up sleeping for long since we
# cancelled.
await asyncio.sleep(600)
else:
close_send_event.set()

task = asyncio.create_task(
send_client_stream_request(
Expand Down Expand Up @@ -450,6 +513,7 @@ async def send_idempotent_unary_request(
),
)
)
close_send_event.set()
case "ServerStream":

async def send_server_stream_request(
Expand Down Expand Up @@ -477,6 +541,7 @@ async def send_server_stream_request(
),
)
)
close_send_event.set()
case "Unary":

async def send_unary_request(
Expand All @@ -499,6 +564,7 @@ async def send_unary_request(
),
)
)
close_send_event.set()
case "Unimplemented":
task = asyncio.create_task(
client.unimplemented(
Expand All @@ -510,15 +576,21 @@ async def send_unary_request(
timeout_ms=timeout_ms,
)
)
close_send_event.set()
case _:
msg = f"Unrecognized method: {test_request.method}"
raise ValueError(msg)
cancel_task: asyncio.Task | None = None
if test_request.cancel.after_close_send_ms:
await asyncio.sleep(
test_request.cancel.after_close_send_ms / 1000.0
delay = test_request.cancel.after_close_send_ms / 1000.0
cancel_task = _schedule_cancel_after_close_send(
task, delay, close_send_event
)
task.cancel()
await task
try:
await task
finally:
if cancel_task is not None:
cancel_task.cancel()
except ConnectError as e:
test_response.response.error.code = _convert_code(e.code)
test_response.response.error.message = e.message
Expand All @@ -543,32 +615,45 @@ async def send_unary_request(

class Args(argparse.Namespace):
mode: Literal["sync", "async"]
parallel: int


async def main() -> None:
parser = argparse.ArgumentParser(description="Conformance client")
parser.add_argument("--mode", choices=["sync", "async"])
parser.add_argument("--parallel", type=int, default=multiprocessing.cpu_count() * 4)
args = parser.parse_args(namespace=Args())

stdin, stdout = await create_standard_streams()
while True:
try:
size_buf = await stdin.readexactly(4)
except asyncio.IncompleteReadError:
return
size = int.from_bytes(size_buf, byteorder="big")
# Allow to raise even on EOF since we always should have a message
request_buf = await stdin.readexactly(size)
request = ClientCompatRequest()
request.ParseFromString(request_buf)

response = await _run_test(args.mode, request)

response_buf = response.SerializeToString()
size_buf = len(response_buf).to_bytes(4, byteorder="big")
stdout.write(size_buf)
stdout.write(response_buf)
await stdout.drain()
sema = asyncio.Semaphore(args.parallel)
stdout_lock = asyncio.Lock()
tasks: list[asyncio.Task] = []
try:
while True:
try:
size_buf = await stdin.readexactly(4)
except asyncio.IncompleteReadError:
return
size = int.from_bytes(size_buf, byteorder="big")
# Allow to raise even on EOF since we always should have a message
request_buf = await stdin.readexactly(size)
request = ClientCompatRequest()
request.ParseFromString(request_buf)

async def task(request: ClientCompatRequest) -> None:
async with sema:
response = await _run_test(args.mode, request)

response_buf = response.SerializeToString()
size_buf = len(response_buf).to_bytes(4, byteorder="big")
async with stdout_lock:
stdout.write(size_buf)
stdout.write(response_buf)
await stdout.drain()

tasks.append(asyncio.create_task(task(request)))
finally:
await asyncio.gather(*tasks)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion conformance/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_client_async() -> None:
"client",
*_skipped_tests_async,
"--known-flaky",
"Client Cancellation/**",
"**/cancel-after-responses",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still flaky, I suspect it is an issue in the test harness, not code. I will defer debugging it as already did a lot of debugging and want to get the fixes in

"--",
*args,
],
Expand Down
Loading