diff --git a/astrbot/builtin_stars/builtin_commands/commands/__init__.py b/astrbot/builtin_stars/builtin_commands/commands/__init__.py index 46d255965a..d96e52da6e 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/__init__.py +++ b/astrbot/builtin_stars/builtin_commands/commands/__init__.py @@ -2,6 +2,8 @@ from .admin import AdminCommands from .alter_cmd import AlterCmdCommands +from .context_compaction import ContextCompactionCommands +from .context_memory import ContextMemoryCommands from .conversation import ConversationCommands from .help import HelpCommand from .llm import LLMCommands @@ -17,6 +19,8 @@ "AdminCommands", "AlterCmdCommands", "ConversationCommands", + "ContextCompactionCommands", + "ContextMemoryCommands", "HelpCommand", "LLMCommands", "PersonaCommands", diff --git a/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py b/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py new file mode 100644 index 0000000000..d1c6c23d55 --- /dev/null +++ b/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py @@ -0,0 +1,110 @@ +from astrbot.api import star +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.core import logger +from astrbot.core.context_compaction_scheduler import PeriodicContextCompactionScheduler + + +class ContextCompactionCommands: + def __init__(self, context: star.Context) -> None: + self.context = context + + def _get_scheduler(self) -> PeriodicContextCompactionScheduler | None: + scheduler = getattr(self.context, "context_compaction_scheduler", None) + if isinstance(scheduler, PeriodicContextCompactionScheduler): + return scheduler + return None + + async def status(self, event: AstrMessageEvent) -> None: + scheduler = self._get_scheduler() + if not scheduler: + await event.send( + MessageChain().message("定时上下文压缩调度器不可用。"), + ) + return + + status = scheduler.get_status() + cfg = status.get("config", {}) + last = status.get("last_report") or {} + trigger_tokens = cfg.get("trigger_tokens", "?") + trigger_ratio = cfg.get("trigger_min_context_ratio", "?") + if isinstance(trigger_tokens, int) and trigger_tokens <= 0: + if isinstance(trigger_ratio, (int, float)): + trigger_text = f"自动({trigger_ratio}x模型上下文或目标长度估算)" + else: + trigger_text = "自动(基于目标长度估算)" + else: + trigger_text = str(trigger_tokens) + + lines = ["定时上下文压缩状态:"] + lines.append( + f"启用={self._yes_no(bool(cfg.get('enabled', False)))}" + f" | 运行中={self._yes_no(bool(status.get('running', False)))}" + f" | 停止请求={self._yes_no(bool(status.get('stop_requested', False)))}" + ) + lines.append( + f"间隔={cfg.get('interval_minutes', '?')}分钟" + f" | 每轮最多压缩={cfg.get('max_conversations_per_run', '?')}" + f" | 每轮最多扫描={cfg.get('max_scan_per_run', '?')}" + ) + lines.append( + f"触发Token={trigger_text}" + f" | 目标Token={cfg.get('target_tokens', '?')}" + f" | 最大轮次={cfg.get('max_rounds', '?')}" + ) + + if last: + lines.append( + f"最近任务[{last.get('reason', 'unknown')}]" + f" scanned={last.get('scanned', 0)}" + f" compacted={last.get('compacted', 0)}" + f" skipped={last.get('skipped', 0)}" + f" failed={last.get('failed', 0)}" + f" elapsed={last.get('elapsed_sec', 0.0):.2f}s" + ) + else: + lines.append("最近任务:暂无") + + if status.get("last_started_at"): + lines.append(f"最近开始:{status.get('last_started_at')}") + if status.get("last_finished_at"): + lines.append(f"最近结束:{status.get('last_finished_at')}") + if status.get("last_error"): + lines.append(f"最近错误:{status.get('last_error')}") + + await event.send(MessageChain().message("\n".join(lines))) + + async def run(self, event: AstrMessageEvent, limit: int | None = None) -> None: + scheduler = self._get_scheduler() + if not scheduler: + await event.send( + MessageChain().message("定时上下文压缩调度器不可用。"), + ) + return + + if limit is not None and limit < 1: + await event.send(MessageChain().message("limit 必须 >= 1。")) + return + + try: + report = await scheduler.run_once( + reason="manual_command", + max_conversations_override=limit, + ) + except Exception as exc: + logger.error("ctxcompact run failed: %s", exc, exc_info=True) + await event.send(MessageChain().message("触发压缩失败,请查看服务端日志。")) + return + + msg = ( + "手动触发完成:" + f"scanned={report.get('scanned', 0)} " + f"compacted={report.get('compacted', 0)} " + f"skipped={report.get('skipped', 0)} " + f"failed={report.get('failed', 0)} " + f"elapsed={report.get('elapsed_sec', 0.0):.2f}s" + ) + await event.send(MessageChain().message(msg)) + + @staticmethod + def _yes_no(value: bool) -> str: + return "是" if value else "否" diff --git a/astrbot/builtin_stars/builtin_commands/commands/context_memory.py b/astrbot/builtin_stars/builtin_commands/commands/context_memory.py new file mode 100644 index 0000000000..8bb923fdc2 --- /dev/null +++ b/astrbot/builtin_stars/builtin_commands/commands/context_memory.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from typing import Any + +from astrbot.api import star +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.core.context_memory import ensure_context_memory_settings + +PINNED_PREVIEW_MAX_CHARS = 180 + + +class ContextMemoryCommands: + def __init__(self, context: star.Context) -> None: + self.context = context + + def _get_provider_settings(self, event: AstrMessageEvent) -> tuple[Any, dict[str, Any]]: + cfg = self.context.get_config(umo=event.unified_msg_origin) + provider_settings = cfg.get("provider_settings", {}) + if not isinstance(provider_settings, dict): + provider_settings = {} + cfg["provider_settings"] = provider_settings + return cfg, provider_settings + + @staticmethod + def _save_config(cfg: Any) -> None: + save_func = getattr(cfg, "save_config", None) + if callable(save_func): + save_func() + + @staticmethod + def _parse_switch(value: str) -> bool | None: + normalized = value.strip().lower() + if normalized in {"1", "true", "on", "yes", "enable", "enabled"}: + return True + if normalized in {"0", "false", "off", "no", "disable", "disabled"}: + return False + return None + + async def status(self, event: AstrMessageEvent) -> None: + _, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + pinned = cm_cfg.get("pinned_memories", []) + if not isinstance(pinned, list): + pinned = [] + + lines = ["上下文记忆状态:"] + lines.append( + "启用=" + + ("是" if bool(cm_cfg.get("enabled", False)) else "否") + + " | 注入顶层记忆=" + + ("是" if bool(cm_cfg.get("inject_pinned_memory", True)) else "否") + ) + lines.append( + f"顶层记忆条数={len(pinned)}" + f" | 最大条数={cm_cfg.get('pinned_max_items', '?')}" + f" | 单条最大字符={cm_cfg.get('pinned_max_chars_per_item', '?')}" + ) + lines.append( + "检索增强(开发中)=" + + ("是" if bool(cm_cfg.get("retrieval_enabled", False)) else "否") + + f" | backend={cm_cfg.get('retrieval_backend', '') or '-'}" + + f" | top_k={cm_cfg.get('retrieval_top_k', '?')}" + ) + await event.send(MessageChain().message("\n".join(lines))) + + async def ls(self, event: AstrMessageEvent) -> None: + _, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + pinned = cm_cfg.get("pinned_memories", []) + if not isinstance(pinned, list) or not pinned: + await event.send(MessageChain().message("当前没有手动顶层记忆。")) + return + + configured_max_chars = cm_cfg.get("pinned_max_chars_per_item", 400) + try: + configured_max_chars = int(configured_max_chars) + except Exception: + configured_max_chars = 400 + preview_max_chars = min( + max(1, configured_max_chars), + PINNED_PREVIEW_MAX_CHARS, + ) + + lines = ["手动顶层记忆列表:"] + for idx, text in enumerate(pinned, start=1): + text_str = str(text) + if len(text_str) > preview_max_chars: + text_str = text_str[:preview_max_chars] + "..." + lines.append(f"{idx}. {text_str}") + await event.send(MessageChain().message("\n".join(lines))) + + async def add(self, event: AstrMessageEvent, text: str) -> None: + content = str(text or "").strip() + if not content: + await event.send(MessageChain().message("用法: /ctxmem add <记忆内容>")) + return + + cfg, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + pinned = cm_cfg.get("pinned_memories", []) + if not isinstance(pinned, list): + pinned = [] + cm_cfg["pinned_memories"] = pinned + + max_items = int(cm_cfg.get("pinned_max_items", 8) or 8) + if len(pinned) >= max_items: + await event.send( + MessageChain().message( + f"已达到顶层记忆最大条数({max_items}),请先使用 /ctxmem rm <序号> 或 /ctxmem clear。", + ) + ) + return + + max_chars = int(cm_cfg.get("pinned_max_chars_per_item", 400) or 400) + truncated = False + if len(content) > max_chars: + content = content[:max_chars] + truncated = True + + pinned.append(content) + self._save_config(cfg) + + msg = f"已添加顶层记忆 #{len(pinned)}。" + if truncated: + msg += f" 内容超过上限,已截断到 {max_chars} 字符。" + await event.send(MessageChain().message(msg)) + + async def rm(self, event: AstrMessageEvent, index: int) -> None: + cfg, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + pinned = cm_cfg.get("pinned_memories", []) + if not isinstance(pinned, list) or not pinned: + await event.send(MessageChain().message("当前没有可删除的顶层记忆。")) + return + + if index < 1 or index > len(pinned): + await event.send( + MessageChain().message(f"序号超出范围。请输入 1~{len(pinned)}。") + ) + return + + removed = str(pinned.pop(index - 1)) + self._save_config(cfg) + preview = removed if len(removed) <= 80 else removed[:80] + "..." + await event.send(MessageChain().message(f"已删除顶层记忆 #{index}: {preview}")) + + async def clear(self, event: AstrMessageEvent) -> None: + cfg, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + pinned = cm_cfg.get("pinned_memories", []) + count = len(pinned) if isinstance(pinned, list) else 0 + cm_cfg["pinned_memories"] = [] + self._save_config(cfg) + await event.send(MessageChain().message(f"已清空顶层记忆,共 {count} 条。")) + + async def enable(self, event: AstrMessageEvent, value: str = "") -> None: + cfg, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + enabled = bool(cm_cfg.get("enabled", False)) + + value = str(value or "").strip() + if value: + parsed = self._parse_switch(value) + if parsed is None: + await event.send( + MessageChain().message("参数错误。用法: /ctxmem enable [on|off]") + ) + return + enabled = parsed + else: + enabled = not enabled + + cm_cfg["enabled"] = enabled + self._save_config(cfg) + await event.send( + MessageChain().message( + "上下文记忆注入已" + ("开启。" if enabled else "关闭。") + ) + ) + + async def retrieval(self, event: AstrMessageEvent, value: str = "") -> None: + cfg, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + enabled = bool(cm_cfg.get("retrieval_enabled", False)) + + value = str(value or "").strip() + if value: + parsed = self._parse_switch(value) + if parsed is None: + await event.send( + MessageChain().message("参数错误。用法: /ctxmem retrieval [on|off]") + ) + return + enabled = parsed + else: + enabled = not enabled + + cm_cfg["retrieval_enabled"] = enabled + self._save_config(cfg) + await event.send( + MessageChain().message( + "检索增强开关(开发中)已" + ("开启。" if enabled else "关闭。") + ) + ) diff --git a/astrbot/builtin_stars/builtin_commands/main.py b/astrbot/builtin_stars/builtin_commands/main.py index fb4a834035..14455a2bb7 100644 --- a/astrbot/builtin_stars/builtin_commands/main.py +++ b/astrbot/builtin_stars/builtin_commands/main.py @@ -1,9 +1,12 @@ from astrbot.api import star from astrbot.api.event import AstrMessageEvent, filter +from astrbot.core.star.filter.command import GreedyStr from .commands import ( AdminCommands, AlterCmdCommands, + ContextCompactionCommands, + ContextMemoryCommands, ConversationCommands, HelpCommand, LLMCommands, @@ -26,6 +29,8 @@ def __init__(self, context: star.Context) -> None: self.plugin_c = PluginCommands(self.context) self.admin_c = AdminCommands(self.context) self.conversation_c = ConversationCommands(self.context) + self.ctxcompact_c = ContextCompactionCommands(self.context) + self.ctxmem_c = ContextMemoryCommands(self.context) self.provider_c = ProviderCommands(self.context) self.persona_c = PersonaCommands(self.context) self.alter_cmd_c = AlterCmdCommands(self.context) @@ -127,6 +132,74 @@ async def provider( """查看或者切换 LLM Provider""" await self.provider_c.provider(event, idx, idx2) + @filter.command_group("ctxcompact") + @filter.permission_type(filter.PermissionType.ADMIN) + def ctxcompact(self) -> None: + """上下文定时压缩管理""" + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxcompact.command("status") + async def ctxcompact_status(self, event: AstrMessageEvent) -> None: + """查看定时上下文压缩状态""" + await self.ctxcompact_c.status(event) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxcompact.command("run") + async def ctxcompact_run( + self, + event: AstrMessageEvent, + limit: int | None = None, + ) -> None: + """手动触发一次上下文压缩(可选 limit 覆盖本次压缩会话数)""" + await self.ctxcompact_c.run(event, limit) + + @filter.command_group("ctxmem") + @filter.permission_type(filter.PermissionType.ADMIN) + def ctxmem(self) -> None: + """上下文记忆管理(手动顶层记忆)""" + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("status") + async def ctxmem_status(self, event: AstrMessageEvent) -> None: + """查看上下文记忆状态""" + await self.ctxmem_c.status(event) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("ls") + async def ctxmem_ls(self, event: AstrMessageEvent) -> None: + """查看手动顶层记忆列表""" + await self.ctxmem_c.ls(event) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("add") + async def ctxmem_add(self, event: AstrMessageEvent, text: GreedyStr) -> None: + """添加一条手动顶层记忆。ctxmem add """ + await self.ctxmem_c.add(event, text) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("rm") + async def ctxmem_rm(self, event: AstrMessageEvent, index: int) -> None: + """删除一条手动顶层记忆。ctxmem rm """ + await self.ctxmem_c.rm(event, index) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("clear") + async def ctxmem_clear(self, event: AstrMessageEvent) -> None: + """清空手动顶层记忆""" + await self.ctxmem_c.clear(event) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("enable") + async def ctxmem_enable(self, event: AstrMessageEvent, value: str = "") -> None: + """开关上下文记忆注入。ctxmem enable [on|off]""" + await self.ctxmem_c.enable(event, value) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("retrieval") + async def ctxmem_retrieval(self, event: AstrMessageEvent, value: str = "") -> None: + """开关检索增强预留开关。ctxmem retrieval [on|off]""" + await self.ctxmem_c.retrieval(event, value) + @filter.command("reset") async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" diff --git a/astrbot/core/agent/context/config.py b/astrbot/core/agent/context/config.py index b8fd8eb968..7d43aaed64 100644 --- a/astrbot/core/agent/context/config.py +++ b/astrbot/core/agent/context/config.py @@ -29,6 +29,10 @@ class ContextConfig: """Number of recent messages to keep during LLM-based compression.""" llm_compress_provider: "Provider | None" = None """LLM provider used for compression tasks. If None, truncation strategy is used.""" + token_counter_mode: str = "estimate" + """Token counting mode: estimate, tokenizer, auto.""" + token_counter_model: str | None = None + """Optional model name for tokenizer-based token counting.""" custom_token_counter: TokenCounter | None = None """Custom token counting method. If None, the default method is used.""" custom_compressor: ContextCompressor | None = None diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 216a3e7e15..72d6ddfbbe 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -3,7 +3,7 @@ from ..message import Message from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor from .config import ContextConfig -from .token_counter import EstimateTokenCounter +from .token_counter import create_token_counter from .truncator import ContextTruncator @@ -25,7 +25,10 @@ def __init__( """ self.config = config - self.token_counter = config.custom_token_counter or EstimateTokenCounter() + self.token_counter = config.custom_token_counter or create_token_counter( + config.token_counter_mode, + model=config.token_counter_model, + ) self.truncator = ContextTruncator() if config.custom_compressor: @@ -42,12 +45,18 @@ def __init__( ) async def process( - self, messages: list[Message], trusted_token_usage: int = 0 + self, + messages: list[Message], + trusted_token_usage: int = 0, + force_compaction: bool = False, ) -> list[Message]: """Process the messages. Args: messages: The original message list. + trusted_token_usage: Optional trusted token usage hint. + force_compaction: Force one compaction pass when token-based compaction + is enabled, regardless of compressor threshold. Returns: The processed message list. @@ -69,7 +78,7 @@ async def process( result, trusted_token_usage ) - if self.compressor.should_compress( + if force_compaction or self.compressor.should_compress( result, total_tokens, self.config.max_context_tokens ): result = await self._run_compression(result, total_tokens) diff --git a/astrbot/core/agent/context/token_counter.py b/astrbot/core/agent/context/token_counter.py index 7c60cb23ec..ca9f886a85 100644 --- a/astrbot/core/agent/context/token_counter.py +++ b/astrbot/core/agent/context/token_counter.py @@ -1,6 +1,9 @@ import json +from collections.abc import Callable from typing import Protocol, runtime_checkable +from astrbot import logger + from ..message import AudioURLPart, ImageURLPart, Message, TextPart, ThinkPart @@ -76,3 +79,118 @@ def _estimate_tokens(self, text: str) -> int: chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"]) other_count = len(text) - chinese_count return int(chinese_count * 0.6 + other_count * 0.3) + + def estimate_text_tokens(self, text: str) -> int: + return self._estimate_tokens(text) + + +class TokenizerTokenCounter: + """Tokenizer-based token counter. + + Uses `tiktoken` when available and falls back to estimate mode if encoding + is unavailable. + """ + + def __init__(self, model: str | None = None) -> None: + self._estimate = EstimateTokenCounter() + self._encode: Callable[[str], int] | None = None + self._available = False + self._init_encoder(model) + + @property + def available(self) -> bool: + return self._available + + def _init_encoder(self, model: str | None) -> None: + try: + import tiktoken # type: ignore + except Exception: + self._available = False + self._encode = None + return + + try: + if model: + encoding = tiktoken.encoding_for_model(model) + else: + encoding = tiktoken.get_encoding("cl100k_base") + except Exception: + try: + encoding = tiktoken.get_encoding("cl100k_base") + except Exception: + self._available = False + self._encode = None + return + + self._available = True + self._encode = lambda text: len(encoding.encode(text)) + + def count_tokens( + self, messages: list[Message], trusted_token_usage: int = 0 + ) -> int: + if trusted_token_usage > 0: + return trusted_token_usage + if not self._available: + return self._estimate.count_tokens(messages) + + total = 0 + for msg in messages: + content = msg.content + if isinstance(content, str): + total += self._encode_len(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, TextPart): + total += self._encode_len(part.text) + elif isinstance(part, ThinkPart): + total += self._encode_len(part.think) + elif isinstance(part, ImageURLPart): + total += IMAGE_TOKEN_ESTIMATE + elif isinstance(part, AudioURLPart): + total += AUDIO_TOKEN_ESTIMATE + + if msg.tool_calls: + for tc in msg.tool_calls: + tc_str = json.dumps( + tc if isinstance(tc, dict) else tc.model_dump(), + ensure_ascii=False, + default=str, + ) + total += self._encode_len(tc_str) + + return total + + def _encode_len(self, text: str) -> int: + if not self._encode: + return self._estimate.estimate_text_tokens(text) + try: + return self._encode(text) + except Exception: + return self._estimate.estimate_text_tokens(text) + + +def create_token_counter( + mode: str | None = None, + *, + model: str | None = None, +) -> TokenCounter: + normalized = str(mode or "estimate").strip().lower() + + if normalized == "estimate": + return EstimateTokenCounter() + + if normalized in {"tokenizer", "auto"}: + tokenizer_counter = TokenizerTokenCounter(model=model) + if tokenizer_counter.available: + return tokenizer_counter + if normalized == "tokenizer": + logger.warning( + "context_token_counter_mode=tokenizer but `tiktoken` is unavailable; fallback to estimate." + ) + return EstimateTokenCounter() + + logger.warning( + "Unknown context_token_counter_mode=%s, fallback to estimate.", + normalized, + ) + return EstimateTokenCounter() diff --git a/astrbot/core/agent/message_history_parser.py b/astrbot/core/agent/message_history_parser.py new file mode 100644 index 0000000000..e7650db50c --- /dev/null +++ b/astrbot/core/agent/message_history_parser.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import json +from collections.abc import Iterable +from typing import Any + +from astrbot.core.agent.message import Message + + +class MessageHistoryParser: + def parse(self, history: Iterable[Any]) -> list[Message]: + parsed: list[Message] = [] + for item in history: + if not isinstance(item, dict): + continue + + msg = self._try_validate(item) + if msg is not None: + parsed.append(msg) + continue + + fallback = self.sanitize_message_dict(item) + if not fallback: + continue + msg = self._try_validate(fallback) + if msg is not None: + parsed.append(msg) + + return parsed + + @staticmethod + def _try_validate(data: dict[str, Any]) -> Message | None: + try: + return Message.model_validate(data) + except Exception: + return None + + def sanitize_message_dict(self, item: dict[str, Any]) -> dict[str, Any] | None: + role = str(item.get("role", "")).strip().lower() + if role not in {"system", "user", "assistant", "tool"}: + return None + + result: dict[str, Any] = {"role": role} + + if role == "assistant" and isinstance(item.get("tool_calls"), list): + result["tool_calls"] = item["tool_calls"] + + if role == "tool" and item.get("tool_call_id"): + result["tool_call_id"] = str(item.get("tool_call_id")) + + content = item.get("content") + if content is None and role == "assistant" and result.get("tool_calls"): + result["content"] = None + return result + + result["content"] = self.sanitize_content(content, role) + + if result["content"] is None and not ( + role == "assistant" and result.get("tool_calls") + ): + return None + + return result + + def sanitize_content(self, content: Any, role: str) -> str | list[dict] | None: + if isinstance(content, str): + return content + + if isinstance(content, list): + return self.sanitize_list_content(content) + + if content is None: + if role == "assistant": + return None + return "" + + dumped = self.safe_json(content) + return dumped if dumped is not None else str(content) + + def sanitize_list_content(self, content: list[Any]) -> str | list[dict]: + parts: list[dict[str, Any]] = [] + fallback_texts: list[str] = [] + + for part in content: + if isinstance(part, str): + if part.strip(): + fallback_texts.append(part) + continue + if not isinstance(part, dict): + txt = self.safe_json(part) + if txt: + fallback_texts.append(txt) + continue + self.sanitize_content_part(part, parts, fallback_texts) + + if fallback_texts: + parts.insert(0, {"type": "text", "text": "\n".join(fallback_texts)}) + + if parts: + return parts + return "" + + def sanitize_content_part( + self, + part: dict[str, Any], + parts: list[dict[str, Any]], + fallback_texts: list[str], + ) -> None: + part_type = str(part.get("type", "")).strip() + if part_type == "text": + text_val = part.get("text") + if text_val is not None: + parts.append({"type": "text", "text": str(text_val)}) + return + + if part_type == "image_url": + image_obj = part.get("image_url") + if isinstance(image_obj, dict) and image_obj.get("url"): + image_part: dict[str, Any] = { + "type": "image_url", + "image_url": {"url": str(image_obj.get("url"))}, + } + if image_obj.get("id"): + image_part["image_url"]["id"] = str(image_obj.get("id")) + parts.append(image_part) + return + + if part_type == "audio_url": + audio_obj = part.get("audio_url") + if isinstance(audio_obj, dict) and audio_obj.get("url"): + audio_part: dict[str, Any] = { + "type": "audio_url", + "audio_url": {"url": str(audio_obj.get("url"))}, + } + if audio_obj.get("id"): + audio_part["audio_url"]["id"] = str(audio_obj.get("id")) + parts.append(audio_part) + return + + if part_type == "think": + think = part.get("think") + if think: + fallback_texts.append(str(think)) + return + + raw_text = part.get("text") or part.get("content") + if raw_text: + fallback_texts.append(str(raw_text)) + else: + dumped = self.safe_json(part) + if dumped: + fallback_texts.append(dumped) + + @staticmethod + def safe_json(value: Any) -> str | None: + try: + return json.dumps(value, ensure_ascii=False, default=str) + except Exception: + return None diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 743b280070..2b4700f31c 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -32,6 +32,7 @@ ToolCallsResult, ) from astrbot.core.provider.provider import Provider +from astrbot.core.utils.config_normalization import to_non_negative_int, to_ratio from ..context.compressor import ContextCompressor from ..context.config import ContextConfig @@ -80,6 +81,85 @@ class FollowUpTicket: resolved: asyncio.Event = field(default_factory=asyncio.Event) +@dataclass(slots=True, frozen=True) +class PostToolCompactionConfig: + enabled: bool = False + soft_ratio: float = 0.3 + hard_ratio: float = 0.7 + min_delta_tokens: int = 0 + min_delta_turns: int = 0 + debounce_seconds: int = 0 + + +class PostToolCompactionController: + def __init__(self, config: PostToolCompactionConfig) -> None: + self.config = config + self._baseline_tokens = 0 + self._baseline_messages = 0 + self._last_check_at = 0.0 + + def refresh_baseline( + self, + *, + messages: list[Message], + token_counter: TokenCounter, + trusted_token_usage: int = 0, + ) -> None: + try: + self._baseline_tokens = token_counter.count_tokens( + messages, + trusted_token_usage, + ) + except Exception: + self._baseline_tokens = 0 + self._baseline_messages = len(messages) + + def should_compact( + self, + *, + messages: list[Message], + token_counter: TokenCounter, + max_context_tokens: int, + ) -> bool: + if not self.config.enabled: + return False + + now = time.monotonic() + if ( + self.config.debounce_seconds > 0 + and self._last_check_at > 0 + and (now - self._last_check_at) < self.config.debounce_seconds + ): + return False + self._last_check_at = now + + if max_context_tokens <= 0: + # No explicit token budget configured: preserve legacy behavior. + return True + + try: + current_tokens = token_counter.count_tokens(messages) + except Exception: + return False + + current_messages = len(messages) + current_ratio = current_tokens / max(1, max_context_tokens) + + if current_ratio >= self.config.hard_ratio: + return True + if current_ratio < self.config.soft_ratio: + return False + + delta_tokens = max(0, current_tokens - self._baseline_tokens) + delta_messages = max(0, current_messages - self._baseline_messages) + if ( + delta_tokens < self.config.min_delta_tokens + and delta_messages < self.config.min_delta_turns + ): + return False + return True + + class ToolLoopAgentRunner(BaseAgentRunner[TContext]): def _get_persona_custom_error_message(self) -> str | None: """Read persona-level custom error message from event extras when available.""" @@ -104,6 +184,16 @@ async def reset( llm_compress_provider: Provider | None = None, # truncate by turns compressor truncate_turns: int = 1, + # context token counting mode + token_counter_mode: str = "estimate", + # run context compression immediately after tool execution + compact_context_after_tool_call: bool = False, + # post-tool-call compaction policy + compact_context_soft_ratio: float = 0.3, + compact_context_hard_ratio: float = 0.7, + compact_context_min_delta_tokens: int = 0, + compact_context_min_delta_turns: int = 0, + compact_context_debounce_seconds: int = 0, # customize custom_token_counter: TokenCounter | None = None, custom_compressor: ContextCompressor | None = None, @@ -118,11 +208,26 @@ async def reset( self.llm_compress_keep_recent = llm_compress_keep_recent self.llm_compress_provider = llm_compress_provider self.truncate_turns = truncate_turns + self.token_counter_mode = token_counter_mode + post_tool_soft_ratio = to_ratio(compact_context_soft_ratio, 0.3) + self.post_tool_compaction = PostToolCompactionConfig( + enabled=bool(compact_context_after_tool_call), + soft_ratio=post_tool_soft_ratio, + hard_ratio=max(post_tool_soft_ratio, to_ratio(compact_context_hard_ratio, 0.7)), + min_delta_tokens=to_non_negative_int(compact_context_min_delta_tokens), + min_delta_turns=to_non_negative_int(compact_context_min_delta_turns), + debounce_seconds=to_non_negative_int( + compact_context_debounce_seconds + ), + ) + self.post_tool_compaction_controller = PostToolCompactionController( + self.post_tool_compaction + ) self.custom_token_counter = custom_token_counter self.custom_compressor = custom_compressor # we will do compress when: # 1. before requesting LLM - # TODO: 2. after LLM output a tool call + # 2. optionally after tool execution, controlled by config self.context_config = ContextConfig( # <=0 will never do compress max_context_tokens=provider.provider_config.get("max_context_tokens", 0), @@ -132,6 +237,8 @@ async def reset( llm_compress_instruction=self.llm_compress_instruction, llm_compress_keep_recent=self.llm_compress_keep_recent, llm_compress_provider=self.llm_compress_provider, + token_counter_mode=self.token_counter_mode, + token_counter_model=provider.get_model(), custom_token_counter=self.custom_token_counter, custom_compressor=self.custom_compressor, ) @@ -195,10 +302,29 @@ async def reset( Message(role="system", content=request.system_prompt), ) self.run_context.messages = messages + self._refresh_tool_compaction_baseline( + trusted_token_usage=request.conversation.token_usage if request.conversation else 0 + ) self.stats = AgentStats() self.stats.start_time = time.time() + def _refresh_tool_compaction_baseline(self, *, trusted_token_usage: int = 0) -> None: + self.post_tool_compaction_controller.refresh_baseline( + messages=self.run_context.messages, + token_counter=self.context_manager.token_counter, + trusted_token_usage=trusted_token_usage, + ) + + def _should_run_post_tool_compaction(self) -> bool: + if not hasattr(self, "post_tool_compaction_controller"): + return False + return self.post_tool_compaction_controller.should_compact( + messages=self.run_context.messages, + token_counter=self.context_manager.token_counter, + max_context_tokens=int(self.context_config.max_context_tokens or 0), + ) + async def _iter_llm_responses( self, *, include_model: bool = True ) -> T.AsyncGenerator[LLMResponse, None]: @@ -369,6 +495,7 @@ async def step(self): self.run_context.messages = await self.context_manager.process( self.run_context.messages, trusted_token_usage=token_usage ) + self._refresh_tool_compaction_baseline(trusted_token_usage=token_usage) self._simple_print_message_role("[AftCompact]") async for llm_response in self._iter_llm_responses_with_fallback(): @@ -618,6 +745,13 @@ async def step(self): self.req.append_tool_calls_result(tool_calls_result) + if self._should_run_post_tool_compaction(): + self.run_context.messages = await self.context_manager.process( + self.run_context.messages, + force_compaction=True, + ) + self._refresh_tool_compaction_baseline() + async def step_until_done( self, max_step: int ) -> T.AsyncGenerator[AgentResponse, None]: diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 10b67253fe..199f2a0328 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -50,6 +50,10 @@ TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, retrieve_knowledge_base, ) +from astrbot.core.context_memory import ( + build_pinned_memory_system_block, + load_context_memory_config, +) from astrbot.core.conversation_mgr import Conversation from astrbot.core.message.components import File, Image, Reply from astrbot.core.persona_error_reply import ( @@ -57,6 +61,7 @@ set_persona_custom_error_message_on_event, ) from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.prompt_assembly_router import assemble_system_prompt from astrbot.core.provider import Provider from astrbot.core.provider.entities import ProviderRequest from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt @@ -119,6 +124,20 @@ class MainAgentBuildConfig: """The number of most recent turns to keep during llm_compress strategy.""" llm_compress_provider_id: str = "" """The provider ID for the LLM used in context compression.""" + context_token_counter_mode: str = "estimate" + """Token counting mode for context compaction: estimate, tokenizer, auto.""" + compact_context_after_tool_call: bool = False + """Whether to run context compaction check immediately after tool execution.""" + compact_context_soft_ratio: float = 0.3 + """Soft trigger threshold for post-tool-call context compaction.""" + compact_context_hard_ratio: float = 0.7 + """Hard trigger threshold for post-tool-call context compaction.""" + compact_context_min_delta_tokens: int = 0 + """Minimum token growth required before post-tool-call compaction runs in soft zone.""" + compact_context_min_delta_turns: int = 0 + """Minimum message growth required before post-tool-call compaction runs in soft zone.""" + compact_context_debounce_seconds: int = 0 + """Debounce window for post-tool-call compaction checks.""" max_context_length: int = -1 """The maximum number of turns to keep in context. -1 means no limit. This enforce max turns before compression""" @@ -617,6 +636,34 @@ def _append_system_reminders( req.extra_user_content_parts.append(TextPart(text=system_content)) +def _inject_context_memory( + event: AstrMessageEvent, + req: ProviderRequest, + cfg: dict, +) -> None: + """Inject manually pinned top-level memories into system prompt. + + Vector retrieval enhancement is intentionally deferred to a follow-up PR. + This function only handles manually configured pinned memories. + """ + if not isinstance(cfg, dict): + return + cm_cfg = load_context_memory_config(cfg) + memory_block = build_pinned_memory_system_block(cm_cfg) + retrieved_facts = event.get_extra("retrieved_long_term_facts") + summarized_history = event.get_extra("compacted_history_summary") + req.system_prompt = assemble_system_prompt( + base_system_prompt=req.system_prompt or "", + retrieved_long_term_facts=retrieved_facts + if isinstance(retrieved_facts, list) + else None, + summarized_history=summarized_history + if isinstance(summarized_history, str) + else "", + pinned_memory_block=memory_block, + ) + + async def _decorate_llm_request( event: AstrMessageEvent, req: ProviderRequest, @@ -655,6 +702,7 @@ async def _decorate_llm_request( if tz is None: tz = plugin_context.get_config().get("timezone") _append_system_reminders(event, req, cfg, tz) + _inject_context_memory(event, req, cfg) def _modalities_fix(provider: Provider, req: ProviderRequest) -> None: @@ -1203,6 +1251,13 @@ async def build_main_agent( llm_compress_instruction=config.llm_compress_instruction, llm_compress_keep_recent=config.llm_compress_keep_recent, llm_compress_provider=_get_compress_provider(config, plugin_context), + token_counter_mode=config.context_token_counter_mode, + compact_context_after_tool_call=config.compact_context_after_tool_call, + compact_context_soft_ratio=config.compact_context_soft_ratio, + compact_context_hard_ratio=config.compact_context_hard_ratio, + compact_context_min_delta_tokens=config.compact_context_min_delta_tokens, + compact_context_min_delta_turns=config.compact_context_min_delta_turns, + compact_context_debounce_seconds=config.compact_context_debounce_seconds, truncate_turns=config.dequeue_context_length, enforce_max_turns=config.max_context_length, tool_schema_mode=config.tool_schema_mode, diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 58c1726814..5df43e3967 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -18,6 +18,177 @@ "line", ] +PERIODIC_CONTEXT_COMPACTION_DEFAULTS = { + "enabled": False, + "interval_minutes": 30, + "startup_delay_seconds": 120, + "max_conversations_per_run": 8, + "max_scan_per_run": 120, + "scan_page_size": 40, + "min_idle_minutes": 15, + "min_messages": 14, + "target_tokens": 4096, + "trigger_tokens": 0, + "trigger_min_context_ratio": 0.3, + "max_rounds": 3, + "truncate_turns": 1, + "keep_recent": 6, + "provider_id": "", + "instruction": "", + "dry_run": False, +} + +PERIODIC_CONTEXT_COMPACTION_FIELD_META: dict[str, dict[str, Any]] = { + "enabled": { + "schema_type": "bool", + "ui_type": "bool", + "description": "启用定时历史压缩", + "hint": "后台定时扫描会话历史,使用 LLM 摘要旧消息并回写对话历史,实现多轮 compact context。", + }, + "interval_minutes": { + "schema_type": "int", + "ui_type": "int", + "description": "定时间隔(分钟)", + "hint": "每隔多少分钟执行一次压缩扫描。", + }, + "startup_delay_seconds": { + "schema_type": "int", + "ui_type": "int", + "description": "启动延迟(秒)", + "hint": "AstrBot 启动后,等待指定秒数再执行首次压缩任务。", + }, + "max_conversations_per_run": { + "schema_type": "int", + "ui_type": "int", + "description": "单次最多压缩会话数", + "hint": "每次任务最多实际压缩多少个会话。", + }, + "max_scan_per_run": { + "schema_type": "int", + "ui_type": "int", + "description": "单次最多扫描会话数", + "hint": "每次任务最多扫描多少会话(包括被跳过的会话)。", + }, + "scan_page_size": { + "schema_type": "int", + "ui_type": "int", + "description": "分页扫描大小", + "hint": "扫描 conversations 表时每页读取条数。", + }, + "min_idle_minutes": { + "schema_type": "int", + "ui_type": "int", + "description": "最小静默时长(分钟)", + "hint": "会话最近更新时间小于该值时跳过,避免压缩活跃会话。", + }, + "min_messages": { + "schema_type": "int", + "ui_type": "int", + "description": "最小消息条数", + "hint": "少于该消息条数的会话不参与压缩。", + }, + "target_tokens": { + "schema_type": "int", + "ui_type": "int", + "description": "目标 Token 阈值", + "hint": "压缩目标上下文大小(token 估算值)。", + }, + "trigger_tokens": { + "schema_type": "int", + "ui_type": "int", + "description": "触发 Token 阈值", + "hint": "会话估算 token 超过此值才触发压缩。<=0 表示自动按模型最大上下文比例计算。", + }, + "trigger_min_context_ratio": { + "schema_type": "float", + "ui_type": "float", + "description": "自动触发比例", + "hint": "当触发 Token 阈值 <= 0 时生效。默认 0.3(即模型最大上下文的 30%)。支持填写 0~1 或 0~100(百分比)。", + }, + "max_rounds": { + "schema_type": "int", + "ui_type": "int", + "description": "每会话最大压缩轮数", + "hint": "单个会话一次任务内最多执行几轮摘要压缩(实现 multiple compact context)。", + }, + "truncate_turns": { + "schema_type": "int", + "ui_type": "int", + "description": "截断轮数(后备)", + "hint": "LLM 压缩后仍超限时,按轮截断的每次丢弃轮数。", + }, + "keep_recent": { + "schema_type": "int", + "ui_type": "int", + "description": "保留最近轮数", + "hint": "压缩时始终保留最近 N 轮消息。", + }, + "provider_id": { + "schema_type": "string", + "ui_type": "string", + "description": "压缩模型提供商 ID", + "hint": "可自定义指定任意可用对话模型;留空时按会话当前模型执行压缩。建议优先选择成本较低、响应较快的模型。", + "_special": "select_provider", + }, + "instruction": { + "schema_type": "string", + "ui_type": "text", + "description": "定时压缩提示词", + "hint": "留空时复用 provider_settings.llm_compress_instruction。", + }, + "dry_run": { + "schema_type": "bool", + "ui_type": "bool", + "description": "演练模式(不回写)", + "hint": "开启后只记录日志,不实际写回数据库。", + }, +} + + +def _build_periodic_context_compaction_schema_properties() -> dict[str, dict[str, str]]: + return { + key: {"type": str(meta["schema_type"])} + for key, meta in PERIODIC_CONTEXT_COMPACTION_FIELD_META.items() + } + + +def _build_periodic_context_compaction_dashboard_items() -> dict[str, dict[str, Any]]: + items: dict[str, dict[str, Any]] = {} + base_enabled_condition = { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + } + for key, meta in PERIODIC_CONTEXT_COMPACTION_FIELD_META.items(): + condition = ( + {"provider_settings.agent_runner_type": "local"} + if key == "enabled" + else dict(base_enabled_condition) + ) + field: dict[str, Any] = { + "description": meta["description"], + "type": meta["ui_type"], + "hint": meta["hint"], + "condition": condition, + } + if "_special" in meta: + field["_special"] = meta["_special"] + items[f"provider_settings.periodic_context_compaction.{key}"] = field + return items + + +CONTEXT_MEMORY_DEFAULTS = { + "enabled": False, + "inject_pinned_memory": True, + "pinned_memories": [], + "pinned_max_items": 8, + "pinned_max_chars_per_item": 400, + # Reserved switches for follow-up PRs (manual opt-in only). + "retrieval_enabled": False, + "retrieval_backend": "", + "retrieval_provider_id": "", + "retrieval_top_k": 5, +} + # 默认配置 DEFAULT_CONFIG = { "config_version": 2, @@ -96,6 +267,15 @@ ), "llm_compress_keep_recent": 6, "llm_compress_provider_id": "", + "context_token_counter_mode": "estimate", + "compact_context_after_tool_call": False, + "compact_context_soft_ratio": 0.3, + "compact_context_hard_ratio": 0.7, + "compact_context_min_delta_tokens": 0, + "compact_context_min_delta_turns": 0, + "compact_context_debounce_seconds": 0, + "periodic_context_compaction": dict(PERIODIC_CONTEXT_COMPACTION_DEFAULTS), + "context_memory": dict(CONTEXT_MEMORY_DEFAULTS), "max_context_length": -1, "dequeue_context_length": 1, "streaming_response": False, @@ -2509,6 +2689,64 @@ class ChatProviderTemplate(TypedDict): "prompt_prefix": { "type": "string", }, + "context_token_counter_mode": { + "type": "string", + }, + "compact_context_after_tool_call": { + "type": "bool", + }, + "compact_context_soft_ratio": { + "type": "float", + }, + "compact_context_hard_ratio": { + "type": "float", + }, + "compact_context_min_delta_tokens": { + "type": "int", + }, + "compact_context_min_delta_turns": { + "type": "int", + }, + "compact_context_debounce_seconds": { + "type": "int", + }, + "periodic_context_compaction": { + "type": "object", + "properties": _build_periodic_context_compaction_schema_properties(), + }, + "context_memory": { + "type": "object", + "properties": { + "enabled": { + "type": "bool", + }, + "inject_pinned_memory": { + "type": "bool", + }, + "pinned_memories": { + "type": "list", + "items": {"type": "string"}, + }, + "pinned_max_items": { + "type": "int", + }, + "pinned_max_chars_per_item": { + "type": "int", + }, + "retrieval_enabled": { + "type": "bool", + }, + "retrieval_backend": { + "type": "string", + }, + "retrieval_provider_id": { + "type": "string", + }, + "retrieval_top_k": { + "type": "int", + }, + }, + }, "max_context_length": { "type": "int", }, @@ -3196,6 +3434,142 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.context_token_counter_mode": { + "description": "Token 计数模式", + "type": "string", + "options": ["estimate", "tokenizer", "auto"], + "labels": ["估算", "Tokenizer", "自动(优先 Tokenizer)"], + "hint": "用于上下文压缩触发判断。tokenizer 模式会优先使用 tiktoken,不可用时回退估算。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.compact_context_after_tool_call": { + "description": "工具调用后立即检查压缩", + "type": "bool", + "hint": "开启后,每次工具执行回写上下文后都会立刻触发一次上下文压缩检查。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.compact_context_soft_ratio": { + "description": "工具后压缩软阈值", + "type": "float", + "hint": "当上下文占比达到该阈值时,按“最小增长量”规则决定是否压缩。支持填写 0~1 或 0~100(百分比)。", + "condition": { + "provider_settings.compact_context_after_tool_call": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.compact_context_hard_ratio": { + "description": "工具后压缩硬阈值", + "type": "float", + "hint": "当上下文占比达到该阈值时,强制执行一次压缩。支持填写 0~1 或 0~100(百分比)。", + "condition": { + "provider_settings.compact_context_after_tool_call": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.compact_context_min_delta_tokens": { + "description": "工具后最小 Token 增长", + "type": "int", + "hint": "在软阈值区间内,Token 增长低于该值时不触发压缩。0 表示不限制。", + "condition": { + "provider_settings.compact_context_after_tool_call": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.compact_context_min_delta_turns": { + "description": "工具后最小消息增长", + "type": "int", + "hint": "在软阈值区间内,消息增长低于该值时不触发压缩。0 表示不限制。", + "condition": { + "provider_settings.compact_context_after_tool_call": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.compact_context_debounce_seconds": { + "description": "工具后压缩防抖(秒)", + "type": "int", + "hint": "两次工具后压缩检查的最小间隔秒数。0 表示关闭防抖。", + "condition": { + "provider_settings.compact_context_after_tool_call": True, + "provider_settings.agent_runner_type": "local", + }, + }, + **_build_periodic_context_compaction_dashboard_items(), + "provider_settings.context_memory.enabled": { + "description": "启用上下文记忆注入", + "type": "bool", + "hint": "启用后可将手动维护的顶层记忆注入到 system prompt,并预留向量记忆检索接口。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.inject_pinned_memory": { + "description": "注入手动顶层记忆", + "type": "bool", + "hint": "将 `pinned_memories` 作为高优先级记忆注入系统提示词。", + "condition": { + "provider_settings.context_memory.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.pinned_max_items": { + "description": "顶层记忆最大条数", + "type": "int", + "hint": "通过管理命令添加手动顶层记忆时允许保留的最大条目数。", + "condition": { + "provider_settings.context_memory.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.pinned_max_chars_per_item": { + "description": "单条顶层记忆最大字符数", + "type": "int", + "hint": "超出长度的条目会被截断,避免 system prompt 膨胀。", + "condition": { + "provider_settings.context_memory.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.retrieval_enabled": { + "description": "启用检索增强(开发中)", + "type": "bool", + "hint": "预留开关,默认关闭;向量检索增强建议在后续 PR 中实现。", + "condition": { + "provider_settings.context_memory.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.retrieval_backend": { + "description": "检索后端标识(预留)", + "type": "string", + "hint": "例如 zep/mem0/custom,当前版本仅用于配置预留。", + "condition": { + "provider_settings.context_memory.retrieval_enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.retrieval_provider_id": { + "description": "检索重排模型提供商 ID(预留)", + "type": "string", + "_special": "select_provider", + "hint": "当前版本仅保留配置,不会触发额外检索调用。", + "condition": { + "provider_settings.context_memory.retrieval_enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.retrieval_top_k": { + "description": "检索 Top-K(预留)", + "type": "int", + "hint": "后续检索增强功能默认使用的召回条数。", + "condition": { + "provider_settings.context_memory.retrieval_enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, }, "condition": { "provider_settings.agent_runner_type": "local", diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py new file mode 100644 index 0000000000..e2045e8f0b --- /dev/null +++ b/astrbot/core/context_compaction_scheduler.py @@ -0,0 +1,685 @@ +from __future__ import annotations + +import asyncio +import time +from collections.abc import AsyncIterator +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from astrbot import logger +from astrbot.core.agent.context.config import ContextConfig +from astrbot.core.agent.context.manager import ContextManager +from astrbot.core.agent.context.token_counter import ( + EstimateTokenCounter, + TokenCounter, + create_token_counter, +) +from astrbot.core.agent.message import Message +from astrbot.core.agent.message_history_parser import MessageHistoryParser +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.config.default import PERIODIC_CONTEXT_COMPACTION_DEFAULTS +from astrbot.core.conversation_mgr import ConversationManager +from astrbot.core.db.po import ConversationV2 +from astrbot.core.provider.entities import ProviderType +from astrbot.core.provider.provider import Provider +from astrbot.core.utils.config_normalization import to_bool, to_int, to_ratio +from astrbot.core.utils.llm_metadata import LLM_METADATAS + +if TYPE_CHECKING: + from astrbot.core.provider.manager import ProviderManager + + +@dataclass +class _CompactionStats: + scanned: int = 0 + compacted: int = 0 + skipped: int = 0 + failed: int = 0 + + +@dataclass +class _RoundResult: + messages: list[Message] + changed: bool + rounds: int + + +EligibilityInfo = tuple[list[Message], int] + + +@dataclass +class _RunStatus: + started_at: str | None = None + finished_at: str | None = None + error: str | None = None + report: dict[str, Any] | None = None + + +@dataclass(frozen=True) +class CompactionConfig: + enabled: bool + interval_minutes: int + startup_delay_seconds: int + max_conversations_per_run: int + max_scan_per_run: int + scan_page_size: int + min_idle_minutes: int + min_messages: int + target_tokens: int + trigger_tokens: int + trigger_min_context_ratio: float + max_rounds: int + truncate_turns: int + keep_recent: int + provider_id: str + instruction: str + dry_run: bool + + @classmethod + def from_default_conf( + cls, + default_conf: dict[str, Any], + ) -> CompactionConfig: + defaults = PERIODIC_CONTEXT_COMPACTION_DEFAULTS + provider_settings = default_conf.get("provider_settings", {}) or {} + raw_cfg = provider_settings.get("periodic_context_compaction", {}) or {} + if not isinstance(raw_cfg, dict): + raw_cfg = {} + + cfg = dict(defaults) + cfg.update(raw_cfg) + + target_tokens = to_int(cfg.get("target_tokens"), 4096, 512) + trigger_tokens = to_int(cfg.get("trigger_tokens"), 0, 0) + trigger_min_context_ratio = to_ratio( + cfg.get("trigger_min_context_ratio"), + 0.3, + ) + + return cls( + enabled=to_bool(cfg.get("enabled"), False), + interval_minutes=to_int(cfg.get("interval_minutes"), 30, 1), + startup_delay_seconds=to_int(cfg.get("startup_delay_seconds"), 120, 0), + max_conversations_per_run=to_int( + cfg.get("max_conversations_per_run"), + 8, + 1, + ), + max_scan_per_run=to_int(cfg.get("max_scan_per_run"), 120, 1), + scan_page_size=to_int(cfg.get("scan_page_size"), 40, 10), + min_idle_minutes=to_int(cfg.get("min_idle_minutes"), 15, 0), + min_messages=to_int(cfg.get("min_messages"), 14, 2), + target_tokens=target_tokens, + trigger_tokens=trigger_tokens, + trigger_min_context_ratio=trigger_min_context_ratio, + max_rounds=to_int(cfg.get("max_rounds"), 3, 1), + truncate_turns=to_int(cfg.get("truncate_turns"), 1, 1), + keep_recent=to_int(cfg.get("keep_recent"), 6, 0), + provider_id=str(cfg.get("provider_id", "") or "").strip(), + instruction=str(cfg.get("instruction", "") or "").strip(), + dry_run=to_bool(cfg.get("dry_run"), False), + ) + + +@dataclass(frozen=True) +class CompactionPolicy: + cfg: CompactionConfig + token_counter: TokenCounter + + def check_eligibility( + self, + conv: ConversationV2, + history_parser: MessageHistoryParser, + ) -> EligibilityInfo | None: + history = conv.content + if not isinstance(history, list) or len(history) < self.cfg.min_messages: + return None + + if not self.is_idle_enough(conv.updated_at, self.cfg.min_idle_minutes): + return None + + messages = history_parser.parse(history) + if len(messages) < self.cfg.min_messages: + return None + + trusted_usage = conv.token_usage if isinstance(conv.token_usage, int) else 0 + before_tokens = self.token_counter.count_tokens(messages, trusted_usage) + return messages, before_tokens + + def resolve_trigger_tokens(self, provider: Provider) -> int: + if self.cfg.trigger_tokens > 0: + return self.cfg.trigger_tokens + + max_context_tokens = self.resolve_provider_max_context(provider) + if max_context_tokens > 0: + return max(1, int(max_context_tokens * self.cfg.trigger_min_context_ratio)) + + return max(int(self.cfg.target_tokens * 1.5), self.cfg.target_tokens + 1) + + @staticmethod + def resolve_provider_max_context(provider: Provider) -> int: + configured = provider.provider_config.get("max_context_tokens", 0) + try: + configured_tokens = int(configured) + except Exception: + configured_tokens = 0 + if configured_tokens > 0: + return configured_tokens + + model = provider.get_model() + model_info = LLM_METADATAS.get(model) + if not isinstance(model_info, dict): + return 0 + limit = model_info.get("limit") + if not isinstance(limit, dict): + return 0 + context = limit.get("context") + try: + context_tokens = int(context) + except Exception: + context_tokens = 0 + return max(context_tokens, 0) + + @staticmethod + def is_idle_enough(updated_at: datetime | None, min_idle_minutes: int) -> bool: + if min_idle_minutes <= 0: + return True + if updated_at is None: + return True + now = datetime.now(timezone.utc) + at = updated_at + if at.tzinfo is None: + at = at.replace(tzinfo=timezone.utc) + return (now - at).total_seconds() >= (min_idle_minutes * 60) + + +class PeriodicContextCompactionScheduler: + """Periodically compact conversation history and persist summarized history back to DB. + + This upgrades existing "compress-on-overflow" behavior into proactive, scheduled + conversation-body compaction to keep long sessions lightweight. + """ + + def __init__( + self, + config_manager: AstrBotConfigManager, + conversation_manager: ConversationManager, + provider_manager: ProviderManager, + ) -> None: + self.config_manager = config_manager + self.conversation_manager = conversation_manager + self.provider_manager = provider_manager + self._stop_event = asyncio.Event() + self._running_lock = asyncio.Lock() + # Default fallback counter. Actual counter is resolved by provider_settings + # (context_token_counter_mode) and provider model when available. + self._token_counter = EstimateTokenCounter() + self._token_counter_cache: dict[tuple[str, str], TokenCounter] = {} + self._history_parser = MessageHistoryParser() + self._bootstrapped = False + self._last_status = _RunStatus() + + def get_status(self) -> dict[str, Any]: + cfg = self._load_config() + return { + "running": self._running_lock.locked(), + "bootstrapped": self._bootstrapped, + "stop_requested": self._stop_event.is_set(), + "config": asdict(cfg), + "last_started_at": self._last_status.started_at, + "last_finished_at": self._last_status.finished_at, + "last_error": self._last_status.error, + "last_report": self._last_status.report, + "last_status": asdict(self._last_status), + } + + async def run(self) -> None: + logger.info("[ContextCompact] scheduler started") + while not self._stop_event.is_set(): + cfg = self._load_config() + wait_seconds = max(1, int(cfg.interval_minutes)) * 60 + + if not cfg.enabled: + await self._sleep_or_stop(wait_seconds) + continue + + if not self._bootstrapped: + self._bootstrapped = True + startup_delay = max(0, int(cfg.startup_delay_seconds)) + if startup_delay > 0: + logger.info( + "[ContextCompact] startup delay: %ss before first run", + startup_delay, + ) + await self._sleep_or_stop(startup_delay) + if self._stop_event.is_set(): + break + + try: + report = await self.run_once(reason="scheduled", cfg=cfg) + logger.info( + "[ContextCompact] run done(%s): scanned=%s compacted=%s skipped=%s failed=%s elapsed=%.2fs", + report.get("reason", "unknown"), + report.get("scanned", 0), + report.get("compacted", 0), + report.get("skipped", 0), + report.get("failed", 0), + report.get("elapsed_sec", 0.0), + ) + except Exception as exc: + finished = self._now_iso() + self._update_last_status( + finished_at=finished, + error=str(exc), + ) + if self._last_status.started_at is None: + self._last_status.started_at = finished + if self._last_status.report is None: + self._last_status.report = {} + logger.error( + "[ContextCompact] scheduler run error: %s", + exc, + exc_info=True, + ) + + await self._sleep_or_stop(wait_seconds) + + logger.info("[ContextCompact] scheduler stopped") + + async def stop(self) -> None: + self._stop_event.set() + + async def run_once( + self, + reason: str = "manual", + max_conversations_override: int | None = None, + cfg: CompactionConfig | None = None, + ) -> dict[str, Any]: + """Run one compaction sweep. + + Exposed so future admin command/cron endpoints can trigger ad-hoc compaction. + """ + async with self._running_lock: + started_at = self._now_iso() + self._last_status.started_at = started_at + self._last_status.finished_at = None + if cfg is None: + cfg = self._load_config() + started = time.monotonic() + stats = _CompactionStats() + + if not cfg.enabled and reason == "scheduled": + report = { + "reason": reason, + "scanned": 0, + "compacted": 0, + "skipped": 0, + "failed": 0, + "elapsed_sec": 0.0, + "message": "disabled", + } + self._update_last_status( + started_at=started_at, + finished_at=self._now_iso(), + report=report, + error=None, + ) + return report + + max_to_compact, max_to_scan, scan_page_size = self._resolve_run_limits( + cfg, + max_conversations_override, + ) + + async for conv in self._iter_candidate_conversations( + scan_page_size=scan_page_size, + cfg=cfg, + ): + if ( + self._stop_event.is_set() + or stats.scanned >= max_to_scan + or stats.compacted >= max_to_compact + ): + break + + stats.scanned += 1 + outcome = await self._compact_one_conversation(conv, cfg) + if outcome == "compacted": + stats.compacted += 1 + elif outcome == "skipped": + stats.skipped += 1 + else: + stats.failed += 1 + + elapsed = time.monotonic() - started + report = { + "reason": reason, + "scanned": stats.scanned, + "compacted": stats.compacted, + "skipped": stats.skipped, + "failed": stats.failed, + "elapsed_sec": elapsed, + } + self._update_last_status( + started_at=started_at, + finished_at=self._now_iso(), + report=report, + error=None, + ) + return report + + @staticmethod + def _resolve_run_limits( + cfg: CompactionConfig, + max_conversations_override: int | None, + ) -> tuple[int, int, int]: + max_to_scan = max(1, int(cfg.max_scan_per_run)) + max_to_compact = max(1, int(cfg.max_conversations_per_run)) + if max_conversations_override is not None: + max_to_compact = max(1, int(max_conversations_override)) + max_to_compact = min(max_to_compact, max_to_scan) + scan_page_size = max(10, int(cfg.scan_page_size)) + return max_to_compact, max_to_scan, scan_page_size + + async def _iter_candidate_conversations( + self, + scan_page_size: int, + cfg: CompactionConfig, + ) -> AsyncIterator[ConversationV2]: + page = 1 + while not self._stop_event.is_set(): + conversations, total = await self.conversation_manager.db.get_filtered_conversations( + page=page, + page_size=scan_page_size, + updated_before=None, + min_messages=cfg.min_messages, + ) + if not conversations: + break + + for conv in conversations: + if self._stop_event.is_set(): + return + yield conv + + if page * scan_page_size >= total: + break + page += 1 + + def _update_last_status( + self, + *, + started_at: str | None = None, + finished_at: str | None = None, + error: str | None = None, + report: dict[str, Any] | None = None, + ) -> None: + if started_at is not None: + self._last_status.started_at = started_at + if finished_at is not None: + self._last_status.finished_at = finished_at + self._last_status.error = error + if report is not None: + self._last_status.report = report + + async def _sleep_or_stop(self, seconds: int) -> None: + try: + await asyncio.wait_for(self._stop_event.wait(), timeout=seconds) + except asyncio.TimeoutError: + return + + def _load_config(self) -> CompactionConfig: + return CompactionConfig.from_default_conf( + default_conf=self.config_manager.default_conf, + ) + + async def _compact_one_conversation( + self, + conv: ConversationV2, + cfg: CompactionConfig, + ) -> str: + provider = await self._resolve_provider(cfg, conv.user_id) + if not provider: + return "failed" + + token_counter = self._resolve_token_counter(provider) + policy = CompactionPolicy(cfg=cfg, token_counter=token_counter) + eligibility = policy.check_eligibility(conv, self._history_parser) + if eligibility is None: + return "skipped" + messages, before_tokens = eligibility + + trigger_tokens = policy.resolve_trigger_tokens(provider) + if before_tokens < trigger_tokens: + return "skipped" + + round_result = await self._run_compaction_rounds( + messages=messages, + provider=provider, + cfg=cfg, + token_counter=token_counter, + ) + if not round_result.changed: + return "skipped" + + after_tokens = token_counter.count_tokens(round_result.messages) + if after_tokens >= before_tokens: + return "skipped" + + if cfg.dry_run: + self._log_dry_run(conv, before_tokens, after_tokens, round_result) + return "skipped" + + persisted = await self._persist_compacted_history( + conv=conv, + compressed=round_result.messages, + after_tokens=after_tokens, + ) + if not persisted: + return "failed" + + self._log_compacted( + conv, + before_tokens, + after_tokens, + round_result, + ) + return "compacted" + + async def _run_compaction_rounds( + self, + messages: list[Message], + provider: Provider, + cfg: CompactionConfig, + token_counter: TokenCounter, + ) -> _RoundResult: + compressed = messages + changed = False + rounds = 0 + instruction = self._resolve_instruction(cfg) + manager = self._build_context_manager(cfg, provider, instruction, token_counter) + + for _ in range(cfg.max_rounds): + current_tokens = token_counter.count_tokens(compressed) + if current_tokens <= cfg.target_tokens: + break + + rounds += 1 + next_messages = await manager.process(compressed) + if self._messages_equal(compressed, next_messages): + break + + compressed = next_messages + changed = True + + return _RoundResult(messages=compressed, changed=changed, rounds=rounds) + + @staticmethod + def _build_context_manager( + cfg: CompactionConfig, + provider: Provider, + instruction: str, + token_counter: TokenCounter, + ) -> ContextManager: + return ContextManager( + ContextConfig( + max_context_tokens=cfg.target_tokens, + enforce_max_turns=-1, + truncate_turns=cfg.truncate_turns, + llm_compress_keep_recent=cfg.keep_recent, + llm_compress_instruction=instruction, + llm_compress_provider=provider, + custom_token_counter=token_counter, + ) + ) + + async def _persist_compacted_history( + self, + conv: ConversationV2, + compressed: list[Message], + after_tokens: int, + ) -> bool: + try: + await self.conversation_manager.update_conversation( + unified_msg_origin=conv.user_id, + conversation_id=conv.conversation_id, + history=[msg.model_dump(exclude_none=True) for msg in compressed], + token_usage=after_tokens, + ) + except Exception as exc: + logger.error( + "[ContextCompact] update failed: cid=%s user=%s err=%s", + conv.conversation_id, + conv.user_id, + exc, + exc_info=True, + ) + return False + return True + + @staticmethod + def _log_dry_run( + conv: ConversationV2, + before_tokens: int, + after_tokens: int, + round_result: _RoundResult, + ) -> None: + logger.info( + "[ContextCompact] dry-run: cid=%s user=%s tokens=%s->%s rounds=%s", + conv.conversation_id, + conv.user_id, + before_tokens, + after_tokens, + round_result.rounds, + ) + + @staticmethod + def _log_compacted( + conv: ConversationV2, + before_tokens: int, + after_tokens: int, + round_result: _RoundResult, + ) -> None: + logger.info( + "[ContextCompact] compacted cid=%s user=%s tokens=%s->%s rounds=%s", + conv.conversation_id, + conv.user_id, + before_tokens, + after_tokens, + round_result.rounds, + ) + + async def _resolve_provider( + self, + cfg: CompactionConfig, + umo: str, + ) -> Provider | None: + provider = None + + if cfg.provider_id: + provider = await self.provider_manager.get_provider_by_id(cfg.provider_id) + else: + provider = self.provider_manager.get_using_provider( + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) + if provider is None: + provider = self.provider_manager.get_using_provider( + provider_type=ProviderType.CHAT_COMPLETION, + umo=None, + ) + + if not isinstance(provider, Provider): + logger.warning( + "[ContextCompact] provider unavailable for umo=%s provider_id=%s", + umo, + cfg.provider_id, + ) + return None + return provider + + def _resolve_instruction(self, cfg: CompactionConfig) -> str: + if cfg.instruction: + return cfg.instruction + + provider_settings = self.config_manager.default_conf.get("provider_settings", {}) + base_instruction = provider_settings.get("llm_compress_instruction", "") + if isinstance(base_instruction, str) and base_instruction.strip(): + return base_instruction.strip() + return "" + + def _resolve_token_counter(self, provider: Provider | None) -> TokenCounter: + mode = self._resolve_token_counter_mode(provider) + + model = "" + if provider is not None: + try: + model = str(provider.get_model() or "") + except Exception: + model = "" + cache_key = (mode, model) + cached = self._token_counter_cache.get(cache_key) + if cached is not None: + return cached + + try: + resolved = create_token_counter(mode=mode, model=model or None) + except Exception as exc: + logger.warning( + "[ContextCompact] failed to create token counter(mode=%s, model=%s), fallback to estimate: %s", + mode, + model or "-", + exc, + ) + resolved = self._token_counter + + self._token_counter_cache[cache_key] = resolved + return resolved + + def _resolve_token_counter_mode(self, provider: Provider | None) -> str: + if provider is not None: + provider_settings = getattr(provider, "provider_settings", None) + if isinstance(provider_settings, dict): + mode = str(provider_settings.get("context_token_counter_mode", "") or "") + normalized = mode.strip().lower() + if normalized: + return normalized + + provider_settings = self.config_manager.default_conf.get("provider_settings", {}) + mode = "estimate" + if isinstance(provider_settings, dict): + mode = str(provider_settings.get("context_token_counter_mode", "estimate")) + return mode.strip().lower() or "estimate" + + @staticmethod + def _messages_equal(a: list[Message], b: list[Message]) -> bool: + if len(a) != len(b): + return False + return [m.model_dump(exclude_none=True) for m in a] == [ + m.model_dump(exclude_none=True) for m in b + ] + + @staticmethod + def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() diff --git a/astrbot/core/context_memory.py b/astrbot/core/context_memory.py new file mode 100644 index 0000000000..008a8c1c6c --- /dev/null +++ b/astrbot/core/context_memory.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from astrbot.core.config.default import CONTEXT_MEMORY_DEFAULTS +from astrbot.core.utils.config_normalization import to_bool, to_int + +__all__ = [ + "ContextMemoryConfig", + "normalize_context_memory_settings", + "load_context_memory_config", + "ensure_context_memory_settings", + "build_pinned_memory_system_block", +] + + +@dataclass(frozen=True) +class ContextMemoryConfig: + enabled: bool = False + inject_pinned_memory: bool = True + pinned_memories: list[str] = field(default_factory=list) + pinned_max_items: int = 8 + pinned_max_chars_per_item: int = 400 + retrieval_enabled: bool = False + retrieval_backend: str = "" + retrieval_provider_id: str = "" + retrieval_top_k: int = 5 + + @classmethod + def from_settings( + cls, + provider_settings: dict[str, Any] | None, + ) -> ContextMemoryConfig: + raw = None + if isinstance(provider_settings, dict): + raw = provider_settings.get("context_memory") + return cls.from_raw(raw if isinstance(raw, dict) else None) + + @classmethod + def from_raw(cls, raw: dict[str, Any] | None) -> ContextMemoryConfig: + defaults = CONTEXT_MEMORY_DEFAULTS + data = raw if isinstance(raw, dict) else {} + + enabled = to_bool(data.get("enabled"), bool(defaults["enabled"])) + inject_pinned_memory = to_bool( + data.get("inject_pinned_memory"), + bool(defaults["inject_pinned_memory"]), + ) + pinned_max_items = to_int( + data.get("pinned_max_items"), + int(defaults["pinned_max_items"]), + 1, + ) + pinned_max_chars_per_item = to_int( + data.get("pinned_max_chars_per_item"), + int(defaults["pinned_max_chars_per_item"]), + 1, + ) + retrieval_enabled = to_bool( + data.get("retrieval_enabled"), + bool(defaults["retrieval_enabled"]), + ) + retrieval_backend = str(data.get("retrieval_backend", "") or "").strip() + retrieval_provider_id = str(data.get("retrieval_provider_id", "") or "").strip() + retrieval_top_k = to_int( + data.get("retrieval_top_k"), + int(defaults["retrieval_top_k"]), + 1, + ) + + pinned_raw = data.get("pinned_memories", []) + pinned_memories: list[str] = [] + if isinstance(pinned_raw, list): + for item in pinned_raw: + text = str(item or "").strip() + if not text: + continue + if len(text) > pinned_max_chars_per_item: + text = text[:pinned_max_chars_per_item] + pinned_memories.append(text) + if len(pinned_memories) >= pinned_max_items: + break + + return cls( + enabled=enabled, + inject_pinned_memory=inject_pinned_memory, + pinned_memories=pinned_memories, + pinned_max_items=pinned_max_items, + pinned_max_chars_per_item=pinned_max_chars_per_item, + retrieval_enabled=retrieval_enabled, + retrieval_backend=retrieval_backend, + retrieval_provider_id=retrieval_provider_id, + retrieval_top_k=retrieval_top_k, + ) + + def to_settings_dict(self) -> dict[str, Any]: + return { + "enabled": self.enabled, + "inject_pinned_memory": self.inject_pinned_memory, + "pinned_memories": list(self.pinned_memories), + "pinned_max_items": self.pinned_max_items, + "pinned_max_chars_per_item": self.pinned_max_chars_per_item, + "retrieval_enabled": self.retrieval_enabled, + "retrieval_backend": self.retrieval_backend, + "retrieval_provider_id": self.retrieval_provider_id, + "retrieval_top_k": self.retrieval_top_k, + } + + +def normalize_context_memory_settings(raw: dict[str, Any] | None) -> dict[str, Any]: + return ContextMemoryConfig.from_raw(raw if isinstance(raw, dict) else None).to_settings_dict() + + +def load_context_memory_config(provider_settings: dict[str, Any] | None) -> ContextMemoryConfig: + return ContextMemoryConfig.from_settings(provider_settings) + + +def ensure_context_memory_settings(provider_settings: dict[str, Any]) -> dict[str, Any]: + """Normalize and persist context_memory subtree in provider_settings.""" + normalized = ContextMemoryConfig.from_settings(provider_settings).to_settings_dict() + provider_settings["context_memory"] = normalized + return normalized + + +def build_pinned_memory_system_block(config: ContextMemoryConfig) -> str: + """Build system-prompt block for manually pinned top-level memories.""" + if not config.enabled or not config.inject_pinned_memory: + return "" + if not config.pinned_memories: + return "" + + lines = [ + "", + "The following high-priority memory is manually configured and should be respected when relevant:", + ] + for idx, memory in enumerate(config.pinned_memories, start=1): + lines.append(f"{idx}. {memory}") + lines.append("") + return "\n".join(lines) diff --git a/astrbot/core/context_memory_backends.py b/astrbot/core/context_memory_backends.py new file mode 100644 index 0000000000..a508a3dca3 --- /dev/null +++ b/astrbot/core/context_memory_backends.py @@ -0,0 +1,11 @@ +"""Compatibility re-exports for experimental context-memory backend hooks. + +Experimental protocol definitions live in +`context_memory_experimental_backends.py` to keep extension points isolated from +stable context-memory config logic. +""" + +from astrbot.core import context_memory_experimental_backends as _exp + +__all__ = list(_exp.__all__) +globals().update({name: getattr(_exp, name) for name in __all__}) diff --git a/astrbot/core/context_memory_experimental_backends.py b/astrbot/core/context_memory_experimental_backends.py new file mode 100644 index 0000000000..cf8ec7689c --- /dev/null +++ b/astrbot/core/context_memory_experimental_backends.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class ContextMemoryBackend(Protocol): + """Experimental unified protocol for context-memory evolution + migration.""" + + async def evolve( + self, + *, + unified_msg_origin: str, + turns: list[str], + metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Evolve short-term conversation turns into durable memory artifacts.""" + ... + + async def retrieve( + self, + *, + unified_msg_origin: str, + query: str, + top_k: int, + ) -> list[str]: + """Retrieve evolved memory snippets for prompt assembly.""" + ... + + async def export_session( + self, + *, + unified_msg_origin: str, + ) -> dict[str, Any]: + """Export memory payload for migration or backup.""" + ... + + async def import_session( + self, + *, + unified_msg_origin: str, + payload: dict[str, Any], + ) -> None: + """Import migrated memory payload into target backend.""" + ... + + +__all__ = [ + "ContextMemoryBackend", +] diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index fe6b1c351d..ab1a3eb25d 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -20,6 +20,9 @@ from astrbot.core import LogBroker, LogManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.default import VERSION +from astrbot.core.context_compaction_scheduler import ( + PeriodicContextCompactionScheduler, +) from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.cron import CronJobManager from astrbot.core.db import BaseDatabase @@ -59,6 +62,9 @@ def __init__(self, log_broker: LogBroker, db: BaseDatabase) -> None: self.subagent_orchestrator: SubAgentOrchestrator | None = None self.cron_manager: CronJobManager | None = None self.temp_dir_cleaner: TempDirCleaner | None = None + self.context_compaction_scheduler: ( + PeriodicContextCompactionScheduler | None + ) = None # 设置代理 proxy_config = self.astrbot_config.get("http_proxy", "") @@ -166,6 +172,13 @@ async def initialize(self) -> None: # 初始化对话管理器 self.conversation_manager = ConversationManager(self.db) + # 初始化定时历史压缩调度器(基于 llm_compress) + self.context_compaction_scheduler = PeriodicContextCompactionScheduler( + config_manager=self.astrbot_config_mgr, + conversation_manager=self.conversation_manager, + provider_manager=self.provider_manager, + ) + # 初始化平台消息历史管理器 self.platform_message_history_manager = PlatformMessageHistoryManager(self.db) @@ -193,6 +206,7 @@ async def initialize(self) -> None: self.cron_manager, self.subagent_orchestrator, ) + self.star_context.context_compaction_scheduler = self.context_compaction_scheduler # 初始化插件管理器 self.plugin_manager = PluginManager(self.star_context, self.astrbot_config) @@ -252,6 +266,12 @@ def _load(self) -> None: self.temp_dir_cleaner.run(), name="temp_dir_cleaner", ) + context_compaction_task = None + if self.context_compaction_scheduler: + context_compaction_task = asyncio.create_task( + self.context_compaction_scheduler.run(), + name="context_compaction_scheduler", + ) # 把插件中注册的所有协程函数注册到事件总线中并执行 extra_tasks = [] @@ -263,6 +283,8 @@ def _load(self) -> None: tasks_.append(cron_task) if temp_dir_cleaner_task: tasks_.append(temp_dir_cleaner_task) + if context_compaction_task: + tasks_.append(context_compaction_task) for task in tasks_: self.curr_tasks.append( asyncio.create_task(self._task_wrapper(task), name=task.get_name()), @@ -317,6 +339,9 @@ async def stop(self) -> None: if self.temp_dir_cleaner: await self.temp_dir_cleaner.stop() + if self.context_compaction_scheduler: + await self.context_compaction_scheduler.stop() + # 请求停止所有正在运行的异步任务 for task in self.curr_tasks: task.cancel() diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index c8e50909d5..b19f4d47f3 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -243,6 +243,19 @@ async def get_filtered_conversations( base_query = base_query.where( col(ConversationV2.platform_id).in_(kwargs["platforms"]), ) + if "updated_before" in kwargs and kwargs["updated_before"] is not None: + updated_before = kwargs["updated_before"] + base_query = base_query.where( + or_( + col(ConversationV2.updated_at).is_(None), + col(ConversationV2.updated_at) <= updated_before, + ), + ) + if "min_messages" in kwargs and kwargs["min_messages"]: + min_messages = max(1, int(kwargs["min_messages"])) + base_query = base_query.where( + func.json_array_length(col(ConversationV2.content)) >= min_messages, + ) # Get total count matching the filters count_query = select(func.count()).select_from(base_query.subquery()) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 523d758a0a..7b1d41eef6 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -29,6 +29,7 @@ ProviderRequest, ) from astrbot.core.star.star_handler import EventType +from astrbot.core.utils.config_normalization import to_non_negative_int, to_ratio from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager @@ -91,6 +92,33 @@ async def initialize(self, ctx: PipelineContext) -> None: self.llm_compress_provider_id: str = settings.get( "llm_compress_provider_id", "" ) + self.context_token_counter_mode: str = str( + settings.get("context_token_counter_mode", "estimate") + ) + self.compact_context_after_tool_call: bool = settings.get( + "compact_context_after_tool_call", + False, + ) + self.compact_context_soft_ratio: float = to_ratio( + settings.get("compact_context_soft_ratio", 0.3), + 0.3, + ) + self.compact_context_hard_ratio: float = to_ratio( + settings.get("compact_context_hard_ratio", 0.7), + 0.7, + ) + self.compact_context_min_delta_tokens: int = to_non_negative_int( + settings.get("compact_context_min_delta_tokens", 0), + 0, + ) + self.compact_context_min_delta_turns: int = to_non_negative_int( + settings.get("compact_context_min_delta_turns", 0), + 0, + ) + self.compact_context_debounce_seconds: int = to_non_negative_int( + settings.get("compact_context_debounce_seconds", 0), + 0, + ) self.max_context_length = settings["max_context_length"] # int self.dequeue_context_length: int = min( max(1, settings["dequeue_context_length"]), @@ -125,6 +153,13 @@ async def initialize(self, ctx: PipelineContext) -> None: llm_compress_instruction=self.llm_compress_instruction, llm_compress_keep_recent=self.llm_compress_keep_recent, llm_compress_provider_id=self.llm_compress_provider_id, + context_token_counter_mode=self.context_token_counter_mode, + compact_context_after_tool_call=self.compact_context_after_tool_call, + compact_context_soft_ratio=self.compact_context_soft_ratio, + compact_context_hard_ratio=self.compact_context_hard_ratio, + compact_context_min_delta_tokens=self.compact_context_min_delta_tokens, + compact_context_min_delta_turns=self.compact_context_min_delta_turns, + compact_context_debounce_seconds=self.compact_context_debounce_seconds, max_context_length=self.max_context_length, dequeue_context_length=self.dequeue_context_length, llm_safety_mode=self.llm_safety_mode, diff --git a/astrbot/core/prompt_assembly_router.py b/astrbot/core/prompt_assembly_router.py new file mode 100644 index 0000000000..5936996fb0 --- /dev/null +++ b/astrbot/core/prompt_assembly_router.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from collections.abc import Iterable + + +def _normalize_text_items(items: Iterable[object] | None) -> list[str]: + normalized: list[str] = [] + if not items: + return normalized + for item in items: + text = str(item or "").strip() + if text: + normalized.append(text) + return normalized + + +def _render_long_term_facts_block(facts: list[str]) -> str: + if not facts: + return "" + lines = [ + "", + "Use these retrieved long-term facts when relevant:", + ] + for idx, fact in enumerate(facts, start=1): + lines.append(f"{idx}. {fact}") + lines.append("") + return "\n".join(lines) + + +def _render_summarized_history_block(summary: str) -> str: + text = str(summary or "").strip() + if not text: + return "" + return "\n".join( + [ + "", + text, + "", + ] + ) + + +def assemble_system_prompt( + *, + base_system_prompt: str, + retrieved_long_term_facts: Iterable[object] | None = None, + summarized_history: str = "", + pinned_memory_block: str = "", +) -> str: + """Assemble final system prompt with stable section ordering. + + Section order: + 1) Base system prompt + 2) Retrieved long-term facts + 3) Summarized history + 4) Pinned top-level memory + """ + sections: list[str] = [] + base = str(base_system_prompt or "").strip() + if base: + sections.append(base) + + facts_block = _render_long_term_facts_block( + _normalize_text_items(retrieved_long_term_facts) + ) + if facts_block: + sections.append(facts_block) + + summary_block = _render_summarized_history_block(summarized_history) + if summary_block: + sections.append(summary_block) + + pinned = str(pinned_memory_block or "").strip() + if pinned: + sections.append(pinned) + + return "\n\n".join(sections).strip() diff --git a/astrbot/core/utils/config_normalization.py b/astrbot/core/utils/config_normalization.py new file mode 100644 index 0000000000..84faa19719 --- /dev/null +++ b/astrbot/core/utils/config_normalization.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Any + + +def to_bool(value: Any, default: bool) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + return default + + +def to_int(value: Any, default: int, min_value: int | None = None) -> int: + try: + parsed = int(value) + except Exception: + parsed = default + if min_value is not None: + parsed = max(parsed, min_value) + return parsed + + +def to_non_negative_int(value: Any, default: int = 0) -> int: + return max(0, to_int(value, default)) + + +def to_ratio(value: Any, default: float) -> float: + try: + parsed = float(value) + except Exception: + parsed = default + if parsed > 1.0 and parsed <= 100.0: + parsed = parsed / 100.0 + return min(max(parsed, 0.0), 1.0) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index bcd7e075c7..76ffc341e0 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -160,7 +160,10 @@ def validate(data: dict, metadata: dict = schema, path="") -> None: for item in value: validate(item, meta["items"], path=f"{path}{key}.") elif meta["type"] == "object" and isinstance(value, dict): - validate(value, meta["items"], path=f"{path}{key}.") + object_schema = meta.get("items") + if not isinstance(object_schema, dict): + object_schema = meta.get("properties", {}) + validate(value, object_schema, path=f"{path}{key}.") if meta["type"] == "int" and not isinstance(value, int): casted = try_cast(value, "int") diff --git a/astrbot/dashboard/routes/util.py b/astrbot/dashboard/routes/util.py index 1056198158..c3858c7924 100644 --- a/astrbot/dashboard/routes/util.py +++ b/astrbot/dashboard/routes/util.py @@ -15,7 +15,7 @@ def get_schema_item(schema: dict | None, key_path: str) -> dict | None: 同时支持: - 扁平 schema(直接 key 命中) - - 嵌套 object schema({type: "object", items: {...}}) + - 嵌套 object schema({type: "object", items/properties: {...}}) """ if not isinstance(schema, dict) or not key_path: @@ -33,7 +33,9 @@ def get_schema_item(schema: dict | None, key_path: str) -> dict | None: return meta if not isinstance(meta, dict) or meta.get("type") != "object": return None - current = meta.get("items", {}) + current = meta.get("items") + if not isinstance(current, dict): + current = meta.get("properties", {}) return None diff --git a/tests/agent/test_context_manager.py b/tests/agent/test_context_manager.py index 0b955ff401..5f78028111 100644 --- a/tests/agent/test_context_manager.py +++ b/tests/agent/test_context_manager.py @@ -94,6 +94,18 @@ def test_init_with_truncate_compressor(self): assert isinstance(manager.compressor, TruncateByTurnsCompressor) + @patch("astrbot.core.agent.context.manager.create_token_counter") + def test_init_uses_token_counter_mode(self, mock_create_token_counter): + """Test token counter mode wiring into ContextManager.""" + fake_counter = MagicMock() + mock_create_token_counter.return_value = fake_counter + config = ContextConfig(token_counter_mode="auto", token_counter_model="gpt-4") + + manager = ContextManager(config) + + mock_create_token_counter.assert_called_once_with("auto", model="gpt-4") + assert manager.token_counter is fake_counter + # ==================== Empty and Edge Cases ==================== @pytest.mark.asyncio @@ -210,6 +222,29 @@ async def test_token_compression_not_triggered_below_threshold(self): mock_compress.assert_not_called() assert result == messages + @pytest.mark.asyncio + async def test_force_compaction_bypasses_threshold_gate(self): + """Test that force_compaction bypasses compressor threshold gate.""" + config = ContextConfig(max_context_tokens=1000) + manager = ContextManager(config) + + messages = [self.create_message("user", "Hello")] + + with patch.object( + manager.compressor, "should_compress", return_value=False + ) as mock_should_compress: + with patch.object( + manager, + "_run_compression", + new_callable=AsyncMock, + return_value=messages, + ) as mock_run_compression: + result = await manager.process(messages, force_compaction=True) + + mock_should_compress.assert_not_called() + mock_run_compression.assert_called_once() + assert result == messages + @pytest.mark.asyncio async def test_token_compression_triggered_above_threshold(self): """Test that compression is triggered above threshold.""" diff --git a/tests/agent/test_token_counter.py b/tests/agent/test_token_counter.py index c68b056e66..fad8dfda1f 100644 --- a/tests/agent/test_token_counter.py +++ b/tests/agent/test_token_counter.py @@ -1,9 +1,14 @@ """Tests for EstimateTokenCounter multimodal support.""" +import sys +from types import SimpleNamespace + from astrbot.core.agent.context.token_counter import ( AUDIO_TOKEN_ESTIMATE, IMAGE_TOKEN_ESTIMATE, EstimateTokenCounter, + TokenizerTokenCounter, + create_token_counter, ) from astrbot.core.agent.message import ( AudioURLPart, @@ -13,7 +18,6 @@ ThinkPart, ) - counter = EstimateTokenCounter() @@ -101,3 +105,125 @@ def test_tool_calls_counted(self): # 文本 + tool call JSON 都应被计算 text_only = counter.count_tokens([_msg("assistant", "calling tool")]) assert tokens > text_only + + +class TestCounterFactory: + def test_create_estimate_mode(self): + created = create_token_counter("estimate") + assert isinstance(created, EstimateTokenCounter) + + def test_create_unknown_mode_fallback(self): + created = create_token_counter("unknown-mode") + assert isinstance(created, EstimateTokenCounter) + + def test_create_tokenizer_mode_returns_valid_counter_type(self): + created = create_token_counter("tokenizer", model="gpt-4") + assert isinstance(created, (TokenizerTokenCounter, EstimateTokenCounter)) + + def test_tokenizer_counter_gracefully_handles_broken_fallback_encoder( + self, monkeypatch + ): + fake_tiktoken = SimpleNamespace( + encoding_for_model=lambda _model: (_ for _ in ()).throw(RuntimeError("boom")), + get_encoding=lambda _name: (_ for _ in ()).throw(RuntimeError("boom")), + ) + monkeypatch.setitem(sys.modules, "tiktoken", fake_tiktoken) + + counter = TokenizerTokenCounter(model="gpt-4") + assert counter.available is False + + created = create_token_counter("tokenizer", model="gpt-4") + assert isinstance(created, EstimateTokenCounter) + + +class TestTokenizerTokenCounterBehavior: + def test_count_tokens_uses_trusted_usage_and_skips_encode(self, monkeypatch): + counter = TokenizerTokenCounter(model="gpt-4") + + def _encode_should_not_be_called(_text): # pragma: no cover - safety guard + raise AssertionError( + "_encode should not be called when trusted_token_usage is provided" + ) + + monkeypatch.setattr(counter, "_encode", _encode_should_not_be_called) + + messages = [ + _msg( + "user", + [ + TextPart(text="hello"), + ThinkPart(think="internal thoughts"), + ImageURLPart( + image_url=ImageURLPart.ImageURL( + url="data:image/png;base64,abc" + ) + ), + AudioURLPart( + audio_url=AudioURLPart.AudioURL(url="https://x.com/a.mp3") + ), + ], + ) + ] + + trusted_usage = 123 + result = counter.count_tokens(messages, trusted_token_usage=trusted_usage) + assert result == trusted_usage + + def test_encode_error_falls_back_to_estimate_text_tokens(self, monkeypatch): + counter = TokenizerTokenCounter(model="gpt-4") + + def broken_encode(_text): + raise RuntimeError("tiktoken failure") + + captured: dict[str, str] = {} + + def fake_estimate(text: str) -> int: + captured["text"] = text + return 42 + + monkeypatch.setattr(counter, "_encode", broken_encode) + monkeypatch.setattr(counter._estimate, "estimate_text_tokens", fake_estimate) + + result = counter._encode_len("fallback text") + assert result == 42 + assert captured["text"] == "fallback text" + + def test_tokenizer_mode_mixed_modalities_use_fixed_estimates(self, monkeypatch): + counter = TokenizerTokenCounter(model="gpt-4") + counter._available = True + + def fake_encode_len(text: str) -> int: + return len(text.split()) + + monkeypatch.setattr(counter, "_encode_len", fake_encode_len) + + messages = [ + _msg( + "user", + [ + TextPart(text="hello world"), + ImageURLPart( + image_url=ImageURLPart.ImageURL( + url="data:image/png;base64,image" + ) + ), + ], + ), + _msg( + "assistant", + [ + ThinkPart(think="thinking hard"), + AudioURLPart( + audio_url=AudioURLPart.AudioURL(url="https://x.com/a.mp3") + ), + ], + ), + ] + + expected_text_tokens = fake_encode_len("hello world") + fake_encode_len( + "thinking hard" + ) + expected_non_text_tokens = IMAGE_TOKEN_ESTIMATE + AUDIO_TOKEN_ESTIMATE + + tokens = counter.count_tokens(messages) + assert tokens == expected_text_tokens + expected_non_text_tokens diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 38c601cee5..e124941a94 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -1,5 +1,6 @@ import os import sys +from types import SimpleNamespace from unittest.mock import AsyncMock import pytest @@ -9,7 +10,11 @@ from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner +from astrbot.core.agent.runners.tool_loop_agent_runner import ( + PostToolCompactionConfig, + PostToolCompactionController, + ToolLoopAgentRunner, +) from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.provider.entities import LLMResponse, ProviderRequest, TokenUsage from astrbot.core.provider.provider import Provider @@ -536,6 +541,210 @@ async def test_follow_up_ticket_not_consumed_when_no_next_tool_call( assert ticket.consumed is False +@pytest.mark.asyncio +async def test_compact_context_after_tool_call_enabled( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + mock_provider.should_call_tools = True + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + compact_context_after_tool_call=True, + ) + + runner.context_manager.process = AsyncMock( # type: ignore[method-assign] + side_effect=lambda messages, trusted_token_usage=0: messages, + ) + + async for _ in runner.step(): + pass + + assert runner.context_manager.process.await_count == 2 + + +@pytest.mark.asyncio +async def test_compact_context_after_tool_call_disabled_by_default( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + mock_provider.should_call_tools = True + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + runner.context_manager.process = AsyncMock( # type: ignore[method-assign] + side_effect=lambda messages, trusted_token_usage=0: messages, + ) + + async for _ in runner.step(): + pass + + assert runner.context_manager.process.await_count == 1 + + +@pytest.mark.asyncio +async def test_compact_context_after_tool_call_honors_debounce( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + mock_provider.should_call_tools = True + mock_provider.provider_config["max_context_tokens"] = 100 + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + compact_context_after_tool_call=True, + compact_context_soft_ratio=0.3, + compact_context_hard_ratio=0.4, + compact_context_debounce_seconds=3600, + ) + + runner.context_manager.token_counter = SimpleNamespace( + count_tokens=lambda *_args, **_kwargs: 90 + ) + runner.context_manager.process = AsyncMock( # type: ignore[method-assign] + side_effect=lambda messages, trusted_token_usage=0: messages, + ) + + # step 1: pre-LLM compact + post-tool compact + async for _ in runner.step(): + pass + # step 2: pre-LLM compact + post-tool compact skipped by debounce + async for _ in runner.step(): + pass + + assert runner.context_manager.process.await_count == 3 + + +def test_post_tool_compaction_soft_zone_respects_min_delta(runner): + runner.post_tool_compaction = PostToolCompactionConfig( + enabled=True, + soft_ratio=0.3, + hard_ratio=0.9, + min_delta_tokens=10, + min_delta_turns=10, + debounce_seconds=0, + ) + runner.post_tool_compaction_controller = PostToolCompactionController( + runner.post_tool_compaction + ) + runner.context_config = SimpleNamespace(max_context_tokens=100) + runner.run_context = SimpleNamespace(messages=[object(), object()]) + runner.context_manager = SimpleNamespace( + token_counter=SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 35) + ) + runner.post_tool_compaction_controller.refresh_baseline( + messages=runner.run_context.messages, + token_counter=SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 30), + ) + + # ratio=0.35 in soft zone, token delta=5 and message delta=0 -> should skip + assert runner._should_run_post_tool_compaction() is False + + runner.context_manager = SimpleNamespace( + token_counter=SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 95) + ) + # ratio=0.95 in hard zone -> force compaction + assert runner._should_run_post_tool_compaction() is True + + +def test_post_tool_compaction_handles_token_counter_errors(runner): + runner.post_tool_compaction = PostToolCompactionConfig( + enabled=True, + soft_ratio=0.3, + hard_ratio=0.9, + min_delta_tokens=10, + min_delta_turns=10, + debounce_seconds=0, + ) + runner.post_tool_compaction_controller = PostToolCompactionController( + runner.post_tool_compaction + ) + runner.context_config = SimpleNamespace(max_context_tokens=100) + runner.run_context = SimpleNamespace(messages=[object(), object()]) + + def _raise(*_args, **_kwargs): + raise RuntimeError("counter broken") + + runner.context_manager = SimpleNamespace( + token_counter=SimpleNamespace(count_tokens=_raise) + ) + + assert runner._should_run_post_tool_compaction() is False + + +def test_post_tool_compaction_debounce_is_not_extended(monkeypatch): + config = PostToolCompactionConfig( + enabled=True, + soft_ratio=0.3, + hard_ratio=0.9, + min_delta_tokens=0, + min_delta_turns=0, + debounce_seconds=100, + ) + controller = PostToolCompactionController(config) + messages = [object()] + token_counter = SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 95) + + # refresh baseline before checks + controller.refresh_baseline( + messages=messages, + token_counter=SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 30), + ) + + ts = iter([1.0, 10.0, 20.0, 105.0]) + monkeypatch.setattr( + "astrbot.core.agent.runners.tool_loop_agent_runner.time.monotonic", + lambda: next(ts), + ) + + # first check performs decision and sets baseline timestamp + assert ( + controller.should_compact( + messages=messages, + token_counter=token_counter, + max_context_tokens=100, + ) + is True + ) + # next two checks are debounced + assert ( + controller.should_compact( + messages=messages, + token_counter=token_counter, + max_context_tokens=100, + ) + is False + ) + assert ( + controller.should_compact( + messages=messages, + token_counter=token_counter, + max_context_tokens=100, + ) + is False + ) + # should become eligible at t=105 if debounce anchor remains at first real check (t=1) + assert ( + controller.should_compact( + messages=messages, + token_counter=token_counter, + max_context_tokens=100, + ) + is True + ) + + if __name__ == "__main__": # 运行测试 pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 9a42abd733..40c393ad96 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -663,6 +663,100 @@ async def test_decorate_llm_request_no_conversation(self, mock_event, mock_conte assert req.prompt == "Hello" + @pytest.mark.asyncio + async def test_decorate_llm_request_injects_pinned_context_memory( + self, mock_event, mock_context + ): + module = ama + req = ProviderRequest(prompt="Hello", system_prompt="System") + config = module.MainAgentBuildConfig( + tool_call_timeout=60, + provider_settings={ + "context_memory": { + "enabled": True, + "inject_pinned_memory": True, + "pinned_memories": [ + "用户喜欢先结论后细节。", + "优先使用中文回复。", + ], + "pinned_max_items": 8, + "pinned_max_chars_per_item": 200, + } + }, + ) + + with patch.object(mock_context, "get_config") as mock_get_config: + mock_get_config.return_value = {} + await module._decorate_llm_request(mock_event, req, mock_context, config) + + assert "" in req.system_prompt + assert "1. 用户喜欢先结论后细节。" in req.system_prompt + assert "2. 优先使用中文回复。" in req.system_prompt + + @pytest.mark.asyncio + async def test_decorate_llm_request_skips_pinned_memory_when_disabled( + self, mock_event, mock_context + ): + module = ama + req = ProviderRequest(prompt="Hello", system_prompt="System") + config = module.MainAgentBuildConfig( + tool_call_timeout=60, + provider_settings={ + "context_memory": { + "enabled": True, + "inject_pinned_memory": False, + "pinned_memories": ["这条不应被注入。"], + } + }, + ) + + with patch.object(mock_context, "get_config") as mock_get_config: + mock_get_config.return_value = {} + await module._decorate_llm_request(mock_event, req, mock_context, config) + + assert "" not in req.system_prompt + + @pytest.mark.asyncio + async def test_decorate_llm_request_assembles_prompt_layers_in_order( + self, mock_event, mock_context + ): + module = ama + req = ProviderRequest(prompt="Hello", system_prompt="System") + config = module.MainAgentBuildConfig( + tool_call_timeout=60, + provider_settings={ + "context_memory": { + "enabled": True, + "inject_pinned_memory": True, + "pinned_memories": ["固定偏好A"], + } + }, + ) + + extras = { + "retrieved_long_term_facts": ["事实1", "事实2"], + "compacted_history_summary": "历史摘要S", + } + mock_event.get_extra.side_effect = lambda key: extras.get(key) + + with patch.object(mock_context, "get_config") as mock_get_config: + mock_get_config.return_value = {} + await module._decorate_llm_request(mock_event, req, mock_context, config) + + system_prompt = req.system_prompt + assert "System" in system_prompt + assert "" in system_prompt + assert "" in system_prompt + assert "" in system_prompt + # ordering: base -> long_term_facts -> summarized_history -> top_level_memory + assert system_prompt.index("System") < system_prompt.index("") + assert system_prompt.index("") < system_prompt.index( + "" + ) + assert system_prompt.index("") < system_prompt.index( + "" + ) + class TestModalitiesFix: """Tests for _modalities_fix function.""" diff --git a/tests/unit/test_context_compaction_command.py b/tests/unit/test_context_compaction_command.py new file mode 100644 index 0000000000..f4d3e39e93 --- /dev/null +++ b/tests/unit/test_context_compaction_command.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from astrbot.builtin_stars.builtin_commands.commands.context_compaction import ( + ContextCompactionCommands, +) +from astrbot.core.context_compaction_scheduler import PeriodicContextCompactionScheduler + + +class DummyConfigManager: + def __init__(self, default_conf: dict): + self.default_conf = default_conf + + +def _build_scheduler() -> PeriodicContextCompactionScheduler: + cfg_mgr = DummyConfigManager( + { + "provider_settings": { + "periodic_context_compaction": { + "enabled": True, + "interval_minutes": 3, + "max_conversations_per_run": 2, + "max_scan_per_run": 10, + "target_tokens": 1024, + "trigger_tokens": 2048, + "max_rounds": 2, + } + } + } + ) + return PeriodicContextCompactionScheduler( + config_manager=cfg_mgr, + conversation_manager=SimpleNamespace(), + provider_manager=SimpleNamespace(), + ) + + +@pytest.mark.asyncio +async def test_status_when_scheduler_unavailable() -> None: + command = ContextCompactionCommands(context=SimpleNamespace()) + event = SimpleNamespace(send=AsyncMock()) + + await command.status(event) + + event.send.assert_awaited_once() + chain = event.send.await_args.args[0] + assert "不可用" in chain.get_plain_text(with_other_comps_mark=True) + + +@pytest.mark.asyncio +async def test_status_with_runtime_report() -> None: + scheduler = _build_scheduler() + scheduler._last_status.report = { + "reason": "manual_command", + "scanned": 8, + "compacted": 2, + "skipped": 6, + "failed": 0, + "elapsed_sec": 1.2, + } + scheduler._last_status.started_at = "2026-03-19T12:00:00+00:00" + scheduler._last_status.finished_at = "2026-03-19T12:00:01+00:00" + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.status(event) + + event.send.assert_awaited_once() + chain = event.send.await_args.args[0] + text = chain.get_plain_text(with_other_comps_mark=True) + assert "定时上下文压缩状态" in text + assert "最近任务[manual_command]" in text + assert "compacted=2" in text + + +@pytest.mark.asyncio +async def test_status_with_no_report() -> None: + scheduler = _build_scheduler() + scheduler._last_status.report = None + scheduler._last_status.error = None + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.status(event) + + event.send.assert_awaited_once() + chain = event.send.await_args.args[0] + text = chain.get_plain_text(with_other_comps_mark=True) + assert "定时上下文压缩状态" in text + assert "启用=是" in text + assert "最近任务:暂无" in text + + +@pytest.mark.asyncio +async def test_status_includes_last_error_line() -> None: + scheduler = _build_scheduler() + scheduler._last_status.report = { + "reason": "manual_command", + "scanned": 1, + "compacted": 0, + "skipped": 1, + "failed": 0, + "elapsed_sec": 0.3, + } + scheduler._last_status.error = "mock error" + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.status(event) + + event.send.assert_awaited_once() + chain = event.send.await_args.args[0] + text = chain.get_plain_text(with_other_comps_mark=True) + assert "最近任务[manual_command]" in text + assert "最近错误:mock error" in text + + +@pytest.mark.asyncio +async def test_run_with_invalid_limit() -> None: + scheduler = _build_scheduler() + scheduler.run_once = AsyncMock() + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.run(event, 0) + + scheduler.run_once.assert_not_awaited() + chain = event.send.await_args.args[0] + assert "limit 必须 >= 1" in chain.get_plain_text(with_other_comps_mark=True) + + +@pytest.mark.asyncio +async def test_run_triggers_scheduler_once() -> None: + scheduler = _build_scheduler() + scheduler.run_once = AsyncMock( + return_value={ + "scanned": 12, + "compacted": 3, + "skipped": 8, + "failed": 1, + "elapsed_sec": 2.5, + } + ) + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.run(event, 5) + + scheduler.run_once.assert_awaited_once_with( + reason="manual_command", + max_conversations_override=5, + ) + chain = event.send.await_args.args[0] + text = chain.get_plain_text(with_other_comps_mark=True) + assert "手动触发完成" in text + assert "compacted=3" in text + + +@pytest.mark.asyncio +async def test_run_reports_error_when_scheduler_raises() -> None: + scheduler = _build_scheduler() + scheduler.run_once = AsyncMock(side_effect=RuntimeError("mock boom")) + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.run(event, 2) + + scheduler.run_once.assert_awaited_once_with( + reason="manual_command", + max_conversations_override=2, + ) + chain = event.send.await_args.args[0] + text = chain.get_plain_text(with_other_comps_mark=True) + assert text == "触发压缩失败,请查看服务端日志。" diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py new file mode 100644 index 0000000000..8d3f8ef6d2 --- /dev/null +++ b/tests/unit/test_context_compaction_scheduler.py @@ -0,0 +1,410 @@ +from __future__ import annotations + +from dataclasses import replace +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from astrbot.core.agent.message import Message +from astrbot.core.config.default import PERIODIC_CONTEXT_COMPACTION_DEFAULTS +from astrbot.core.context_compaction_scheduler import ( + CompactionConfig, + CompactionPolicy, + PeriodicContextCompactionScheduler, +) + + +class DummyConfigManager: + def __init__(self, default_conf: dict): + self.default_conf = default_conf + + +def _build_scheduler(cfg: dict) -> PeriodicContextCompactionScheduler: + manager = DummyConfigManager({"provider_settings": cfg}) + return PeriodicContextCompactionScheduler( + config_manager=manager, + conversation_manager=SimpleNamespace(), + provider_manager=SimpleNamespace(), + ) + + +def test_load_config_normalizes_values() -> None: + scheduler = _build_scheduler( + { + "periodic_context_compaction": { + "enabled": "true", + "interval_minutes": "0", + "target_tokens": 1024, + "trigger_tokens": 1000, + "max_rounds": "2", + } + } + ) + + cfg = scheduler._load_config() + + assert cfg.enabled is True + assert cfg.interval_minutes == 1 + assert cfg.target_tokens == 1024 + assert cfg.trigger_tokens == 1000 + assert cfg.trigger_min_context_ratio == pytest.approx(0.3) + assert cfg.max_rounds == 2 + + +@pytest.mark.parametrize( + ("raw_enabled", "expected"), + [ + ("true", True), + ("false", False), + ("1", True), + ("0", False), + ("yes", True), + ("no", False), + ("unknown", False), + ], +) +def test_load_config_enabled_bool_parsing(raw_enabled: str, expected: bool) -> None: + scheduler = _build_scheduler( + { + "periodic_context_compaction": { + "enabled": raw_enabled, + } + } + ) + + cfg = scheduler._load_config() + assert cfg.enabled is expected + + +@pytest.mark.parametrize( + ("raw_cfg", "expected_interval", "expected_scan_page_size", "expected_min_messages"), + [ + ({"interval_minutes": 0, "scan_page_size": 1, "min_messages": 0}, 1, 10, 2), + ( + {"interval_minutes": -3, "scan_page_size": -5, "min_messages": -1}, + 1, + 10, + 2, + ), + ( + {"interval_minutes": "0", "scan_page_size": "1", "min_messages": "0"}, + 1, + 10, + 2, + ), + ], +) +def test_load_config_clamps_numeric_minimums( + raw_cfg: dict, + expected_interval: int, + expected_scan_page_size: int, + expected_min_messages: int, +) -> None: + scheduler = _build_scheduler({"periodic_context_compaction": raw_cfg}) + cfg = scheduler._load_config() + + assert cfg.interval_minutes == expected_interval + assert cfg.scan_page_size == expected_scan_page_size + assert cfg.min_messages == expected_min_messages + + +@pytest.mark.parametrize( + ("raw_cfg", "expected_target", "expected_trigger"), + [ + ({"target_tokens": 1024}, 1024, 0), + ({"target_tokens": 1024, "trigger_tokens": None}, 1024, 0), + ({"target_tokens": 1024, "trigger_tokens": 512}, 1024, 512), + ({"target_tokens": 1024, "trigger_tokens": "512"}, 1024, 512), + ({"target_tokens": 1024, "trigger_tokens": 2048}, 1024, 2048), + ({"target_tokens": 10}, 512, 0), + ], +) +def test_load_config_token_threshold_normalization( + raw_cfg: dict, + expected_target: int, + expected_trigger: int, +) -> None: + scheduler = _build_scheduler({"periodic_context_compaction": raw_cfg}) + cfg = scheduler._load_config() + + assert cfg.target_tokens == expected_target + assert cfg.trigger_tokens == expected_trigger + + +@pytest.mark.parametrize( + ("raw_ratio", "expected"), + [ + (0.3, 0.3), + ("30", 0.3), + ("0.25", 0.25), + (-1, 0.0), + (500, 1.0), + ], +) +def test_load_config_trigger_ratio_normalization(raw_ratio, expected: float) -> None: + scheduler = _build_scheduler( + {"periodic_context_compaction": {"trigger_min_context_ratio": raw_ratio}} + ) + cfg = scheduler._load_config() + assert cfg.trigger_min_context_ratio == pytest.approx(expected) + + +@pytest.mark.parametrize("raw_value", [None, 1, "not-a-dict", []]) +def test_load_config_falls_back_for_non_dict(raw_value) -> None: + scheduler = _build_scheduler({"periodic_context_compaction": raw_value}) + cfg = scheduler._load_config() + + expected = CompactionConfig(**PERIODIC_CONTEXT_COMPACTION_DEFAULTS) + assert cfg == expected + + +def test_get_status_returns_runtime_snapshot() -> None: + scheduler = _build_scheduler( + {"periodic_context_compaction": {"enabled": True, "interval_minutes": 3}} + ) + status = scheduler.get_status() + + assert status["running"] is False + assert status["config"]["enabled"] is True + assert status["config"]["interval_minutes"] == 3 + assert status["last_report"] is None + + +def test_sanitize_message_dict_keeps_supported_parts() -> None: + scheduler = _build_scheduler({}) + + raw = { + "role": "assistant", + "content": [ + {"type": "think", "think": "internal reasoning"}, + {"type": "text", "text": "visible answer"}, + {"type": "image_url", "image_url": {"url": "https://x.test/a.png"}}, + {"type": "unknown", "foo": "bar"}, + ], + } + + sanitized = scheduler._history_parser.sanitize_message_dict(raw) + + assert sanitized is not None + assert sanitized["role"] == "assistant" + assert isinstance(sanitized["content"], list) + content = sanitized["content"] + assert content[0]["type"] == "text" + assert "internal reasoning" in content[0]["text"] + assert any(part.get("type") == "image_url" for part in content) + + +def test_is_idle_enough_respects_threshold() -> None: + now = datetime.now(timezone.utc) + old = now - timedelta(minutes=30) + recent = now - timedelta(minutes=2) + + assert CompactionPolicy.is_idle_enough(old, 10) is True + assert CompactionPolicy.is_idle_enough(recent, 10) is False + assert CompactionPolicy.is_idle_enough(None, 10) is True + + +def test_resolve_run_limits_treats_max_scan_as_upper_bound() -> None: + scheduler = _build_scheduler({}) + cfg = replace( + scheduler._load_config(), + max_conversations_per_run=8, + max_scan_per_run=3, + scan_page_size=5, + ) + + max_to_compact, max_to_scan, page_size = scheduler._resolve_run_limits(cfg, None) + assert max_to_compact == 3 + assert max_to_scan == 3 + assert page_size == 10 + + max_to_compact, max_to_scan, _ = scheduler._resolve_run_limits(cfg, 20) + assert max_to_compact == 3 + assert max_to_scan == 3 + + +def test_resolve_trigger_tokens_prefers_manual_value() -> None: + scheduler = _build_scheduler({}) + cfg = replace( + scheduler._load_config(), + target_tokens=1024, + trigger_tokens=1500, + trigger_min_context_ratio=0.3, + ) + policy = CompactionPolicy(cfg=cfg, token_counter=SimpleNamespace()) + provider = SimpleNamespace( + provider_config={"max_context_tokens": 32768}, + get_model=lambda: "unknown-model", + ) + + resolved = policy.resolve_trigger_tokens(provider) + assert resolved == 1500 + + +def test_resolve_trigger_tokens_uses_ratio_when_auto_mode() -> None: + scheduler = _build_scheduler({}) + cfg = replace( + scheduler._load_config(), + target_tokens=1024, + trigger_tokens=0, + trigger_min_context_ratio=0.3, + ) + policy = CompactionPolicy(cfg=cfg, token_counter=SimpleNamespace()) + provider = SimpleNamespace( + provider_config={"max_context_tokens": 32768}, + get_model=lambda: "unknown-model", + ) + + resolved = policy.resolve_trigger_tokens(provider) + assert resolved == 9830 + + +def test_resolve_trigger_tokens_falls_back_when_provider_context_unknown() -> None: + scheduler = _build_scheduler({}) + cfg = replace( + scheduler._load_config(), + target_tokens=1024, + trigger_tokens=0, + trigger_min_context_ratio=0.3, + ) + policy = CompactionPolicy(cfg=cfg, token_counter=SimpleNamespace()) + provider = SimpleNamespace( + provider_config={"max_context_tokens": 0}, + get_model=lambda: "unknown-model", + ) + + resolved = policy.resolve_trigger_tokens(provider) + assert resolved == 1536 + + +def test_resolve_token_counter_uses_configured_mode_and_provider_model(monkeypatch) -> None: + scheduler = _build_scheduler({"context_token_counter_mode": "auto"}) + provider = SimpleNamespace(get_model=lambda: "gpt-4o") + called: dict[str, str | None] = {} + + fake_counter = SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 0) + + def _fake_create(mode: str | None = None, *, model: str | None = None): + called["mode"] = mode + called["model"] = model + return fake_counter + + monkeypatch.setattr( + "astrbot.core.context_compaction_scheduler.create_token_counter", + _fake_create, + ) + + resolved = scheduler._resolve_token_counter(provider) + assert resolved is fake_counter + assert called["mode"] == "auto" + assert called["model"] == "gpt-4o" + + +def test_resolve_token_counter_prefers_provider_level_mode(monkeypatch) -> None: + scheduler = _build_scheduler({"context_token_counter_mode": "estimate"}) + provider = SimpleNamespace( + get_model=lambda: "gpt-4o", + provider_settings={"context_token_counter_mode": "tokenizer"}, + ) + called: dict[str, str | None] = {} + + fake_counter = SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 0) + + def _fake_create(mode: str | None = None, *, model: str | None = None): + called["mode"] = mode + called["model"] = model + return fake_counter + + monkeypatch.setattr( + "astrbot.core.context_compaction_scheduler.create_token_counter", + _fake_create, + ) + + resolved = scheduler._resolve_token_counter(provider) + assert resolved is fake_counter + assert called["mode"] == "tokenizer" + assert called["model"] == "gpt-4o" + + +@pytest.mark.asyncio +async def test_iter_candidate_conversations_does_not_apply_idle_filter_in_db_query() -> None: + scheduler = _build_scheduler( + {"periodic_context_compaction": {"enabled": True, "min_idle_minutes": 30}} + ) + + class _FakeDB: + def __init__(self) -> None: + self.updated_before_calls: list[datetime | None] = [] + + async def get_filtered_conversations( + self, + *, + page: int, + page_size: int, + updated_before: datetime | None, + min_messages: int, + ): + self.updated_before_calls.append(updated_before) + return [], 0 + + fake_db = _FakeDB() + scheduler.conversation_manager = SimpleNamespace(db=fake_db) # type: ignore[assignment] + + cfg = scheduler._load_config() + result = [ + conv + async for conv in scheduler._iter_candidate_conversations( + scan_page_size=40, + cfg=cfg, + ) + ] + + assert result == [] + assert fake_db.updated_before_calls == [None] + + +@pytest.mark.asyncio +async def test_compact_one_conversation_dry_run_reports_skipped(monkeypatch) -> None: + scheduler = _build_scheduler({"periodic_context_compaction": {"enabled": True}}) + cfg = replace(scheduler._load_config(), dry_run=True) + + conv = SimpleNamespace( + conversation_id="conv-1", + user_id="user-1", + content=[], + token_usage=0, + updated_at=None, + ) + scheduler._resolve_provider = AsyncMock( # type: ignore[method-assign] + return_value=SimpleNamespace(get_model=lambda: "gpt-4o") + ) + scheduler._run_compaction_rounds = AsyncMock( # type: ignore[method-assign] + return_value=SimpleNamespace( + messages=[Message(role="user", content="after")], + changed=True, + rounds=1, + ) + ) + scheduler._resolve_token_counter = lambda _provider: SimpleNamespace( # type: ignore[method-assign] + count_tokens=lambda *_args, **_kwargs: 50 + ) + scheduler._persist_compacted_history = AsyncMock( # type: ignore[method-assign] + return_value=True + ) + monkeypatch.setattr( + CompactionPolicy, + "check_eligibility", + lambda self, _conv, _parser: ([Message(role="user", content="before")], 100), + ) + monkeypatch.setattr( + CompactionPolicy, + "resolve_trigger_tokens", + lambda self, _provider: 1, + ) + + outcome = await scheduler._compact_one_conversation(conv, cfg) + + assert outcome == "skipped" + scheduler._persist_compacted_history.assert_not_awaited() diff --git a/tests/unit/test_context_memory.py b/tests/unit/test_context_memory.py new file mode 100644 index 0000000000..830ac7a1b7 --- /dev/null +++ b/tests/unit/test_context_memory.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from astrbot.core.config.default import CONTEXT_MEMORY_DEFAULTS +from astrbot.core.context_memory import ( + ContextMemoryConfig, + normalize_context_memory_settings, +) +from astrbot.core.context_memory_experimental_backends import ( + ContextMemoryBackend, +) + + +def test_normalize_context_memory_settings_initializes_fresh_pinned_memories() -> None: + first = normalize_context_memory_settings(None) + second = normalize_context_memory_settings(None) + + first["pinned_memories"].append("A") + + assert first["pinned_memories"] == ["A"] + assert second["pinned_memories"] == [] + + +def test_context_memory_unified_backend_protocol_shape() -> None: + class _Backend: + async def evolve( + self, + *, + unified_msg_origin: str, + turns: list[str], + metadata=None, + ) -> dict: + return {"ok": True} + + async def retrieve( + self, + *, + unified_msg_origin: str, + query: str, + top_k: int, + ) -> list[str]: + return [] + + async def export_session(self, *, unified_msg_origin: str) -> dict: + return {} + + async def import_session( + self, + *, + unified_msg_origin: str, + payload: dict, + ) -> None: + return None + + assert isinstance(_Backend(), ContextMemoryBackend) + + +def test_context_memory_defaults_follow_single_source() -> None: + cfg = ContextMemoryConfig.from_raw(None) + + assert cfg.enabled == CONTEXT_MEMORY_DEFAULTS["enabled"] + assert cfg.inject_pinned_memory == CONTEXT_MEMORY_DEFAULTS["inject_pinned_memory"] + assert cfg.pinned_max_items == CONTEXT_MEMORY_DEFAULTS["pinned_max_items"] + assert ( + cfg.pinned_max_chars_per_item + == CONTEXT_MEMORY_DEFAULTS["pinned_max_chars_per_item"] + ) + assert cfg.retrieval_enabled == CONTEXT_MEMORY_DEFAULTS["retrieval_enabled"] + assert cfg.retrieval_backend == CONTEXT_MEMORY_DEFAULTS["retrieval_backend"] + assert cfg.retrieval_provider_id == CONTEXT_MEMORY_DEFAULTS["retrieval_provider_id"] + assert cfg.retrieval_top_k == CONTEXT_MEMORY_DEFAULTS["retrieval_top_k"] + assert cfg.pinned_memories == [] diff --git a/tests/unit/test_context_memory_command.py b/tests/unit/test_context_memory_command.py new file mode 100644 index 0000000000..8dfe85aa51 --- /dev/null +++ b/tests/unit/test_context_memory_command.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from astrbot.builtin_stars.builtin_commands.commands.context_memory import ( + ContextMemoryCommands, +) + + +class DummyConfig(dict): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.save_calls = 0 + + def save_config(self) -> None: + self.save_calls += 1 + + +def _build_command_with_cfg(cfg: DummyConfig) -> ContextMemoryCommands: + context = SimpleNamespace(get_config=lambda umo=None: cfg) + return ContextMemoryCommands(context=context) + + +def _build_event() -> SimpleNamespace: + return SimpleNamespace(unified_msg_origin="umo:test", send=AsyncMock()) + + +@pytest.mark.asyncio +async def test_status_shows_defaults() -> None: + cfg = DummyConfig({"provider_settings": {}}) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.status(event) + + event.send.assert_awaited_once() + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert "上下文记忆状态" in text + assert "启用=否" in text + assert "顶层记忆条数=0" in text + + +@pytest.mark.asyncio +async def test_add_list_and_remove_memory() -> None: + cfg = DummyConfig( + { + "provider_settings": { + "context_memory": { + "enabled": True, + "inject_pinned_memory": True, + "pinned_memories": [], + "pinned_max_items": 4, + "pinned_max_chars_per_item": 120, + } + } + } + ) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.add(event, "用户喜欢先给结论再解释。") + assert cfg.save_calls == 1 + stored = cfg["provider_settings"]["context_memory"]["pinned_memories"] + assert stored == ["用户喜欢先给结论再解释。"] + + await command.ls(event) + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert "手动顶层记忆列表" in text + assert "用户喜欢先给结论再解释" in text + + await command.rm(event, 1) + assert cfg["provider_settings"]["context_memory"]["pinned_memories"] == [] + + +@pytest.mark.asyncio +async def test_add_rejects_when_reaching_max_items() -> None: + cfg = DummyConfig( + { + "provider_settings": { + "context_memory": { + "pinned_memories": ["A"], + "pinned_max_items": 1, + "pinned_max_chars_per_item": 100, + } + } + } + ) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.add(event, "B") + + # only normalization happened in-memory, no successful add/save + assert cfg["provider_settings"]["context_memory"]["pinned_memories"] == ["A"] + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert "已达到顶层记忆最大条数" in text + + +@pytest.mark.asyncio +async def test_enable_and_retrieval_toggles() -> None: + cfg = DummyConfig({"provider_settings": {"context_memory": {"enabled": False}}}) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.enable(event) + assert cfg["provider_settings"]["context_memory"]["enabled"] is True + + await command.retrieval(event, "on") + assert cfg["provider_settings"]["context_memory"]["retrieval_enabled"] is True + + +@pytest.mark.asyncio +async def test_add_truncates_long_memory_item() -> None: + cfg = DummyConfig( + { + "provider_settings": { + "context_memory": { + "pinned_memories": [], + "pinned_max_items": 3, + "pinned_max_chars_per_item": 10, + } + } + } + ) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.add(event, "1234567890abcdef") + + stored = cfg["provider_settings"]["context_memory"]["pinned_memories"] + assert stored == ["1234567890"] + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert "已截断到 10 字符" in text + + +@pytest.mark.asyncio +async def test_ls_preview_length_uses_config_limit() -> None: + cfg = DummyConfig( + { + "provider_settings": { + "context_memory": { + "pinned_memories": ["abcdefghijklmno"], + "pinned_max_items": 3, + "pinned_max_chars_per_item": 10, + } + } + } + ) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.ls(event) + + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert "1. abcdefghij" in text + assert "1. abcdefghijk" not in text + + +@pytest.mark.asyncio +async def test_ls_preview_length_is_capped_by_display_limit() -> None: + long_text = "x" * 250 + cfg = DummyConfig( + { + "provider_settings": { + "context_memory": { + "pinned_memories": [long_text], + "pinned_max_items": 3, + "pinned_max_chars_per_item": 999, + } + } + } + ) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.ls(event) + + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert ("1. " + ("x" * 180) + "...") in text diff --git a/tests/unit/test_prompt_assembly_router.py b/tests/unit/test_prompt_assembly_router.py new file mode 100644 index 0000000000..eaa46403d2 --- /dev/null +++ b/tests/unit/test_prompt_assembly_router.py @@ -0,0 +1,23 @@ +from astrbot.core.prompt_assembly_router import assemble_system_prompt + + +def test_assemble_system_prompt_omits_empty_sections() -> None: + prompt = assemble_system_prompt(base_system_prompt="Base") + assert prompt == "Base" + + +def test_assemble_system_prompt_orders_sections() -> None: + prompt = assemble_system_prompt( + base_system_prompt="Base", + retrieved_long_term_facts=["Fact A", "Fact B"], + summarized_history="Summary C", + pinned_memory_block="\nPinned D\n", + ) + + assert "Base" in prompt + assert "" in prompt + assert "" in prompt + assert "" in prompt + assert prompt.index("Base") < prompt.index("") + assert prompt.index("") < prompt.index("") + assert prompt.index("") < prompt.index("")