Skip to content

Commit e5652fb

Browse files
authored
feat: 重构事件处理部分的代码,并行处理 get 函数和插件,提升处理性能 (#168)
BREAKING CHANGE: 一个事件将可能会被多个 get 函数捕获,请确保 get 函数没有副作用
1 parent 6d8e340 commit e5652fb

12 files changed

Lines changed: 2150 additions & 1879 deletions

File tree

.vscode/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
// Pytest
3-
"python.testing.pytestArgs": ["tests"],
3+
"python.testing.pytestArgs": ["-s"],
44
"python.testing.unittestEnabled": false,
55
"python.testing.pytestEnabled": true
66
}

alicebot/adapter/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ async def get(
113113
event_type: None = None,
114114
max_try_times: Optional[int] = None,
115115
timeout: Optional[Union[int, float]] = None,
116+
to_thread: bool = False,
116117
) -> EventT: ...
117118

118119
@overload
@@ -123,6 +124,7 @@ async def get(
123124
event_type: type[_EventT],
124125
max_try_times: Optional[int] = None,
125126
timeout: Optional[Union[int, float]] = None,
127+
to_thread: bool = False,
126128
) -> _EventT: ...
127129

128130
@final
@@ -133,6 +135,7 @@ async def get(
133135
event_type: Any = None,
134136
max_try_times: Optional[int] = None,
135137
timeout: Optional[Union[int, float]] = None,
138+
to_thread: bool = False,
136139
) -> Event[Any]:
137140
"""获取满足指定条件的的事件,协程会等待直到适配器接收到满足条件的事件、超过最大事件数或超时。
138141
@@ -147,6 +150,7 @@ async def get(
147150
event_type: 当指定时,只接受指定类型的事件,先于 func 条件生效。默认为 `None`。
148151
max_try_times: 最大事件数。
149152
timeout: 超时时间。
153+
to_thread: 是否在独立的线程中运行同步函数。仅当 func 为同步函数时生效。默认为 `False`。
150154
151155
Returns:
152156
返回满足 func 条件的事件。
@@ -160,4 +164,5 @@ async def get(
160164
adapter_type=type(self),
161165
max_try_times=max_try_times,
162166
timeout=timeout,
167+
to_thread=to_thread,
163168
)

alicebot/bot.py

Lines changed: 94 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import signal
99
import sys
1010
import threading
11-
import time
1211
from collections import defaultdict
1312
from collections.abc import Awaitable
1413
from contextlib import AsyncExitStack
@@ -18,28 +17,23 @@
1817

1918
import anyio
2019
import structlog
21-
from anyio.abc import TaskStatus
2220
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2321
from pydantic import ValidationError, create_model
2422

2523
from alicebot.adapter import Adapter
2624
from alicebot.config import AdapterConfig, ConfigModel, MainConfig, PluginConfig
2725
from alicebot.dependencies import solve_dependencies
2826
from alicebot.event import Event, EventHandleOption
29-
from alicebot.exceptions import (
30-
GetEventTimeout,
31-
LoadModuleError,
32-
SkipException,
33-
StopException,
34-
)
27+
from alicebot.exceptions import LoadModuleError, SkipException, StopException
28+
from alicebot.matcher import EventMatcher
3529
from alicebot.plugin import Plugin, PluginLoadType
3630
from alicebot.typing import AdapterHook, AdapterT, BotHook, EventHook, EventT
3731
from alicebot.utils import (
3832
ModulePathFinder,
33+
async_map,
3934
get_classes_from_module_name,
4035
is_config_class,
4136
samefile,
42-
wrap_get_func,
4337
)
4438

4539
if sys.version_info >= (3, 11): # pragma: no cover
@@ -79,8 +73,7 @@ class Bot:
7973

8074
_event_send_stream: MemoryObjectSendStream[EventHandleOption] # pyright: ignore[reportUninitializedInstanceVariable]
8175
_event_receive_stream: MemoryObjectReceiveStream[EventHandleOption] # pyright: ignore[reportUninitializedInstanceVariable]
82-
_condition: anyio.Condition # 用于处理 get 的 Condition # pyright: ignore[reportUninitializedInstanceVariable]
83-
_current_event: Optional[Event[Any]] # 当前待处理的 Event
76+
_event_matchers: list[EventMatcher] # pyright: ignore[reportUninitializedInstanceVariable]
8477

8578
_should_exit: anyio.Event # 机器人是否应该进入准备退出状态 # pyright: ignore[reportUninitializedInstanceVariable]
8679
_restart_flag: bool # 重新启动标志
@@ -135,7 +128,6 @@ def __init__(
135128
self.global_state = {}
136129

137130
self.adapters = []
138-
self._current_event = None
139131
self._restart_flag = False
140132
self._module_path_finder = ModulePathFinder()
141133
self._raw_config_dict = {}
@@ -200,7 +192,7 @@ async def run_async(self) -> None:
200192
async def _init(self) -> None:
201193
"""初始化 AliceBot。"""
202194
self._should_exit = anyio.Event()
203-
self._condition = anyio.Condition()
195+
self._event_matchers = []
204196
self._event_send_stream, self._event_receive_stream = (
205197
anyio.create_memory_object_stream(
206198
max_buffer_size=self.config.bot.event_queue_size
@@ -504,74 +496,84 @@ async def handle_event(
504496
async def _handle_event_receive(self) -> None:
505497
async with anyio.create_task_group() as tg, self._event_receive_stream:
506498
async for current_event, handle_get in self._event_receive_stream:
507-
if handle_get:
508-
await tg.start(self._handle_event_wait_condition)
509-
async with self._condition:
510-
self._current_event = current_event
511-
self._condition.notify_all()
512-
else:
513-
tg.start_soon(self._handle_event, current_event)
499+
tg.start_soon(self._handle_event, current_event, handle_get)
500+
501+
async def _handle_event(self, current_event: Event[Any], handle_get: bool) -> None:
502+
async with anyio.create_task_group() as tg:
503+
if handle_get:
504+
event_handled = False
505+
new_event_matchers: list[EventMatcher] = []
506+
async for event_matcher, result in async_map(
507+
tg,
508+
lambda x: x.run(current_event),
509+
self._event_matchers.copy(),
510+
):
511+
if result is None:
512+
# 当前 event_matcher 已经失效,什么都不做
513+
pass
514+
elif result is True:
515+
# 当前 event_matcher 成功匹配事件,设置 event_handled 为 True
516+
event_handled = True
517+
elif result is False:
518+
# 当前 event_matcher 未成功匹配事件,将其放回队列中,等待下次处理
519+
new_event_matchers.append(event_matcher)
520+
self._event_matchers = new_event_matchers
521+
if event_handled:
522+
return
523+
524+
for event_preprocessor_hook_func in self._event_preprocessor_hooks:
525+
await event_preprocessor_hook_func(current_event)
526+
527+
for plugin_priority in sorted(self.plugins_priority_dict.keys()):
528+
logger.debug("Checking for matching plugins", priority=plugin_priority)
529+
stop = False
530+
async for _plugin_class, should_stop in async_map(
531+
tg,
532+
lambda x: self._run_plugin(x, current_event),
533+
self.plugins_priority_dict[plugin_priority],
534+
):
535+
stop = stop or should_stop
536+
if stop:
537+
break
514538

515-
async def _handle_event_wait_condition(
516-
self, *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
517-
) -> None:
518-
async with self._condition:
519-
task_status.started()
520-
await self._condition.wait()
521-
assert self._current_event is not None
522-
current_event = self._current_event
523-
await self._handle_event(current_event)
524-
525-
async def _handle_event(self, current_event: Event[Any]) -> None:
526-
if current_event.__handled__:
527-
return
539+
for event_postprocessor_hook_func in self._event_postprocessor_hooks:
540+
await event_postprocessor_hook_func(current_event)
528541

529-
for _hook_func in self._event_preprocessor_hooks:
530-
await _hook_func(current_event)
542+
logger.info("Event Finished")
531543

532-
for plugin_priority in sorted(self.plugins_priority_dict.keys()):
533-
logger.debug("Checking for matching plugins", priority=plugin_priority)
534-
stop = False
535-
for plugin in self.plugins_priority_dict[plugin_priority]:
544+
async def _run_plugin(
545+
self, plugin_class: type[Plugin[Any, Any, Any]], event: Event[Any]
546+
) -> bool:
547+
try:
548+
async with AsyncExitStack() as stack:
549+
plugin_instance = await solve_dependencies(
550+
plugin_class,
551+
use_cache=True,
552+
stack=stack,
553+
dependency_cache={Bot: self, Event: event},
554+
)
555+
if plugin_instance.name not in self.plugin_state:
556+
plugin_state = plugin_instance.__init_state__()
557+
if plugin_state is not None:
558+
self.plugin_state[plugin_instance.name] = plugin_state
536559
try:
537-
async with AsyncExitStack() as stack:
538-
_plugin = await solve_dependencies(
539-
plugin,
540-
use_cache=True,
541-
stack=stack,
542-
dependency_cache={
543-
Bot: self,
544-
Event: current_event,
545-
},
560+
if await plugin_instance.rule():
561+
logger.info(
562+
"Event will be handled by plugin", plugin=plugin_instance
546563
)
547-
if _plugin.name not in self.plugin_state:
548-
plugin_state = _plugin.__init_state__()
549-
if plugin_state is not None:
550-
self.plugin_state[_plugin.name] = plugin_state
551-
if await _plugin.rule():
552-
logger.info(
553-
"Event will be handled by plugin", plugin=_plugin
554-
)
555-
try:
556-
await _plugin.handle()
557-
finally:
558-
if _plugin.block:
559-
stop = True
560-
except SkipException:
561-
# 插件要求跳过自身继续当前事件传播
562-
continue
563-
except StopException:
564-
# 插件要求停止当前事件传播
565-
stop = True
566-
except Exception:
567-
logger.exception("Exception in plugin", plugin=plugin)
568-
if stop:
569-
break
570-
571-
for _hook_func in self._event_postprocessor_hooks:
572-
await _hook_func(current_event)
573-
574-
logger.info("Event Finished")
564+
await plugin_instance.handle()
565+
finally:
566+
if plugin_instance.block:
567+
raise StopException
568+
except SkipException:
569+
# 插件要求跳过自身继续当前事件传播
570+
pass
571+
except StopException:
572+
# 插件要求停止当前事件传播
573+
return True
574+
except Exception:
575+
logger.exception("Exception in plugin", plugin=plugin_class)
576+
return False
575577

576578
@overload
577579
async def get(
@@ -582,6 +584,7 @@ async def get(
582584
adapter_type: None = None,
583585
max_try_times: Optional[int] = None,
584586
timeout: Optional[Union[int, float]] = None,
587+
to_thread: bool = False,
585588
) -> Event[Any]: ...
586589

587590
@overload
@@ -593,6 +596,7 @@ async def get(
593596
adapter_type: type[Adapter[EventT, Any]],
594597
max_try_times: Optional[int] = None,
595598
timeout: Optional[Union[int, float]] = None,
599+
to_thread: bool = False,
596600
) -> EventT: ...
597601

598602
@overload
@@ -604,6 +608,7 @@ async def get(
604608
adapter_type: Optional[type[Adapter[Any, Any]]] = None,
605609
max_try_times: Optional[int] = None,
606610
timeout: Optional[Union[int, float]] = None,
611+
to_thread: bool = False,
607612
) -> EventT: ...
608613

609614
async def get(
@@ -614,6 +619,7 @@ async def get(
614619
adapter_type: Optional[type[Adapter[Any, Any]]] = None,
615620
max_try_times: Optional[int] = None,
616621
timeout: Optional[Union[int, float]] = None,
622+
to_thread: bool = False,
617623
) -> Event[Any]:
618624
"""获取满足指定条件的的事件,协程会等待直到适配器接收到满足条件的事件、超过最大事件数或超时。
619625
@@ -625,44 +631,25 @@ async def get(
625631
adapter_type: 当指定时,只接受指定适配器产生的事件,先于 func 条件生效。默认为 `None`。
626632
max_try_times: 最大事件数。
627633
timeout: 超时时间。
634+
to_thread: 是否在独立的线程中运行同步函数。仅当 func 为同步函数时生效。默认为 `False`。
628635
629636
Returns:
630637
返回满足 `func` 条件的事件。
631638
632639
Raises:
633640
GetEventTimeout: 超过最大事件数或超时。
634641
"""
635-
_func = wrap_get_func(func, event_type=event_type, adapter_type=adapter_type)
636-
637-
try_times = 0
638-
start_time = time.time()
639-
while not self._should_exit.is_set():
640-
if max_try_times is not None and try_times > max_try_times:
641-
break
642-
if timeout is not None and time.time() - start_time > timeout:
643-
break
644-
645-
async with self._condition:
646-
if timeout is None:
647-
await self._condition.wait()
648-
else:
649-
try:
650-
with anyio.fail_after(start_time + timeout - time.time()):
651-
await self._condition.wait()
652-
except TimeoutError:
653-
break
654-
655-
if (
656-
self._current_event is not None
657-
and not self._current_event.__handled__
658-
and await _func(self._current_event)
659-
):
660-
self._current_event.__handled__ = True
661-
return self._current_event
662-
663-
try_times += 1
664-
665-
raise GetEventTimeout
642+
event_matcher = EventMatcher(
643+
func,
644+
bot=self,
645+
event_type=event_type,
646+
adapter_type=adapter_type,
647+
max_try_times=max_try_times,
648+
timeout=timeout,
649+
to_thread=to_thread,
650+
)
651+
self._event_matchers.append(event_matcher)
652+
return await event_matcher.wait()
666653

667654
def _load_plugin_class(
668655
self,

alicebot/event.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ class Event(ABC, BaseModel, Generic[AdapterT]):
2020
Attributes:
2121
adapter: 产生当前事件的适配器对象。
2222
type: 事件类型。
23-
__handled__: 表示事件是否被处理过了,用于适配器处理。警告:请勿手动更改此属性的值。
2423
"""
2524

2625
model_config = ConfigDict(extra="allow")
@@ -30,7 +29,6 @@ class Event(ABC, BaseModel, Generic[AdapterT]):
3029
else:
3130
adapter: Any
3231
type: Optional[str]
33-
__handled__: bool = False
3432

3533
@override
3634
def __str__(self) -> str:

0 commit comments

Comments
 (0)