Skip to content

Commit d46b70f

Browse files
feat(client): add event handler implementation for websockets
1 parent 41853cb commit d46b70f

2 files changed

Lines changed: 233 additions & 3 deletions

File tree

src/dedalus_sdk/_event_handler.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2+
3+
from __future__ import annotations
4+
5+
import threading
6+
from typing import Any, Callable
7+
8+
EventHandler = Callable[..., Any]
9+
10+
11+
class EventHandlerRegistry:
12+
"""Thread-safe (optional) registry of event handlers."""
13+
14+
def __init__(self, *, use_lock: bool = False) -> None:
15+
self._handlers: dict[str, list[EventHandler]] = {}
16+
self._once_ids: set[int] = set()
17+
self._lock: threading.Lock | None = threading.Lock() if use_lock else None
18+
19+
def _acquire(self) -> None:
20+
if self._lock is not None:
21+
self._lock.acquire()
22+
23+
def _release(self) -> None:
24+
if self._lock is not None:
25+
self._lock.release()
26+
27+
def add(self, event_type: str, handler: EventHandler, *, once: bool = False) -> None:
28+
self._acquire()
29+
try:
30+
handlers = self._handlers.setdefault(event_type, [])
31+
handlers.append(handler)
32+
if once:
33+
self._once_ids.add(id(handler))
34+
finally:
35+
self._release()
36+
37+
def remove(self, event_type: str, handler: EventHandler) -> None:
38+
self._acquire()
39+
try:
40+
handlers = self._handlers.get(event_type)
41+
if handlers is not None:
42+
try:
43+
handlers.remove(handler)
44+
except ValueError:
45+
pass
46+
self._once_ids.discard(id(handler))
47+
finally:
48+
self._release()
49+
50+
def get_handlers(self, event_type: str) -> list[EventHandler]:
51+
"""Return a snapshot of handlers for the given event type, removing once-handlers."""
52+
self._acquire()
53+
try:
54+
handlers = self._handlers.get(event_type)
55+
if not handlers:
56+
return []
57+
result = list(handlers)
58+
to_remove = [h for h in result if id(h) in self._once_ids]
59+
for h in to_remove:
60+
handlers.remove(h)
61+
self._once_ids.discard(id(h))
62+
return result
63+
finally:
64+
self._release()
65+
66+
def has_handlers(self, event_type: str) -> bool:
67+
self._acquire()
68+
try:
69+
handlers = self._handlers.get(event_type)
70+
return bool(handlers)
71+
finally:
72+
self._release()

src/dedalus_sdk/resources/machines/terminals.py

Lines changed: 161 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import random
88
import logging
99
from types import TracebackType
10-
from typing import TYPE_CHECKING, Any, Dict, Callable, Iterator, Awaitable, cast
10+
from typing import TYPE_CHECKING, Any, Dict, Union, Callable, Iterator, Awaitable, cast
1111
from typing_extensions import AsyncIterator
1212

1313
import httpx
@@ -27,10 +27,12 @@
2727
from ...pagination import SyncCursorPage, AsyncCursorPage
2828
from ..._exceptions import DedalusError
2929
from ..._base_client import AsyncPaginator, _merge_mappings, make_request_options
30+
from ..._event_handler import EventHandlerRegistry
3031
from ...types.machines import terminal_list_params, terminal_create_params
3132
from ...types.machines.terminal import Terminal
3233
from ...types.websocket_reconnection import ReconnectingEvent, ReconnectingOverrides, is_recoverable_close
3334
from ...types.websocket_connection_options import WebSocketConnectionOptions
35+
from ...types.machines.terminal_error_event import TerminalErrorEvent
3436
from ...types.machines.terminal_client_event import TerminalClientEvent
3537
from ...types.machines.terminal_server_event import TerminalServerEvent
3638
from ...types.machines.terminal_client_event_param import TerminalClientEventParam
@@ -606,6 +608,7 @@ def __init__(
606608
self._extra_query = extra_query
607609
self._extra_headers = extra_headers
608610
self._intentionally_closed = False
611+
self._event_handler_registry = EventHandlerRegistry(use_lock=False)
609612

610613
async def __aiter__(self) -> AsyncIterator[TerminalServerEvent]:
611614
"""
@@ -736,6 +739,86 @@ async def _reconnect(self, exc: Exception) -> bool:
736739

737740
return False
738741

742+
def on(
743+
self, event_type: str, handler: Callable[..., Any] | None = None
744+
) -> Union[AsyncTerminalsResourceConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]:
745+
"""Adds the handler to the end of the handlers list for the given event type.
746+
747+
No checks are made to see if the handler has already been added. Multiple calls
748+
passing the same combination of event type and handler will result in the handler
749+
being added, and called, multiple times.
750+
751+
Can be used as a method (returns ``self`` for chaining)::
752+
753+
connection.on("output", my_handler)
754+
755+
Or as a decorator::
756+
757+
@connection.on("output")
758+
async def my_handler(event): ...
759+
"""
760+
if handler is not None:
761+
self._event_handler_registry.add(event_type, handler)
762+
return self
763+
764+
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
765+
self._event_handler_registry.add(event_type, fn)
766+
return fn
767+
768+
return decorator
769+
770+
def off(self, event_type: str, handler: Callable[..., Any]) -> AsyncTerminalsResourceConnection:
771+
"""Remove a previously registered event handler."""
772+
self._event_handler_registry.remove(event_type, handler)
773+
return self
774+
775+
def once(
776+
self, event_type: str, handler: Callable[..., Any] | None = None
777+
) -> Union[AsyncTerminalsResourceConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]:
778+
"""Register a one-time event handler.
779+
780+
Automatically removed after first invocation.
781+
"""
782+
if handler is not None:
783+
self._event_handler_registry.add(event_type, handler, once=True)
784+
return self
785+
786+
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
787+
self._event_handler_registry.add(event_type, fn, once=True)
788+
return fn
789+
790+
return decorator
791+
792+
async def dispatch_events(self) -> None:
793+
"""Run the event loop, dispatching received events to registered handlers.
794+
795+
Blocks until the connection is closed. This is the push-based
796+
alternative to iterating with ``async for event in connection``.
797+
798+
If an ``"error"`` event arrives and no handler is registered for
799+
``"error"`` or ``"event"``, an ``DedalusError`` is raised.
800+
"""
801+
import asyncio
802+
803+
async for event in self:
804+
event_type = event.type
805+
specific = self._event_handler_registry.get_handlers(event_type)
806+
generic = self._event_handler_registry.get_handlers("event")
807+
808+
if event_type == "error" and not specific and not generic:
809+
if isinstance(event, TerminalErrorEvent):
810+
raise DedalusError(f"WebSocket error: {event}")
811+
812+
for handler in specific:
813+
result = handler(event)
814+
if asyncio.iscoroutine(result):
815+
await result
816+
817+
for handler in generic:
818+
result = handler(event)
819+
if asyncio.iscoroutine(result):
820+
await result
821+
739822

740823
class AsyncTerminalsResourceConnectionManager:
741824
"""
@@ -785,7 +868,7 @@ def __init__(
785868

786869
async def __aenter__(self) -> AsyncTerminalsResourceConnection:
787870
"""
788-
👋 If your application doesn't work well with the context manager approach then you
871+
If your application doesn't work well with the context manager approach then you
789872
can call this method directly to initiate a connection.
790873
791874
**Warning**: You must remember to close the connection with `.close()`.
@@ -893,6 +976,7 @@ def __init__(
893976
self._extra_query = extra_query
894977
self._extra_headers = extra_headers
895978
self._intentionally_closed = False
979+
self._event_handler_registry = EventHandlerRegistry(use_lock=True)
896980

897981
def __iter__(self) -> Iterator[TerminalServerEvent]:
898982
"""
@@ -1021,6 +1105,80 @@ def _reconnect(self, exc: Exception) -> bool:
10211105

10221106
return False
10231107

1108+
def on(
1109+
self, event_type: str, handler: Callable[..., Any] | None = None
1110+
) -> Union[TerminalsResourceConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]:
1111+
"""Adds the handler to the end of the handlers list for the given event type.
1112+
1113+
No checks are made to see if the handler has already been added. Multiple calls
1114+
passing the same combination of event type and handler will result in the handler
1115+
being added, and called, multiple times.
1116+
1117+
Can be used as a method (returns ``self`` for chaining)::
1118+
1119+
connection.on("output", my_handler)
1120+
1121+
Or as a decorator::
1122+
1123+
@connection.on("output")
1124+
def my_handler(event): ...
1125+
"""
1126+
if handler is not None:
1127+
self._event_handler_registry.add(event_type, handler)
1128+
return self
1129+
1130+
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
1131+
self._event_handler_registry.add(event_type, fn)
1132+
return fn
1133+
1134+
return decorator
1135+
1136+
def off(self, event_type: str, handler: Callable[..., Any]) -> TerminalsResourceConnection:
1137+
"""Remove a previously registered event handler."""
1138+
self._event_handler_registry.remove(event_type, handler)
1139+
return self
1140+
1141+
def once(
1142+
self, event_type: str, handler: Callable[..., Any] | None = None
1143+
) -> Union[TerminalsResourceConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]:
1144+
"""Register a one-time event handler.
1145+
1146+
Automatically removed after first invocation.
1147+
"""
1148+
if handler is not None:
1149+
self._event_handler_registry.add(event_type, handler, once=True)
1150+
return self
1151+
1152+
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
1153+
self._event_handler_registry.add(event_type, fn, once=True)
1154+
return fn
1155+
1156+
return decorator
1157+
1158+
def dispatch_events(self) -> None:
1159+
"""Run the event loop, dispatching received events to registered handlers.
1160+
1161+
Blocks the current thread until the connection is closed. This is the push-based
1162+
alternative to iterating with ``for event in connection``.
1163+
1164+
If an ``"error"`` event arrives and no handler is registered for
1165+
``"error"`` or ``"event"``, an ``DedalusError`` is raised.
1166+
"""
1167+
for event in self:
1168+
event_type = event.type
1169+
specific = self._event_handler_registry.get_handlers(event_type)
1170+
generic = self._event_handler_registry.get_handlers("event")
1171+
1172+
if event_type == "error" and not specific and not generic:
1173+
if isinstance(event, TerminalErrorEvent):
1174+
raise DedalusError(f"WebSocket error: {event}")
1175+
1176+
for handler in specific:
1177+
handler(event)
1178+
1179+
for handler in generic:
1180+
handler(event)
1181+
10241182

10251183
class TerminalsResourceConnectionManager:
10261184
"""
@@ -1070,7 +1228,7 @@ def __init__(
10701228

10711229
def __enter__(self) -> TerminalsResourceConnection:
10721230
"""
1073-
👋 If your application doesn't work well with the context manager approach then you
1231+
If your application doesn't work well with the context manager approach then you
10741232
can call this method directly to initiate a connection.
10751233
10761234
**Warning**: You must remember to close the connection with `.close()`.

0 commit comments

Comments
 (0)