88import signal
99import sys
1010import threading
11- import time
1211from collections import defaultdict
1312from collections .abc import Awaitable
1413from contextlib import AsyncExitStack
1817
1918import anyio
2019import structlog
21- from anyio .abc import TaskStatus
2220from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
2321from pydantic import ValidationError , create_model
2422
2523from alicebot .adapter import Adapter
2624from alicebot .config import AdapterConfig , ConfigModel , MainConfig , PluginConfig
2725from alicebot .dependencies import solve_dependencies
2826from alicebot .event import Event , EventHandleOption
29- from alicebot .exceptions import (
30- GetEventTimeout ,
31- LoadModuleError ,
32- SkipException ,
33- StopException ,
34- )
27+ from alicebot .exceptions import LoadModuleError , SkipException , StopException
28+ from alicebot .matcher import EventMatcher
3529from alicebot .plugin import Plugin , PluginLoadType
3630from alicebot .typing import AdapterHook , AdapterT , BotHook , EventHook , EventT
3731from alicebot .utils import (
3832 ModulePathFinder ,
33+ async_map ,
3934 get_classes_from_module_name ,
4035 is_config_class ,
4136 samefile ,
42- wrap_get_func ,
4337)
4438
4539if sys .version_info >= (3 , 11 ): # pragma: no cover
@@ -79,8 +73,7 @@ class Bot:
7973
8074 _event_send_stream : MemoryObjectSendStream [EventHandleOption ] # pyright: ignore[reportUninitializedInstanceVariable]
8175 _event_receive_stream : MemoryObjectReceiveStream [EventHandleOption ] # pyright: ignore[reportUninitializedInstanceVariable]
82- _condition : anyio .Condition # 用于处理 get 的 Condition # pyright: ignore[reportUninitializedInstanceVariable]
83- _current_event : Optional [Event [Any ]] # 当前待处理的 Event
76+ _event_matchers : list [EventMatcher ] # pyright: ignore[reportUninitializedInstanceVariable]
8477
8578 _should_exit : anyio .Event # 机器人是否应该进入准备退出状态 # pyright: ignore[reportUninitializedInstanceVariable]
8679 _restart_flag : bool # 重新启动标志
@@ -135,7 +128,6 @@ def __init__(
135128 self .global_state = {}
136129
137130 self .adapters = []
138- self ._current_event = None
139131 self ._restart_flag = False
140132 self ._module_path_finder = ModulePathFinder ()
141133 self ._raw_config_dict = {}
@@ -200,7 +192,7 @@ async def run_async(self) -> None:
200192 async def _init (self ) -> None :
201193 """初始化 AliceBot。"""
202194 self ._should_exit = anyio .Event ()
203- self ._condition = anyio . Condition ()
195+ self ._event_matchers = []
204196 self ._event_send_stream , self ._event_receive_stream = (
205197 anyio .create_memory_object_stream (
206198 max_buffer_size = self .config .bot .event_queue_size
@@ -504,74 +496,84 @@ async def handle_event(
504496 async def _handle_event_receive (self ) -> None :
505497 async with anyio .create_task_group () as tg , self ._event_receive_stream :
506498 async for current_event , handle_get in self ._event_receive_stream :
507- if handle_get :
508- await tg .start (self ._handle_event_wait_condition )
509- async with self ._condition :
510- self ._current_event = current_event
511- self ._condition .notify_all ()
512- else :
513- tg .start_soon (self ._handle_event , current_event )
499+ tg .start_soon (self ._handle_event , current_event , handle_get )
500+
501+ async def _handle_event (self , current_event : Event [Any ], handle_get : bool ) -> None :
502+ async with anyio .create_task_group () as tg :
503+ if handle_get :
504+ event_handled = False
505+ new_event_matchers : list [EventMatcher ] = []
506+ async for event_matcher , result in async_map (
507+ tg ,
508+ lambda x : x .run (current_event ),
509+ self ._event_matchers .copy (),
510+ ):
511+ if result is None :
512+ # 当前 event_matcher 已经失效,什么都不做
513+ pass
514+ elif result is True :
515+ # 当前 event_matcher 成功匹配事件,设置 event_handled 为 True
516+ event_handled = True
517+ elif result is False :
518+ # 当前 event_matcher 未成功匹配事件,将其放回队列中,等待下次处理
519+ new_event_matchers .append (event_matcher )
520+ self ._event_matchers = new_event_matchers
521+ if event_handled :
522+ return
523+
524+ for event_preprocessor_hook_func in self ._event_preprocessor_hooks :
525+ await event_preprocessor_hook_func (current_event )
526+
527+ for plugin_priority in sorted (self .plugins_priority_dict .keys ()):
528+ logger .debug ("Checking for matching plugins" , priority = plugin_priority )
529+ stop = False
530+ async for _plugin_class , should_stop in async_map (
531+ tg ,
532+ lambda x : self ._run_plugin (x , current_event ),
533+ self .plugins_priority_dict [plugin_priority ],
534+ ):
535+ stop = stop or should_stop
536+ if stop :
537+ break
514538
515- async def _handle_event_wait_condition (
516- self , * , task_status : TaskStatus [None ] = anyio .TASK_STATUS_IGNORED
517- ) -> None :
518- async with self ._condition :
519- task_status .started ()
520- await self ._condition .wait ()
521- assert self ._current_event is not None
522- current_event = self ._current_event
523- await self ._handle_event (current_event )
524-
525- async def _handle_event (self , current_event : Event [Any ]) -> None :
526- if current_event .__handled__ :
527- return
539+ for event_postprocessor_hook_func in self ._event_postprocessor_hooks :
540+ await event_postprocessor_hook_func (current_event )
528541
529- for _hook_func in self ._event_preprocessor_hooks :
530- await _hook_func (current_event )
542+ logger .info ("Event Finished" )
531543
532- for plugin_priority in sorted (self .plugins_priority_dict .keys ()):
533- logger .debug ("Checking for matching plugins" , priority = plugin_priority )
534- stop = False
535- for plugin in self .plugins_priority_dict [plugin_priority ]:
544+ async def _run_plugin (
545+ self , plugin_class : type [Plugin [Any , Any , Any ]], event : Event [Any ]
546+ ) -> bool :
547+ try :
548+ async with AsyncExitStack () as stack :
549+ plugin_instance = await solve_dependencies (
550+ plugin_class ,
551+ use_cache = True ,
552+ stack = stack ,
553+ dependency_cache = {Bot : self , Event : event },
554+ )
555+ if plugin_instance .name not in self .plugin_state :
556+ plugin_state = plugin_instance .__init_state__ ()
557+ if plugin_state is not None :
558+ self .plugin_state [plugin_instance .name ] = plugin_state
536559 try :
537- async with AsyncExitStack () as stack :
538- _plugin = await solve_dependencies (
539- plugin ,
540- use_cache = True ,
541- stack = stack ,
542- dependency_cache = {
543- Bot : self ,
544- Event : current_event ,
545- },
560+ if await plugin_instance .rule ():
561+ logger .info (
562+ "Event will be handled by plugin" , plugin = plugin_instance
546563 )
547- if _plugin .name not in self .plugin_state :
548- plugin_state = _plugin .__init_state__ ()
549- if plugin_state is not None :
550- self .plugin_state [_plugin .name ] = plugin_state
551- if await _plugin .rule ():
552- logger .info (
553- "Event will be handled by plugin" , plugin = _plugin
554- )
555- try :
556- await _plugin .handle ()
557- finally :
558- if _plugin .block :
559- stop = True
560- except SkipException :
561- # 插件要求跳过自身继续当前事件传播
562- continue
563- except StopException :
564- # 插件要求停止当前事件传播
565- stop = True
566- except Exception :
567- logger .exception ("Exception in plugin" , plugin = plugin )
568- if stop :
569- break
570-
571- for _hook_func in self ._event_postprocessor_hooks :
572- await _hook_func (current_event )
573-
574- logger .info ("Event Finished" )
564+ await plugin_instance .handle ()
565+ finally :
566+ if plugin_instance .block :
567+ raise StopException
568+ except SkipException :
569+ # 插件要求跳过自身继续当前事件传播
570+ pass
571+ except StopException :
572+ # 插件要求停止当前事件传播
573+ return True
574+ except Exception :
575+ logger .exception ("Exception in plugin" , plugin = plugin_class )
576+ return False
575577
576578 @overload
577579 async def get (
@@ -582,6 +584,7 @@ async def get(
582584 adapter_type : None = None ,
583585 max_try_times : Optional [int ] = None ,
584586 timeout : Optional [Union [int , float ]] = None ,
587+ to_thread : bool = False ,
585588 ) -> Event [Any ]: ...
586589
587590 @overload
@@ -593,6 +596,7 @@ async def get(
593596 adapter_type : type [Adapter [EventT , Any ]],
594597 max_try_times : Optional [int ] = None ,
595598 timeout : Optional [Union [int , float ]] = None ,
599+ to_thread : bool = False ,
596600 ) -> EventT : ...
597601
598602 @overload
@@ -604,6 +608,7 @@ async def get(
604608 adapter_type : Optional [type [Adapter [Any , Any ]]] = None ,
605609 max_try_times : Optional [int ] = None ,
606610 timeout : Optional [Union [int , float ]] = None ,
611+ to_thread : bool = False ,
607612 ) -> EventT : ...
608613
609614 async def get (
@@ -614,6 +619,7 @@ async def get(
614619 adapter_type : Optional [type [Adapter [Any , Any ]]] = None ,
615620 max_try_times : Optional [int ] = None ,
616621 timeout : Optional [Union [int , float ]] = None ,
622+ to_thread : bool = False ,
617623 ) -> Event [Any ]:
618624 """获取满足指定条件的的事件,协程会等待直到适配器接收到满足条件的事件、超过最大事件数或超时。
619625
@@ -625,44 +631,25 @@ async def get(
625631 adapter_type: 当指定时,只接受指定适配器产生的事件,先于 func 条件生效。默认为 `None`。
626632 max_try_times: 最大事件数。
627633 timeout: 超时时间。
634+ to_thread: 是否在独立的线程中运行同步函数。仅当 func 为同步函数时生效。默认为 `False`。
628635
629636 Returns:
630637 返回满足 `func` 条件的事件。
631638
632639 Raises:
633640 GetEventTimeout: 超过最大事件数或超时。
634641 """
635- _func = wrap_get_func (func , event_type = event_type , adapter_type = adapter_type )
636-
637- try_times = 0
638- start_time = time .time ()
639- while not self ._should_exit .is_set ():
640- if max_try_times is not None and try_times > max_try_times :
641- break
642- if timeout is not None and time .time () - start_time > timeout :
643- break
644-
645- async with self ._condition :
646- if timeout is None :
647- await self ._condition .wait ()
648- else :
649- try :
650- with anyio .fail_after (start_time + timeout - time .time ()):
651- await self ._condition .wait ()
652- except TimeoutError :
653- break
654-
655- if (
656- self ._current_event is not None
657- and not self ._current_event .__handled__
658- and await _func (self ._current_event )
659- ):
660- self ._current_event .__handled__ = True
661- return self ._current_event
662-
663- try_times += 1
664-
665- raise GetEventTimeout
642+ event_matcher = EventMatcher (
643+ func ,
644+ bot = self ,
645+ event_type = event_type ,
646+ adapter_type = adapter_type ,
647+ max_try_times = max_try_times ,
648+ timeout = timeout ,
649+ to_thread = to_thread ,
650+ )
651+ self ._event_matchers .append (event_matcher )
652+ return await event_matcher .wait ()
666653
667654 def _load_plugin_class (
668655 self ,
0 commit comments