diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index d0ef33b815..9f217b20a9 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -190,53 +190,100 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]): name: str = "send_message_to_user" description: str = ( "Send message to the user. " - "Supports various message types including `plain`, `image`, `record`, `video`, `file`, and `mention_user`. " - "Use this tool to send media files (`image`, `record`, `video`, `file`), " - "or when you need to proactively message the user(such as cron job). For normal text replies, you can output directly." + "Use flat parameters only. " + "Exactly one primary payload is allowed per call: `text` OR `path` OR `url`. " + "`mention_user_id` can be combined with one primary payload, or used alone. " + "For `path`/`url`, media type is inferred automatically from file extension." ) parameters: dict = Field( default_factory=lambda: { "type": "object", "properties": { - "messages": { - "type": "array", - "description": "An ordered list of message components to send. `mention_user` type can be used to mention the user.", - "items": { - "type": "object", - "properties": { - "type": { - "type": "string", - "description": ( - "Component type. One of: " - "plain, image, record, video, file, mention_user. Record is voice message." - ), - }, - "text": { - "type": "string", - "description": "Text content for `plain` type.", - }, - "path": { - "type": "string", - "description": "File path for `image`, `record`, or `file` types. Both local path and sandbox path are supported.", - }, - "url": { - "type": "string", - "description": "URL for `image`, `record`, or `file` types.", - }, - "mention_user_id": { - "type": "string", - "description": "User ID to mention for `mention_user` type.", - }, - }, - "required": ["type"], - }, + "text": { + "type": "string", + "description": "Plain text content. Whitespace-only text is treated as empty.", + }, + "path": { + "type": "string", + "description": "Local path or sandbox path to a media/file.", + }, + "url": { + "type": "string", + "description": "Remote media/file URL.", + }, + "name": { + "type": "string", + "description": "Optional file name when inferred type is `file`.", + }, + "mention_user_id": { + "type": "string", + "description": "Optional user ID to mention. Can be combined with one primary payload.", + }, + "session": { + "type": "string", + "description": "Optional target session. Defaults to current session.", }, }, - "required": ["messages"], } ) + _IMAGE_EXTS = { + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", + ".bmp", + ".tif", + ".tiff", + ".ico", + ".heic", + ".heif", + ".avif", + } + _AUDIO_EXTS = { + ".mp3", + ".wav", + ".m4a", + ".aac", + ".flac", + ".ogg", + ".opus", + ".amr", + ".wma", + } + _VIDEO_EXTS = { + ".mp4", + ".mkv", + ".mov", + ".avi", + ".webm", + ".m4v", + ".flv", + ".wmv", + ".3gp", + ".mpeg", + ".mpg", + } + + def _normalize_ref_path(self, ref: str) -> str: + return str(ref).split("?", 1)[0].split("#", 1)[0] + + def _basename_from_ref(self, ref: str) -> str: + return os.path.basename(self._normalize_ref_path(ref)) + + def _infer_component_type_from_ref(self, ref: str) -> str: + clean_ref = self._normalize_ref_path(ref) + ext = os.path.splitext(clean_ref)[1].lower() + if ext in self._IMAGE_EXTS: + return "image" + if ext in self._AUDIO_EXTS: + return "record" + if ext in self._VIDEO_EXTS: + return "video" + return "file" + async def _resolve_path_from_sandbox( self, context: ContextWrapper[AstrAgentContext], path: str ) -> tuple[str, bool]: @@ -276,102 +323,58 @@ async def call( self, context: ContextWrapper[AstrAgentContext], **kwargs ) -> ToolExecResult: session = kwargs.get("session") or context.context.event.unified_msg_origin - messages = kwargs.get("messages") - - if not isinstance(messages, list) or not messages: - return "error: messages parameter is empty or invalid." - components: list[Comp.BaseMessageComponent] = [] - - for idx, msg in enumerate(messages): - if not isinstance(msg, dict): - return f"error: messages[{idx}] should be an object." - - msg_type = str(msg.get("type", "")).lower() - if not msg_type: - return f"error: messages[{idx}].type is required." - - file_from_sandbox = False - - try: - if msg_type == "plain": - text = str(msg.get("text", "")).strip() - if not text: - return f"error: messages[{idx}].text is required for plain component." - components.append(Comp.Plain(text=text)) - elif msg_type == "image": - path = msg.get("path") - url = msg.get("url") - if path: - ( - local_path, - file_from_sandbox, - ) = await self._resolve_path_from_sandbox(context, path) - components.append(Comp.Image.fromFileSystem(path=local_path)) - elif url: - components.append(Comp.Image.fromURL(url=url)) - else: - return f"error: messages[{idx}] must include path or url for image component." - elif msg_type == "record": - path = msg.get("path") - url = msg.get("url") - if path: - ( - local_path, - file_from_sandbox, - ) = await self._resolve_path_from_sandbox(context, path) - components.append(Comp.Record.fromFileSystem(path=local_path)) - elif url: - components.append(Comp.Record.fromURL(url=url)) - else: - return f"error: messages[{idx}] must include path or url for record component." - elif msg_type == "video": - path = msg.get("path") - url = msg.get("url") - if path: - ( - local_path, - file_from_sandbox, - ) = await self._resolve_path_from_sandbox(context, path) - components.append(Comp.Video.fromFileSystem(path=local_path)) - elif url: - components.append(Comp.Video.fromURL(url=url)) - else: - return f"error: messages[{idx}] must include path or url for video component." - elif msg_type == "file": - path = msg.get("path") - url = msg.get("url") - name = ( - msg.get("text") - or (os.path.basename(path) if path else "") - or (os.path.basename(url) if url else "") + text = str(kwargs.get("text", "")).strip() + path = kwargs.get("path") + url = kwargs.get("url") + name = kwargs.get("name") + mention_user_id = kwargs.get("mention_user_id") + + primary_count = sum(map(bool, (text, path, url))) + + if primary_count > 1: + return "error: only one primary payload is allowed per call (`text` OR `path` OR `url`)." + if primary_count == 0 and not mention_user_id: + return "error: missing payload. Provide one of `text`, `path`, `url`, or provide `mention_user_id`." + + if mention_user_id: + components.append(Comp.At(qq=mention_user_id)) + + component_map = { + "image": Comp.Image, + "record": Comp.Record, + "video": Comp.Video, + } + try: + if text: + components.append(Comp.Plain(text=text)) + elif path: + component_type = self._infer_component_type_from_ref(path) + local_path, _ = await self._resolve_path_from_sandbox(context, path) + component_cls = component_map.get(component_type) + if component_cls: + components.append(component_cls.fromFileSystem(path=local_path)) + else: + file_name = ( + (str(name).strip() if name is not None else "") + or os.path.basename(local_path) or "file" ) - if path: - ( - local_path, - file_from_sandbox, - ) = await self._resolve_path_from_sandbox(context, path) - components.append(Comp.File(name=name, file=local_path)) - elif url: - components.append(Comp.File(name=name, url=url)) - else: - return f"error: messages[{idx}] must include path or url for file component." - elif msg_type == "mention_user": - mention_user_id = msg.get("mention_user_id") - if not mention_user_id: - return f"error: messages[{idx}].mention_user_id is required for mention_user component." - components.append( - Comp.At( - qq=mention_user_id, - ), - ) + components.append(Comp.File(name=file_name, file=local_path)) + elif url: + component_type = self._infer_component_type_from_ref(url) + component_cls = component_map.get(component_type) + if component_cls: + components.append(component_cls.fromURL(url=url)) else: - return ( - f"error: unsupported message type '{msg_type}' at index {idx}." + file_name = ( + (str(name).strip() if name is not None else "") + or self._basename_from_ref(url) + or "file" ) - except Exception as exc: # 捕获组件构造异常,避免直接抛出 - return f"error: failed to build messages[{idx}] component: {exc}" + components.append(Comp.File(name=file_name, url=url)) + except Exception as exc: + return f"error: failed to build message component: {exc}" try: target_session = ( diff --git a/tests/unit/test_astr_main_agent_resources.py b/tests/unit/test_astr_main_agent_resources.py new file mode 100644 index 0000000000..d70ed44617 --- /dev/null +++ b/tests/unit/test_astr_main_agent_resources.py @@ -0,0 +1,164 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.astr_main_agent_resources import SendMessageToUserTool +from astrbot.core.message.components import At, File, Image, Plain, Record +from astrbot.core.message.message_event_result import MessageChain + + +def _build_run_context( + unified_msg_origin: str = "test:FriendMessage:session-1", +) -> tuple[ContextWrapper, AsyncMock]: + send_message = AsyncMock() + inner_ctx = SimpleNamespace(send_message=send_message) + event = SimpleNamespace(unified_msg_origin=unified_msg_origin) + wrapped = ContextWrapper(context=SimpleNamespace(context=inner_ctx, event=event)) + return wrapped, send_message + + +@pytest.mark.asyncio +async def test_send_message_text_only_success(): + tool = SendMessageToUserTool() + run_context, send_message = _build_run_context() + + result = await tool.call(run_context, text="hello") + + assert result.startswith("Message sent to session") + send_message.assert_awaited_once() + target_session, chain = send_message.await_args.args + assert str(target_session) == "test:FriendMessage:session-1" + assert isinstance(chain, MessageChain) + assert len(chain.chain) == 1 + assert isinstance(chain.chain[0], Plain) + assert chain.chain[0].text == "hello" + + +@pytest.mark.asyncio +async def test_send_message_rejects_multiple_primary_payloads(): + tool = SendMessageToUserTool() + run_context, send_message = _build_run_context() + + result = await tool.call(run_context, text="hello", url="https://example.com/a.png") + + assert "only one primary payload is allowed" in result + send_message.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_message_supports_mention_with_text(): + tool = SendMessageToUserTool() + run_context, send_message = _build_run_context() + + result = await tool.call(run_context, mention_user_id="12345", text="ping") + + assert result.startswith("Message sent to session") + send_message.assert_awaited_once() + _, chain = send_message.await_args.args + assert isinstance(chain, MessageChain) + assert len(chain.chain) == 2 + assert isinstance(chain.chain[0], At) + assert isinstance(chain.chain[1], Plain) + assert chain.chain[1].text == "ping" + + +@pytest.mark.asyncio +async def test_send_message_supports_mention_only(): + tool = SendMessageToUserTool() + run_context, send_message = _build_run_context() + + result = await tool.call(run_context, mention_user_id="12345") + + assert result.startswith("Message sent to session") + send_message.assert_awaited_once() + _, chain = send_message.await_args.args + assert isinstance(chain, MessageChain) + assert len(chain.chain) == 1 + assert isinstance(chain.chain[0], At) + + +@pytest.mark.asyncio +async def test_send_message_url_infers_image_component(): + tool = SendMessageToUserTool() + run_context, send_message = _build_run_context() + + result = await tool.call(run_context, url="https://example.com/photo.png") + + assert result.startswith("Message sent to session") + send_message.assert_awaited_once() + _, chain = send_message.await_args.args + assert isinstance(chain, MessageChain) + assert len(chain.chain) == 1 + assert isinstance(chain.chain[0], Image) + + +@pytest.mark.asyncio +async def test_send_message_path_infers_record_component(monkeypatch: pytest.MonkeyPatch): + tool = SendMessageToUserTool() + run_context, send_message = _build_run_context() + monkeypatch.setattr( + tool, + "_resolve_path_from_sandbox", + AsyncMock(return_value=("/tmp/voice.mp3", False)), + ) + + result = await tool.call(run_context, path="/sandbox/voice.mp3") + + assert result.startswith("Message sent to session") + send_message.assert_awaited_once() + _, chain = send_message.await_args.args + assert isinstance(chain, MessageChain) + assert len(chain.chain) == 1 + assert isinstance(chain.chain[0], Record) + + +@pytest.mark.asyncio +async def test_send_message_path_component_construction_error( + monkeypatch: pytest.MonkeyPatch, +): + tool = SendMessageToUserTool() + run_context, send_message = _build_run_context() + + monkeypatch.setattr( + tool, + "_resolve_path_from_sandbox", + AsyncMock(side_effect=RuntimeError("boom")), + ) + + result = await tool.call(run_context, path="/sandbox/voice.mp3") + + assert result.startswith("error: failed to build message component") + send_message.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_message_whitespace_text_treated_as_empty(): + tool = SendMessageToUserTool() + run_context, send_message = _build_run_context() + + result = await tool.call(run_context, text=" ") + + assert "error: missing payload" in result + send_message.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_message_url_unknown_extension_falls_back_to_file(): + tool = SendMessageToUserTool() + run_context, send_message = _build_run_context() + + result = await tool.call( + run_context, + url="https://example.com/report.unknown", + name="report.txt", + ) + + assert result.startswith("Message sent to session") + send_message.assert_awaited_once() + _, chain = send_message.await_args.args + assert isinstance(chain, MessageChain) + assert len(chain.chain) == 1 + assert isinstance(chain.chain[0], File) + assert chain.chain[0].name == "report.txt"