diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index ade72f1107..96c7c5568c 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -1,6 +1,7 @@ import asyncio import os import re +from collections.abc import Callable from typing import Any, cast import telegramify_markdown @@ -21,6 +22,7 @@ Video, ) from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata +from astrbot.core.utils.metrics import Metric class TelegramPlatformEvent(AstrMessageEvent): @@ -34,6 +36,20 @@ class TelegramPlatformEvent(AstrMessageEvent): "word": re.compile(r"\s"), } + # sendMessageDraft 的 draft_id 类级递增计数器 + _TELEGRAM_DRAFT_ID_MAX = 2_147_483_647 + _next_draft_id: int = 0 + + @classmethod + def _allocate_draft_id(cls) -> int: + """分配一个递增的 draft_id,溢出时归 1。""" + cls._next_draft_id = ( + 1 + if cls._next_draft_id >= cls._TELEGRAM_DRAFT_ID_MAX + else cls._next_draft_id + 1 + ) + return cls._next_draft_id + # 消息类型到 chat action 的映射,用于优先级判断 ACTION_BY_TYPE: dict[type, str] = { Record: ChatAction.UPLOAD_VOICE, @@ -339,6 +355,118 @@ async def react(self, emoji: str | None, big: bool = False) -> None: except Exception as e: logger.error(f"[Telegram] 添加反应失败: {e}") + async def _send_message_draft( + self, + chat_id: str, + draft_id: int, + text: str, + message_thread_id: str | None = None, + parse_mode: str | None = None, + ) -> None: + """通过 Bot.send_message_draft 发送草稿消息(流式推送部分消息)。 + + 该 API 仅支持私聊。 + + Args: + chat_id: 目标私聊的 chat_id + draft_id: 草稿唯一标识,非零整数;相同 draft_id 的变更会以动画展示 + text: 消息文本,1-4096 字符 + message_thread_id: 可选,目标消息线程 ID + parse_mode: 可选,消息文本的解析模式 + """ + kwargs: dict[str, Any] = {} + if message_thread_id: + kwargs["message_thread_id"] = int(message_thread_id) + if parse_mode: + kwargs["parse_mode"] = parse_mode + + try: + logger.debug( + f"[Telegram] sendMessageDraft: chat_id={chat_id}, draft_id={draft_id}, text_len={len(text)}" + ) + await self.client.send_message_draft( + chat_id=int(chat_id), + draft_id=draft_id, + text=text, + **kwargs, + ) + except Exception as e: + logger.warning(f"[Telegram] sendMessageDraft 失败: {e!s}") + + async def _process_chain_items( + self, + chain: MessageChain, + payload: dict[str, Any], + user_name: str, + message_thread_id: str | None, + on_text: Callable[[str], None], + ) -> None: + """处理 MessageChain 中的各类组件,文本通过 on_text 回调追加,媒体直接发送。""" + for i in chain.chain: + if isinstance(i, Plain): + on_text(i.text) + elif isinstance(i, Image): + image_path = await i.convert_to_file_path() + await self._send_media_with_action( + self.client, + ChatAction.UPLOAD_PHOTO, + self.client.send_photo, + user_name=user_name, + photo=image_path, + **cast(Any, payload), + ) + elif isinstance(i, File): + path = await i.get_file() + name = i.name or os.path.basename(path) + await self._send_media_with_action( + self.client, + ChatAction.UPLOAD_DOCUMENT, + self.client.send_document, + user_name=user_name, + document=path, + filename=name, + **cast(Any, payload), + ) + elif isinstance(i, Record): + path = await i.convert_to_file_path() + await self._send_voice_with_fallback( + self.client, + path, + payload, + caption=i.text or None, + user_name=user_name, + message_thread_id=message_thread_id, + use_media_action=True, + ) + elif isinstance(i, Video): + path = await i.convert_to_file_path() + await self._send_media_with_action( + self.client, + ChatAction.UPLOAD_VIDEO, + self.client.send_video, + user_name=user_name, + video=path, + **cast(Any, payload), + ) + else: + logger.warning(f"不支持的消息类型: {type(i)}") + + async def _send_final_segment(self, delta: str, payload: dict[str, Any]) -> None: + """将累积文本作为 MarkdownV2 真实消息发送,失败时回退到纯文本。""" + try: + markdown_text = telegramify_markdown.markdownify( + delta, + normalize_whitespace=False, + ) + await self.client.send_message( + text=markdown_text, + parse_mode="MarkdownV2", + **cast(Any, payload), + ) + except Exception as e: + logger.warning(f"Markdown转换失败,使用普通文本: {e!s}") + await self.client.send_message(text=delta, **cast(Any, payload)) + async def send_streaming(self, generator, use_fallback: bool = False): message_thread_id = None @@ -356,6 +484,138 @@ async def send_streaming(self, generator, use_fallback: bool = False): if message_thread_id: payload["message_thread_id"] = message_thread_id + # sendMessageDraft 仅支持私聊(显式检查 FRIEND_MESSAGE) + is_private = self.get_message_type() == MessageType.FRIEND_MESSAGE + + if is_private: + logger.info("[Telegram] 流式输出: 使用 sendMessageDraft (私聊)") + await self._send_streaming_draft( + user_name, message_thread_id, payload, generator + ) + else: + logger.info("[Telegram] 流式输出: 使用 edit_message_text fallback (群聊)") + await self._send_streaming_edit( + user_name, message_thread_id, payload, generator + ) + + # 内联父类 send_streaming 的副作用(避免传入已消费的 generator) + asyncio.create_task( + Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name), + ) + self._has_send_oper = True + + async def _send_streaming_draft( + self, + user_name: str, + message_thread_id: str | None, + payload: dict[str, Any], + generator, + ) -> None: + """使用 sendMessageDraft API 进行流式推送(私聊专用)。 + + 流式过程中使用 sendMessageDraft 推送草稿动画, + 流式结束后发送一条真实消息保留最终内容(draft 是临时的,会消失)。 + 使用信号驱动的发送循环:每次有新 token 到达时唤醒发送, + 发送频率由网络 RTT 自然限制(最多一个请求 in-flight)。 + """ + draft_id = self._allocate_draft_id() + delta = "" + last_sent_text = "" + done = False # 信号:生成器已结束 + text_changed = asyncio.Event() # 有新 token 到达时触发 + + async def _draft_sender_loop() -> None: + """信号驱动的草稿发送循环,有新内容就发,RTT 自然限流。""" + nonlocal last_sent_text + while not done: + await text_changed.wait() + text_changed.clear() + # 发送最新的缓冲区内容(MarkdownV2 渲染,与真实消息一致) + if delta and delta != last_sent_text: + draft_text = delta[: self.MAX_MESSAGE_LENGTH] + if draft_text != last_sent_text: + try: + md = telegramify_markdown.markdownify( + draft_text, + normalize_whitespace=False, + ) + await self._send_message_draft( + user_name, + draft_id, + md, + message_thread_id, + parse_mode="MarkdownV2", + ) + last_sent_text = draft_text + except Exception: + # markdownify 对未闭合语法可能失败,回退纯文本 + try: + await self._send_message_draft( + user_name, + draft_id, + draft_text, + message_thread_id, + ) + last_sent_text = draft_text + except Exception as e2: + logger.debug( + f"[Telegram] sendMessageDraft failed (ignored): {e2!s}" + ) + + sender_task = asyncio.create_task(_draft_sender_loop()) + + def _append_text(t: str) -> None: + nonlocal delta + delta += t + text_changed.set() # 唤醒发送循环 + + try: + async for chain in generator: + if not isinstance(chain, MessageChain): + continue + + if chain.type == "break": + # 分割符:发送真实消息保留内容,重置缓冲区 + if delta: + # 用 emoji 清空 draft 显示,避免 draft 和真实消息同时可见 + await self._send_message_draft( + user_name, + draft_id, + "\u23f3", + message_thread_id, + ) + await self._send_final_segment(delta, payload) + delta = "" + last_sent_text = "" + draft_id = self._allocate_draft_id() + continue + + await self._process_chain_items( + chain, payload, user_name, message_thread_id, _append_text + ) + finally: + done = True + text_changed.set() # 唤醒循环使其退出 + await sender_task + + # 流式结束:用 emoji 清空 draft,然后发真实消息持久化 + if delta: + await self._send_message_draft( + user_name, + draft_id, + "\u23f3", + message_thread_id, + ) + await self._send_final_segment(delta, payload) + + async def _send_streaming_edit( + self, + user_name: str, + message_thread_id: str | None, + payload: dict[str, Any], + generator, + ) -> None: + """使用 send_message + edit_message_text 进行流式推送(群聊 fallback)。""" delta = "" current_content = "" message_id = None @@ -368,121 +628,67 @@ async def send_streaming(self, generator, use_fallback: bool = False): await self._ensure_typing(user_name, message_thread_id) last_chat_action_time = asyncio.get_event_loop().time() + def _append_text(t: str) -> None: + nonlocal delta + delta += t + async for chain in generator: - if isinstance(chain, MessageChain): - if chain.type == "break": - # 分割符 - if message_id: - try: - await self.client.edit_message_text( - text=delta, - chat_id=payload["chat_id"], - message_id=message_id, - ) - except Exception as e: - logger.warning(f"编辑消息失败(streaming-break): {e!s}") - message_id = None # 重置消息 ID - delta = "" # 重置 delta - continue + if not isinstance(chain, MessageChain): + continue - # 处理消息链中的每个组件 - for i in chain.chain: - if isinstance(i, Plain): - delta += i.text - elif isinstance(i, Image): - image_path = await i.convert_to_file_path() - await self._send_media_with_action( - self.client, - ChatAction.UPLOAD_PHOTO, - self.client.send_photo, - user_name=user_name, - photo=image_path, - **cast(Any, payload), - ) - continue - elif isinstance(i, File): - path = await i.get_file() - name = i.name or os.path.basename(path) - await self._send_media_with_action( - self.client, - ChatAction.UPLOAD_DOCUMENT, - self.client.send_document, - user_name=user_name, - document=path, - filename=name, - **cast(Any, payload), - ) - continue - elif isinstance(i, Record): - path = await i.convert_to_file_path() - await self._send_voice_with_fallback( - self.client, - path, - payload, - caption=i.text or delta or None, - user_name=user_name, - message_thread_id=message_thread_id, - use_media_action=True, - ) - continue - elif isinstance(i, Video): - path = await i.convert_to_file_path() - await self._send_media_with_action( - self.client, - ChatAction.UPLOAD_VIDEO, - self.client.send_video, - user_name=user_name, - video=path, - **cast(Any, payload), + if chain.type == "break": + # 分割符 + if message_id: + try: + await self.client.edit_message_text( + text=delta, + chat_id=payload["chat_id"], + message_id=message_id, ) - continue - else: - logger.warning(f"不支持的消息类型: {type(i)}") - continue + except Exception as e: + logger.warning(f"编辑消息失败(streaming-break): {e!s}") + message_id = None + delta = "" + continue - # Plain - if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH: - current_time = asyncio.get_event_loop().time() - time_since_last_edit = current_time - last_edit_time - - # 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间 - if time_since_last_edit >= throttle_interval: - # 发送 typing 状态(带节流) - current_time = asyncio.get_event_loop().time() - if current_time - last_chat_action_time >= chat_action_interval: - await self._ensure_typing(user_name, message_thread_id) - last_chat_action_time = current_time - # 编辑消息 - try: - await self.client.edit_message_text( - text=delta, - chat_id=payload["chat_id"], - message_id=message_id, - ) - current_content = delta - except Exception as e: - logger.warning(f"编辑消息失败(streaming): {e!s}") - last_edit_time = ( - asyncio.get_event_loop().time() - ) # 更新上次编辑的时间 - else: - # delta 长度一般不会大于 4096,因此这里直接发送 - # 发送 typing 状态(带节流) + await self._process_chain_items( + chain, payload, user_name, message_thread_id, _append_text + ) + + # 编辑或发送消息 + if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH: + current_time = asyncio.get_event_loop().time() + time_since_last_edit = current_time - last_edit_time + + if time_since_last_edit >= throttle_interval: current_time = asyncio.get_event_loop().time() if current_time - last_chat_action_time >= chat_action_interval: await self._ensure_typing(user_name, message_thread_id) last_chat_action_time = current_time try: - msg = await self.client.send_message( - text=delta, **cast(Any, payload) + await self.client.edit_message_text( + text=delta, + chat_id=payload["chat_id"], + message_id=message_id, ) current_content = delta except Exception as e: - logger.warning(f"发送消息失败(streaming): {e!s}") - message_id = msg.message_id - last_edit_time = ( - asyncio.get_event_loop().time() - ) # 记录初始消息发送时间 + logger.warning(f"编辑消息失败(streaming): {e!s}") + last_edit_time = asyncio.get_event_loop().time() + else: + current_time = asyncio.get_event_loop().time() + if current_time - last_chat_action_time >= chat_action_interval: + await self._ensure_typing(user_name, message_thread_id) + last_chat_action_time = current_time + try: + msg = await self.client.send_message( + text=delta, **cast(Any, payload) + ) + current_content = delta + except Exception as e: + logger.warning(f"发送消息失败(streaming): {e!s}") + message_id = msg.message_id + last_edit_time = asyncio.get_event_loop().time() try: if delta and current_content != delta: @@ -506,5 +712,3 @@ async def send_streaming(self, generator, use_fallback: bool = False): ) except Exception as e: logger.warning(f"编辑消息失败(streaming): {e!s}") - - return await super().send_streaming(generator, use_fallback) diff --git a/pyproject.toml b/pyproject.toml index e57a0216c3..d981c24708 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies = [ "pydantic>=2.12.5", "pydub>=0.25.1", "pyjwt>=2.10.1", - "python-telegram-bot>=22.0", + "python-telegram-bot>=22.6", "qq-botpy>=1.2.1", "quart>=0.20.0", "readability-lxml>=0.8.4.1", diff --git a/requirements.txt b/requirements.txt index c06c1d0f29..d76a11ddeb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,7 +32,7 @@ py-cord>=2.6.1 pydantic>=2.12.5 pydub>=0.25.1 pyjwt>=2.10.1 -python-telegram-bot>=22.0 +python-telegram-bot>=22.6 qq-botpy>=1.2.1 quart>=0.20.0 readability-lxml>=0.8.4.1 diff --git a/tests/fixtures/mocks/telegram.py b/tests/fixtures/mocks/telegram.py index fbe4d04364..904ec4d093 100644 --- a/tests/fixtures/mocks/telegram.py +++ b/tests/fixtures/mocks/telegram.py @@ -110,6 +110,7 @@ def create_bot(): bot.set_my_commands = AsyncMock() bot.set_message_reaction = AsyncMock() bot.edit_message_text = AsyncMock() + bot.send_message_draft = AsyncMock() return bot @staticmethod