Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion astrbot/api/event/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
MessageEventResult,
ResultContentType,
)
from astrbot.core.platform import AstrMessageEvent
from astrbot.core.platform import AstrMessageEvent, RawPlatformEvent

__all__ = [
"AstrMessageEvent",
"CommandResult",
"EventResultType",
"MessageChain",
"MessageEventResult",
"RawPlatformEvent",
"ResultContentType",
]
4 changes: 4 additions & 0 deletions astrbot/api/event/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from astrbot.core.star.register import register_on_plugin_error as on_plugin_error
from astrbot.core.star.register import register_on_plugin_loaded as on_plugin_loaded
from astrbot.core.star.register import register_on_plugin_unloaded as on_plugin_unloaded
from astrbot.core.star.register import (
register_on_raw_platform_event as on_raw_platform_event,
)
from astrbot.core.star.register import register_on_using_llm_tool as on_using_llm_tool
from astrbot.core.star.register import (
register_on_waiting_llm_request as on_waiting_llm_request,
Expand Down Expand Up @@ -59,6 +62,7 @@
"on_plugin_loaded",
"on_plugin_unloaded",
"on_platform_loaded",
"on_raw_platform_event",
"on_waiting_llm_request",
"permission_type",
"platform_adapter_type",
Expand Down
2 changes: 2 additions & 0 deletions astrbot/api/platform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
MessageType,
Platform,
PlatformMetadata,
RawPlatformEvent,
)
from astrbot.core.platform.register import register_platform_adapter

Expand All @@ -18,5 +19,6 @@
"MessageType",
"Platform",
"PlatformMetadata",
"RawPlatformEvent",
"register_platform_adapter",
]
41 changes: 41 additions & 0 deletions astrbot/core/pipeline/context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from astrbot import logger
from astrbot.core.message.message_event_result import CommandResult, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.raw_platform_event import RawPlatformEvent
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry

Expand Down Expand Up @@ -106,3 +107,43 @@ async def call_event_hook(
return True

return event.is_stopped()


async def call_raw_platform_event_hook(
event: RawPlatformEvent,
hook_type: EventType = EventType.OnRawPlatformEvent,
) -> bool:
"""调用原始平台事件钩子函数。"""
handlers = star_handlers_registry.get_handlers_by_event_type(
hook_type,
plugins_name=event.plugins_name,
)
for handler in handlers:
raw_platform_name = handler.extras_configs.get("raw_platform_name")
if raw_platform_name and raw_platform_name != event.platform_name:
continue

raw_platform_id = handler.extras_configs.get("raw_platform_id")
if raw_platform_id and raw_platform_id != event.platform_id:
continue

raw_event_type = handler.extras_configs.get("raw_event_type")
if raw_event_type and raw_event_type != event.event_type:
continue

try:
assert inspect.iscoroutinefunction(handler.handler)
logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
)
await handler.handler(event)
except BaseException:
logger.error(traceback.format_exc())

if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了原始平台事件传播。",
)
return True

return event.is_stopped()
2 changes: 2 additions & 0 deletions astrbot/core/platform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .astrbot_message import AstrBotMessage, Group, MessageMember, MessageType
from .platform import Platform
from .platform_metadata import PlatformMetadata
from .raw_platform_event import RawPlatformEvent

__all__ = [
"AstrBotMessage",
Expand All @@ -11,4 +12,5 @@
"MessageType",
"Platform",
"PlatformMetadata",
"RawPlatformEvent",
]
2 changes: 2 additions & 0 deletions astrbot/core/platform/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def initialize(self) -> None:

# 网页聊天
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
webchat_inst._astrbot_config = self.astrbot_config
self.platform_insts.append(webchat_inst)
self._start_platform_task("webchat", webchat_inst)

Expand Down Expand Up @@ -198,6 +199,7 @@ async def load_platform(self, platform_config: dict) -> None:
return
cls_type = platform_cls_map[platform_config["type"]]
inst: Platform = cls_type(platform_config, self.settings, self.event_queue)
inst._astrbot_config = self.astrbot_config
self._inst_map[platform_config["id"]] = {
"inst": inst,
"client_id": inst.client_self_id,
Expand Down
27 changes: 27 additions & 0 deletions astrbot/core/platform/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .astr_message_event import AstrMessageEvent
from .message_session import MessageSesion
from .platform_metadata import PlatformMetadata
from .raw_platform_event import RawPlatformEvent


class PlatformStatus(Enum):
Expand Down Expand Up @@ -42,6 +43,9 @@ def __init__(self, config: dict, event_queue: Queue) -> None:
self._event_queue = event_queue
self.client_self_id = uuid.uuid4().hex

# 全局配置引用,由 PlatformManager 注入
self._astrbot_config: dict | None = None

# 平台运行状态
self._status: PlatformStatus = PlatformStatus.PENDING
self._errors: list[PlatformError] = []
Expand Down Expand Up @@ -163,3 +167,26 @@ async def webhook_callback(self, request: Any) -> Any:
NotImplementedError: 平台未实现统一 Webhook 模式
"""
raise NotImplementedError(f"平台 {self.meta().name} 未实现统一 Webhook 模式")

async def emit_raw_platform_event(
self,
payload: Any,
*,
meta: dict[str, Any] | None = None,
plugins_name: list[str] | None = None,
) -> bool:
"""发射平台原始事件到框架级 hook。"""
from astrbot.core.pipeline.context_utils import call_raw_platform_event_hook

if plugins_name is None and self._astrbot_config is not None:
plugin_set = self._astrbot_config.get("plugin_set", ["*"])
if plugin_set != ["*"]:
plugins_name = plugin_set

event = RawPlatformEvent(
payload=payload,
platform_meta=self.meta(),
meta=meta,
plugins_name=plugins_name,
)
return await call_raw_platform_event_hook(event)
72 changes: 72 additions & 0 deletions astrbot/core/platform/raw_platform_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

from time import time
from typing import Any

from .platform_metadata import PlatformMetadata


class RawPlatformEvent:
def __init__(
self,
payload: Any,
platform_meta: PlatformMetadata,
meta: dict[str, Any] | None = None,
plugins_name: list[str] | None = None,
) -> None:
self.payload = payload
self.platform_meta = platform_meta
self.meta = meta or {}
self.created_at = time()
self.plugins_name = plugins_name

self._extras: dict[str, Any] = {}
self._stopped = False

# back compatibility with existing event access patterns
self.platform = platform_meta

@property
def platform_name(self) -> str:
return self.platform_meta.name

@property
def platform_id(self) -> str:
return self.platform_meta.id

@property
def adapter_display_name(self) -> str:
return self.platform_meta.adapter_display_name or self.platform_meta.name

@property
def event_type(self) -> str | None:
event_type = self.meta.get("event_type")
if event_type is None:
return None
return str(event_type)

def get_platform_name(self) -> str:
return self.platform_name

def get_platform_id(self) -> str:
return self.platform_id

def stop_event(self) -> None:
self._stopped = True

def continue_event(self) -> None:
self._stopped = False

def is_stopped(self) -> bool:
return self._stopped

def set_extra(self, key: str, value: Any) -> None:
self._extras[key] = value

def get_extra(self, key: str | None = None, default=None) -> Any:
if key is None:
return self._extras
return self._extras.get(key, default)

def clear_extra(self) -> None:
self._extras.clear()
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ async def run(self) -> None:
self.config,
self._event_queue,
self.client,
self,
)
await self.webhook_helper.initialize()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from cryptography.hazmat.primitives.asymmetric import ed25519

from astrbot.api import logger
from astrbot.core.platform.platform import Platform

# remove logger handler
for handler in logging.root.handlers[:]:
Expand All @@ -16,7 +17,11 @@

class QQOfficialWebhook:
def __init__(
self, config: dict, event_queue: asyncio.Queue, botpy_client: Client
self,
config: dict,
event_queue: asyncio.Queue,
botpy_client: Client,
platform: Platform,
) -> None:
self.appid = config["appid"]
self.secret = config["secret"]
Expand All @@ -39,6 +44,7 @@ def __init__(
)
self.client = botpy_client
self.event_queue = event_queue
self.platform = platform
self.shutdown_event = asyncio.Event()
# Deduplication cache for webhook retry callbacks.
self._seen_event_ids: dict[str, float] = {}
Expand Down Expand Up @@ -104,10 +110,18 @@ async def handle_callback(self, request) -> dict:
opcode = msg.get("op")
data = msg.get("d")

context = {
"opcode": opcode,
"event_type": event,
"is_validation": opcode == 13,
"request_path": getattr(request, "path", ""),
"request_method": getattr(request, "method", ""),
}
stopped = await self.platform.emit_raw_platform_event(msg, meta=context)

if opcode == 13:
# validation
signed = await self.webhook_validation(cast(dict, data))
print(signed)
return signed

event_id = msg.get("id")
Expand All @@ -126,7 +140,7 @@ async def handle_callback(self, request) -> dict:
return {"opcode": 12}
self._seen_event_ids[event_id] = now

if event and opcode == BotWebSocket.WS_DISPATCH_EVENT:
if not stopped and event and opcode == BotWebSocket.WS_DISPATCH_EVENT:
event = msg["t"].lower()
try:
func = self._connection.parser[event]
Expand Down
2 changes: 2 additions & 0 deletions astrbot/core/star/register/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
register_on_plugin_error,
register_on_plugin_loaded,
register_on_plugin_unloaded,
register_on_raw_platform_event,
register_on_using_llm_tool,
register_on_waiting_llm_request,
register_permission_type,
Expand All @@ -39,6 +40,7 @@
"register_on_plugin_loaded",
"register_on_plugin_unloaded",
"register_on_platform_loaded",
"register_on_raw_platform_event",
"register_on_waiting_llm_request",
"register_permission_type",
"register_platform_adapter_type",
Expand Down
30 changes: 30 additions & 0 deletions astrbot/core/star/register/star_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,36 @@ def decorator(awaitable):
return decorator


def register_on_raw_platform_event(
platform_name: str | None = None,
platform_id: str | None = None,
event_type: str | None = None,
**kwargs,
):
"""当平台接收到原始事件时。

Hook 参数:
event

说明:
该 hook 不经过消息 pipeline,直接接收平台原始 payload。
首版建议通过 platform_name/platform_id/event_type 做精确匹配。
"""

if platform_name is not None:
kwargs["raw_platform_name"] = platform_name
if platform_id is not None:
kwargs["raw_platform_id"] = platform_id
if event_type is not None:
kwargs["raw_event_type"] = event_type

def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnRawPlatformEvent, **kwargs)
return awaitable

return decorator


def register_on_plugin_error(**kwargs):
"""当插件处理消息异常时触发。

Expand Down
9 changes: 9 additions & 0 deletions astrbot/core/star/star_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ def get_handlers_by_event_type(
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...

@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnRawPlatformEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...

@overload
def get_handlers_by_event_type(
self,
Expand Down Expand Up @@ -221,6 +229,7 @@ class EventType(enum.Enum):
OnPluginErrorEvent = enum.auto() # 插件处理消息异常时
OnPluginLoadedEvent = enum.auto() # 插件加载完成
OnPluginUnloadedEvent = enum.auto() # 插件卸载完成
OnRawPlatformEvent = enum.auto() # 收到平台原始事件


H = TypeVar("H", bound=Callable[..., Any])
Expand Down
Loading
Loading