From a65afe1da22202eac9efc111016bbfec741b4284 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 06:48:27 -0500 Subject: [PATCH 01/29] feat: add periodic llm context compaction scheduler --- astrbot/core/config/default.py | 215 +++++++ astrbot/core/context_compaction_scheduler.py | 555 ++++++++++++++++++ astrbot/core/core_lifecycle.py | 24 + .../unit/test_context_compaction_scheduler.py | 76 +++ 4 files changed, 870 insertions(+) create mode 100644 astrbot/core/context_compaction_scheduler.py create mode 100644 tests/unit/test_context_compaction_scheduler.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 58c1726814..5180abf5df 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -96,6 +96,24 @@ ), "llm_compress_keep_recent": 6, "llm_compress_provider_id": "", + "periodic_context_compaction": { + "enabled": False, + "interval_minutes": 30, + "startup_delay_seconds": 120, + "max_conversations_per_run": 8, + "max_scan_per_run": 120, + "scan_page_size": 40, + "min_idle_minutes": 15, + "min_messages": 14, + "target_tokens": 4096, + "trigger_tokens": 6144, + "max_rounds": 3, + "truncate_turns": 1, + "keep_recent": 6, + "provider_id": "", + "instruction": "", + "dry_run": False, + }, "max_context_length": -1, "dequeue_context_length": 1, "streaming_response": False, @@ -2509,6 +2527,59 @@ class ChatProviderTemplate(TypedDict): "prompt_prefix": { "type": "string", }, + "periodic_context_compaction": { + "type": "object", + "items": { + "enabled": { + "type": "bool", + }, + "interval_minutes": { + "type": "int", + }, + "startup_delay_seconds": { + "type": "int", + }, + "max_conversations_per_run": { + "type": "int", + }, + "max_scan_per_run": { + "type": "int", + }, + "scan_page_size": { + "type": "int", + }, + "min_idle_minutes": { + "type": "int", + }, + "min_messages": { + "type": "int", + }, + "target_tokens": { + "type": "int", + }, + "trigger_tokens": { + "type": "int", + }, + "max_rounds": { + "type": "int", + }, + "truncate_turns": { + "type": "int", + }, + "keep_recent": { + "type": "int", + }, + "provider_id": { + "type": "string", + }, + "instruction": { + "type": "string", + }, + "dry_run": { + "type": "bool", + }, + }, + }, "max_context_length": { "type": "int", }, @@ -3196,6 +3267,150 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.periodic_context_compaction.enabled": { + "description": "启用定时历史压缩", + "type": "bool", + "hint": "后台定时扫描会话历史,使用 LLM 摘要旧消息并回写对话历史,实现多轮 compact context。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.interval_minutes": { + "description": "定时间隔(分钟)", + "type": "int", + "hint": "每隔多少分钟执行一次压缩扫描。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.startup_delay_seconds": { + "description": "启动延迟(秒)", + "type": "int", + "hint": "AstrBot 启动后,等待指定秒数再执行首次压缩任务。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.max_conversations_per_run": { + "description": "单次最多压缩会话数", + "type": "int", + "hint": "每次任务最多实际压缩多少个会话。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.max_scan_per_run": { + "description": "单次最多扫描会话数", + "type": "int", + "hint": "每次任务最多扫描多少会话(包括被跳过的会话)。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.scan_page_size": { + "description": "分页扫描大小", + "type": "int", + "hint": "扫描 conversations 表时每页读取条数。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.min_idle_minutes": { + "description": "最小静默时长(分钟)", + "type": "int", + "hint": "会话最近更新时间小于该值时跳过,避免压缩活跃会话。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.min_messages": { + "description": "最小消息条数", + "type": "int", + "hint": "少于该消息条数的会话不参与压缩。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.target_tokens": { + "description": "目标 Token 阈值", + "type": "int", + "hint": "压缩目标上下文大小(token 估算值)。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.trigger_tokens": { + "description": "触发 Token 阈值", + "type": "int", + "hint": "会话估算 token 超过此值才触发压缩。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.max_rounds": { + "description": "每会话最大压缩轮数", + "type": "int", + "hint": "单个会话一次任务内最多执行几轮摘要压缩(实现 multiple compact context)。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.truncate_turns": { + "description": "截断轮数(后备)", + "type": "int", + "hint": "LLM 压缩后仍超限时,按轮截断的每次丢弃轮数。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.keep_recent": { + "description": "保留最近轮数", + "type": "int", + "hint": "压缩时始终保留最近 N 轮消息。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.provider_id": { + "description": "压缩模型提供商 ID", + "type": "string", + "_special": "select_provider", + "hint": "留空时按会话当前模型执行压缩。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.instruction": { + "description": "定时压缩提示词", + "type": "text", + "hint": "留空时复用 provider_settings.llm_compress_instruction。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.dry_run": { + "description": "演练模式(不回写)", + "type": "bool", + "hint": "开启后只记录日志,不实际写回数据库。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, }, "condition": { "provider_settings.agent_runner_type": "local", diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py new file mode 100644 index 0000000000..7493575a57 --- /dev/null +++ b/astrbot/core/context_compaction_scheduler.py @@ -0,0 +1,555 @@ +from __future__ import annotations + +import asyncio +import json +import time +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from astrbot import logger +from astrbot.core.agent.context.config import ContextConfig +from astrbot.core.agent.context.manager import ContextManager +from astrbot.core.agent.context.token_counter import EstimateTokenCounter +from astrbot.core.agent.message import Message +from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.conversation_mgr import ConversationManager +from astrbot.core.db.po import ConversationV2 +from astrbot.core.provider.entities import ProviderType +from astrbot.core.provider.provider import Provider + +if TYPE_CHECKING: + from astrbot.core.provider.manager import ProviderManager + + +@dataclass +class _CompactionStats: + scanned: int = 0 + compacted: int = 0 + skipped: int = 0 + failed: int = 0 + + +class PeriodicContextCompactionScheduler: + """Periodically compact conversation history and persist summarized history back to DB. + + This upgrades existing "compress-on-overflow" behavior into proactive, scheduled + conversation-body compaction to keep long sessions lightweight. + """ + + _DEFAULTS = { + "enabled": False, + "interval_minutes": 30, + "startup_delay_seconds": 120, + "max_conversations_per_run": 8, + "max_scan_per_run": 120, + "scan_page_size": 40, + "min_idle_minutes": 15, + "min_messages": 14, + "target_tokens": 4096, + "trigger_tokens": 6144, + "max_rounds": 3, + "truncate_turns": 1, + "keep_recent": 6, + "provider_id": "", + "instruction": "", + "dry_run": False, + } + + def __init__( + self, + config_manager: AstrBotConfigManager, + conversation_manager: ConversationManager, + provider_manager: ProviderManager, + ) -> None: + self.config_manager = config_manager + self.conversation_manager = conversation_manager + self.provider_manager = provider_manager + self._stop_event = asyncio.Event() + self._running_lock = asyncio.Lock() + self._token_counter = EstimateTokenCounter() + self._bootstrapped = False + + async def run(self) -> None: + logger.info("[ContextCompact] scheduler started") + while not self._stop_event.is_set(): + cfg = self._load_config() + wait_seconds = max(5, int(cfg["interval_minutes"])) * 60 + + if not cfg["enabled"]: + await self._sleep_or_stop(wait_seconds) + continue + + if not self._bootstrapped: + self._bootstrapped = True + startup_delay = max(0, int(cfg["startup_delay_seconds"])) + if startup_delay > 0: + logger.info( + "[ContextCompact] startup delay: %ss before first run", + startup_delay, + ) + await self._sleep_or_stop(startup_delay) + if self._stop_event.is_set(): + break + + try: + report = await self.run_once(reason="scheduled") + logger.info( + "[ContextCompact] run done(%s): scanned=%s compacted=%s skipped=%s failed=%s elapsed=%.2fs", + report.get("reason", "unknown"), + report.get("scanned", 0), + report.get("compacted", 0), + report.get("skipped", 0), + report.get("failed", 0), + report.get("elapsed_sec", 0.0), + ) + except Exception as exc: + logger.error( + "[ContextCompact] scheduler run error: %s", + exc, + exc_info=True, + ) + + await self._sleep_or_stop(wait_seconds) + + logger.info("[ContextCompact] scheduler stopped") + + async def stop(self) -> None: + self._stop_event.set() + + async def run_once(self, reason: str = "manual") -> dict[str, Any]: + """Run one compaction sweep. + + Exposed so future admin command/cron endpoints can trigger ad-hoc compaction. + """ + async with self._running_lock: + cfg = self._load_config() + started = time.monotonic() + stats = _CompactionStats() + + if not cfg["enabled"] and reason == "scheduled": + return { + "reason": reason, + "scanned": 0, + "compacted": 0, + "skipped": 0, + "failed": 0, + "elapsed_sec": 0.0, + "message": "disabled", + } + + max_to_compact = max(1, int(cfg["max_conversations_per_run"])) + max_to_scan = max(max_to_compact, int(cfg["max_scan_per_run"])) + scan_page_size = max(10, int(cfg["scan_page_size"])) + + page = 1 + while ( + not self._stop_event.is_set() + and stats.scanned < max_to_scan + and stats.compacted < max_to_compact + ): + conversations, total = ( + await self.conversation_manager.db.get_filtered_conversations( + page=page, + page_size=scan_page_size, + ) + ) + if not conversations: + break + + for conv in conversations: + if ( + self._stop_event.is_set() + or stats.scanned >= max_to_scan + or stats.compacted >= max_to_compact + ): + break + + stats.scanned += 1 + outcome = await self._compact_one_conversation(conv, cfg) + if outcome == "compacted": + stats.compacted += 1 + elif outcome == "skipped": + stats.skipped += 1 + else: + stats.failed += 1 + + if page * scan_page_size >= total: + break + page += 1 + + elapsed = time.monotonic() - started + return { + "reason": reason, + "scanned": stats.scanned, + "compacted": stats.compacted, + "skipped": stats.skipped, + "failed": stats.failed, + "elapsed_sec": elapsed, + } + + async def _sleep_or_stop(self, seconds: int) -> None: + try: + await asyncio.wait_for(self._stop_event.wait(), timeout=seconds) + except asyncio.TimeoutError: + return + + def _load_config(self) -> dict[str, Any]: + default_conf = self.config_manager.default_conf + provider_settings = default_conf.get("provider_settings", {}) + raw_cfg = provider_settings.get("periodic_context_compaction", {}) + if not isinstance(raw_cfg, dict): + raw_cfg = {} + + cfg = dict(self._DEFAULTS) + cfg.update(raw_cfg) + + # normalize + cfg["enabled"] = self._to_bool(cfg.get("enabled"), False) + cfg["interval_minutes"] = self._to_int(cfg.get("interval_minutes"), 30, 1) + cfg["startup_delay_seconds"] = self._to_int( + cfg.get("startup_delay_seconds"), + 120, + 0, + ) + cfg["max_conversations_per_run"] = self._to_int( + cfg.get("max_conversations_per_run"), + 8, + 1, + ) + cfg["max_scan_per_run"] = self._to_int( + cfg.get("max_scan_per_run"), + 120, + 1, + ) + cfg["scan_page_size"] = self._to_int(cfg.get("scan_page_size"), 40, 10) + cfg["min_idle_minutes"] = self._to_int(cfg.get("min_idle_minutes"), 15, 0) + cfg["min_messages"] = self._to_int(cfg.get("min_messages"), 14, 2) + cfg["target_tokens"] = self._to_int(cfg.get("target_tokens"), 4096, 512) + cfg["trigger_tokens"] = self._to_int( + cfg.get("trigger_tokens"), + max(int(cfg["target_tokens"] * 1.5), cfg["target_tokens"] + 1), + 512, + ) + if cfg["trigger_tokens"] <= cfg["target_tokens"]: + cfg["trigger_tokens"] = cfg["target_tokens"] + 1 + cfg["max_rounds"] = self._to_int(cfg.get("max_rounds"), 3, 1) + cfg["truncate_turns"] = self._to_int(cfg.get("truncate_turns"), 1, 1) + cfg["keep_recent"] = self._to_int(cfg.get("keep_recent"), 6, 0) + cfg["provider_id"] = str(cfg.get("provider_id", "") or "").strip() + cfg["instruction"] = str(cfg.get("instruction", "") or "").strip() + cfg["dry_run"] = self._to_bool(cfg.get("dry_run"), False) + + return cfg + + async def _compact_one_conversation( + self, + conv: ConversationV2, + cfg: dict[str, Any], + ) -> str: + history = conv.content + if not isinstance(history, list) or len(history) < cfg["min_messages"]: + return "skipped" + + if not self._is_idle_enough(conv.updated_at, cfg["min_idle_minutes"]): + return "skipped" + + messages = self._parse_history(history) + if len(messages) < cfg["min_messages"]: + return "skipped" + + trusted_usage = conv.token_usage if isinstance(conv.token_usage, int) else 0 + before_tokens = self._token_counter.count_tokens(messages, trusted_usage) + if before_tokens < cfg["trigger_tokens"]: + return "skipped" + + provider = await self._resolve_provider(cfg, conv.user_id) + if not provider: + return "failed" + + compressed = messages + changed = False + rounds = 0 + for _ in range(cfg["max_rounds"]): + current_tokens = self._token_counter.count_tokens(compressed) + if current_tokens <= cfg["target_tokens"]: + break + + manager = ContextManager( + ContextConfig( + max_context_tokens=cfg["target_tokens"], + enforce_max_turns=-1, + truncate_turns=cfg["truncate_turns"], + llm_compress_keep_recent=cfg["keep_recent"], + llm_compress_instruction=self._resolve_instruction(cfg), + llm_compress_provider=provider, + ) + ) + + rounds += 1 + next_messages = await manager.process(compressed) + if self._messages_equal(compressed, next_messages): + break + + compressed = next_messages + changed = True + + if not changed: + return "skipped" + + after_tokens = self._token_counter.count_tokens(compressed) + if after_tokens >= before_tokens: + return "skipped" + + if cfg["dry_run"]: + logger.info( + "[ContextCompact] dry-run: cid=%s user=%s tokens=%s->%s rounds=%s", + conv.conversation_id, + conv.user_id, + before_tokens, + after_tokens, + rounds, + ) + return "compacted" + + try: + await self.conversation_manager.update_conversation( + unified_msg_origin=conv.user_id, + conversation_id=conv.conversation_id, + history=[msg.model_dump(exclude_none=True) for msg in compressed], + token_usage=after_tokens, + ) + except Exception as exc: + logger.error( + "[ContextCompact] update failed: cid=%s user=%s err=%s", + conv.conversation_id, + conv.user_id, + exc, + exc_info=True, + ) + return "failed" + + logger.info( + "[ContextCompact] compacted cid=%s user=%s tokens=%s->%s rounds=%s", + conv.conversation_id, + conv.user_id, + before_tokens, + after_tokens, + rounds, + ) + return "compacted" + + async def _resolve_provider( + self, + cfg: dict[str, Any], + umo: str, + ) -> Provider | None: + provider = None + + if cfg["provider_id"]: + provider = await self.provider_manager.get_provider_by_id(cfg["provider_id"]) + else: + provider = self.provider_manager.get_using_provider( + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) + if provider is None: + provider = self.provider_manager.get_using_provider( + provider_type=ProviderType.CHAT_COMPLETION, + umo=None, + ) + + if not isinstance(provider, Provider): + logger.warning( + "[ContextCompact] provider unavailable for umo=%s provider_id=%s", + umo, + cfg["provider_id"], + ) + return None + return provider + + def _resolve_instruction(self, cfg: dict[str, Any]) -> str: + if cfg["instruction"]: + return cfg["instruction"] + + provider_settings = self.config_manager.default_conf.get("provider_settings", {}) + base_instruction = provider_settings.get("llm_compress_instruction", "") + if isinstance(base_instruction, str) and base_instruction.strip(): + return base_instruction.strip() + return "" + + @staticmethod + def _is_idle_enough(updated_at: datetime | None, min_idle_minutes: int) -> bool: + if min_idle_minutes <= 0: + return True + if updated_at is None: + return True + now = datetime.now(timezone.utc) + at = updated_at + if at.tzinfo is None: + at = at.replace(tzinfo=timezone.utc) + return (now - at).total_seconds() >= (min_idle_minutes * 60) + + def _parse_history(self, history: Iterable[Any]) -> list[Message]: + parsed: list[Message] = [] + for item in history: + if not isinstance(item, dict): + continue + + try: + parsed.append(Message.model_validate(item)) + continue + except Exception: + pass + + fallback = self._sanitize_message_dict(item) + if not fallback: + continue + try: + parsed.append(Message.model_validate(fallback)) + except Exception: + continue + + return parsed + + def _sanitize_message_dict(self, item: dict[str, Any]) -> dict[str, Any] | None: + role = str(item.get("role", "")).strip().lower() + if role not in {"system", "user", "assistant", "tool"}: + return None + + result: dict[str, Any] = {"role": role} + + if role == "assistant" and isinstance(item.get("tool_calls"), list): + result["tool_calls"] = item["tool_calls"] + + if role == "tool" and item.get("tool_call_id"): + result["tool_call_id"] = str(item.get("tool_call_id")) + + content = item.get("content") + if content is None and role == "assistant" and result.get("tool_calls"): + result["content"] = None + return result + + result["content"] = self._sanitize_content(content, role) + + if result["content"] is None and not ( + role == "assistant" and result.get("tool_calls") + ): + return None + + return result + + def _sanitize_content(self, content: Any, role: str) -> str | list[dict] | None: + if isinstance(content, str): + return content + + if isinstance(content, list): + parts: list[dict[str, Any]] = [] + fallback_texts: list[str] = [] + + for part in content: + if isinstance(part, str): + if part.strip(): + fallback_texts.append(part) + continue + if not isinstance(part, dict): + txt = self._safe_json(part) + if txt: + fallback_texts.append(txt) + continue + + part_type = str(part.get("type", "")).strip() + if part_type == "text": + text_val = part.get("text") + if text_val is not None: + parts.append({"type": "text", "text": str(text_val)}) + continue + if part_type == "image_url": + image_obj = part.get("image_url") + if isinstance(image_obj, dict) and image_obj.get("url"): + image_part: dict[str, Any] = { + "type": "image_url", + "image_url": {"url": str(image_obj.get("url"))}, + } + if image_obj.get("id"): + image_part["image_url"]["id"] = str(image_obj.get("id")) + parts.append(image_part) + continue + if part_type == "audio_url": + audio_obj = part.get("audio_url") + if isinstance(audio_obj, dict) and audio_obj.get("url"): + audio_part: dict[str, Any] = { + "type": "audio_url", + "audio_url": {"url": str(audio_obj.get("url"))}, + } + if audio_obj.get("id"): + audio_part["audio_url"]["id"] = str(audio_obj.get("id")) + parts.append(audio_part) + continue + + if part_type == "think": + think = part.get("think") + if think: + fallback_texts.append(str(think)) + continue + + raw_text = part.get("text") or part.get("content") + if raw_text: + fallback_texts.append(str(raw_text)) + else: + dumped = self._safe_json(part) + if dumped: + fallback_texts.append(dumped) + + if fallback_texts: + parts.insert(0, {"type": "text", "text": "\n".join(fallback_texts)}) + + if parts: + return parts + return "" + + if content is None: + if role == "assistant": + return None + return "" + + dumped = self._safe_json(content) + return dumped if dumped is not None else str(content) + + @staticmethod + def _messages_equal(a: list[Message], b: list[Message]) -> bool: + if len(a) != len(b): + return False + return [m.model_dump(exclude_none=True) for m in a] == [ + m.model_dump(exclude_none=True) for m in b + ] + + @staticmethod + def _safe_json(value: Any) -> str | None: + try: + return json.dumps(value, ensure_ascii=False, default=str) + except Exception: + return None + + @staticmethod + def _to_bool(value: Any, default: bool) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + return default + + @staticmethod + def _to_int(value: Any, default: int, min_value: int) -> int: + try: + parsed = int(value) + except Exception: + parsed = default + return max(parsed, min_value) diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index fe6b1c351d..6c12719c8e 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -20,6 +20,9 @@ from astrbot.core import LogBroker, LogManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.default import VERSION +from astrbot.core.context_compaction_scheduler import ( + PeriodicContextCompactionScheduler, +) from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.cron import CronJobManager from astrbot.core.db import BaseDatabase @@ -59,6 +62,9 @@ def __init__(self, log_broker: LogBroker, db: BaseDatabase) -> None: self.subagent_orchestrator: SubAgentOrchestrator | None = None self.cron_manager: CronJobManager | None = None self.temp_dir_cleaner: TempDirCleaner | None = None + self.context_compaction_scheduler: ( + PeriodicContextCompactionScheduler | None + ) = None # 设置代理 proxy_config = self.astrbot_config.get("http_proxy", "") @@ -166,6 +172,13 @@ async def initialize(self) -> None: # 初始化对话管理器 self.conversation_manager = ConversationManager(self.db) + # 初始化定时历史压缩调度器(基于 llm_compress) + self.context_compaction_scheduler = PeriodicContextCompactionScheduler( + config_manager=self.astrbot_config_mgr, + conversation_manager=self.conversation_manager, + provider_manager=self.provider_manager, + ) + # 初始化平台消息历史管理器 self.platform_message_history_manager = PlatformMessageHistoryManager(self.db) @@ -252,6 +265,12 @@ def _load(self) -> None: self.temp_dir_cleaner.run(), name="temp_dir_cleaner", ) + context_compaction_task = None + if self.context_compaction_scheduler: + context_compaction_task = asyncio.create_task( + self.context_compaction_scheduler.run(), + name="context_compaction_scheduler", + ) # 把插件中注册的所有协程函数注册到事件总线中并执行 extra_tasks = [] @@ -263,6 +282,8 @@ def _load(self) -> None: tasks_.append(cron_task) if temp_dir_cleaner_task: tasks_.append(temp_dir_cleaner_task) + if context_compaction_task: + tasks_.append(context_compaction_task) for task in tasks_: self.curr_tasks.append( asyncio.create_task(self._task_wrapper(task), name=task.get_name()), @@ -317,6 +338,9 @@ async def stop(self) -> None: if self.temp_dir_cleaner: await self.temp_dir_cleaner.stop() + if self.context_compaction_scheduler: + await self.context_compaction_scheduler.stop() + # 请求停止所有正在运行的异步任务 for task in self.curr_tasks: task.cancel() diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py new file mode 100644 index 0000000000..c8fab93925 --- /dev/null +++ b/tests/unit/test_context_compaction_scheduler.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from types import SimpleNamespace + +from astrbot.core.context_compaction_scheduler import PeriodicContextCompactionScheduler + + +class DummyConfigManager: + def __init__(self, default_conf: dict): + self.default_conf = default_conf + + +def _build_scheduler(cfg: dict) -> PeriodicContextCompactionScheduler: + manager = DummyConfigManager({"provider_settings": cfg}) + return PeriodicContextCompactionScheduler( + config_manager=manager, + conversation_manager=SimpleNamespace(), + provider_manager=SimpleNamespace(), + ) + + +def test_load_config_normalizes_values() -> None: + scheduler = _build_scheduler( + { + "periodic_context_compaction": { + "enabled": "true", + "interval_minutes": "0", + "target_tokens": 1024, + "trigger_tokens": 1000, + "max_rounds": "2", + } + } + ) + + cfg = scheduler._load_config() + + assert cfg["enabled"] is True + assert cfg["interval_minutes"] == 1 + assert cfg["target_tokens"] == 1024 + assert cfg["trigger_tokens"] == 1025 + assert cfg["max_rounds"] == 2 + + +def test_sanitize_message_dict_keeps_supported_parts() -> None: + scheduler = _build_scheduler({}) + + raw = { + "role": "assistant", + "content": [ + {"type": "think", "think": "internal reasoning"}, + {"type": "text", "text": "visible answer"}, + {"type": "image_url", "image_url": {"url": "https://x.test/a.png"}}, + {"type": "unknown", "foo": "bar"}, + ], + } + + sanitized = scheduler._sanitize_message_dict(raw) + + assert sanitized is not None + assert sanitized["role"] == "assistant" + assert isinstance(sanitized["content"], list) + content = sanitized["content"] + assert content[0]["type"] == "text" + assert "internal reasoning" in content[0]["text"] + assert any(part.get("type") == "image_url" for part in content) + + +def test_is_idle_enough_respects_threshold() -> None: + now = datetime.now(timezone.utc) + old = now - timedelta(minutes=30) + recent = now - timedelta(minutes=2) + + assert PeriodicContextCompactionScheduler._is_idle_enough(old, 10) is True + assert PeriodicContextCompactionScheduler._is_idle_enough(recent, 10) is False + assert PeriodicContextCompactionScheduler._is_idle_enough(None, 10) is True From 1396836ca79161096f21fdfa3a9ac6d52a22406d Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 08:29:28 -0500 Subject: [PATCH 02/29] feat: add ctxcompact admin commands and scheduler hardening --- .../builtin_commands/commands/__init__.py | 2 + .../commands/context_compaction.py | 99 ++++++++++++++ .../builtin_stars/builtin_commands/main.py | 23 ++++ astrbot/core/context_compaction_scheduler.py | 60 +++++++-- astrbot/core/core_lifecycle.py | 1 + tests/unit/test_context_compaction_command.py | 127 ++++++++++++++++++ .../unit/test_context_compaction_scheduler.py | 107 +++++++++++++++ 7 files changed, 410 insertions(+), 9 deletions(-) create mode 100644 astrbot/builtin_stars/builtin_commands/commands/context_compaction.py create mode 100644 tests/unit/test_context_compaction_command.py diff --git a/astrbot/builtin_stars/builtin_commands/commands/__init__.py b/astrbot/builtin_stars/builtin_commands/commands/__init__.py index 46d255965a..d56f0cae24 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/__init__.py +++ b/astrbot/builtin_stars/builtin_commands/commands/__init__.py @@ -2,6 +2,7 @@ from .admin import AdminCommands from .alter_cmd import AlterCmdCommands +from .context_compaction import ContextCompactionCommands from .conversation import ConversationCommands from .help import HelpCommand from .llm import LLMCommands @@ -17,6 +18,7 @@ "AdminCommands", "AlterCmdCommands", "ConversationCommands", + "ContextCompactionCommands", "HelpCommand", "LLMCommands", "PersonaCommands", diff --git a/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py b/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py new file mode 100644 index 0000000000..37c0d6aa81 --- /dev/null +++ b/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py @@ -0,0 +1,99 @@ +from astrbot.api import star +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.core.context_compaction_scheduler import PeriodicContextCompactionScheduler + + +class ContextCompactionCommands: + def __init__(self, context: star.Context) -> None: + self.context = context + + def _get_scheduler(self) -> PeriodicContextCompactionScheduler | None: + scheduler = getattr(self.context, "context_compaction_scheduler", None) + if isinstance(scheduler, PeriodicContextCompactionScheduler): + return scheduler + return None + + async def status(self, event: AstrMessageEvent) -> None: + scheduler = self._get_scheduler() + if not scheduler: + await event.send( + MessageChain().message("定时上下文压缩调度器不可用。"), + ) + return + + status = scheduler.get_status() + cfg = status.get("config", {}) + last = status.get("last_report") or {} + + lines = ["定时上下文压缩状态:"] + lines.append( + f"启用={self._yes_no(bool(cfg.get('enabled', False)))}" + f" | 运行中={self._yes_no(bool(status.get('running', False)))}" + f" | 停止请求={self._yes_no(bool(status.get('stop_requested', False)))}" + ) + lines.append( + f"间隔={cfg.get('interval_minutes', '?')}分钟" + f" | 每轮最多压缩={cfg.get('max_conversations_per_run', '?')}" + f" | 每轮最多扫描={cfg.get('max_scan_per_run', '?')}" + ) + lines.append( + f"触发Token={cfg.get('trigger_tokens', '?')}" + f" | 目标Token={cfg.get('target_tokens', '?')}" + f" | 最大轮次={cfg.get('max_rounds', '?')}" + ) + + if last: + lines.append( + f"最近任务[{last.get('reason', 'unknown')}]" + f" scanned={last.get('scanned', 0)}" + f" compacted={last.get('compacted', 0)}" + f" skipped={last.get('skipped', 0)}" + f" failed={last.get('failed', 0)}" + f" elapsed={last.get('elapsed_sec', 0.0):.2f}s" + ) + else: + lines.append("最近任务:暂无") + + if status.get("last_started_at"): + lines.append(f"最近开始:{status.get('last_started_at')}") + if status.get("last_finished_at"): + lines.append(f"最近结束:{status.get('last_finished_at')}") + if status.get("last_error"): + lines.append(f"最近错误:{status.get('last_error')}") + + await event.send(MessageChain().message("\n".join(lines))) + + async def run(self, event: AstrMessageEvent, limit: int | None = None) -> None: + scheduler = self._get_scheduler() + if not scheduler: + await event.send( + MessageChain().message("定时上下文压缩调度器不可用。"), + ) + return + + if limit is not None and limit < 1: + await event.send(MessageChain().message("limit 必须 >= 1。")) + return + + try: + report = await scheduler.run_once( + reason="manual_command", + max_conversations_override=limit, + ) + except Exception as exc: + await event.send(MessageChain().message(f"触发压缩失败: {exc}")) + return + + msg = ( + "手动触发完成:" + f"scanned={report.get('scanned', 0)} " + f"compacted={report.get('compacted', 0)} " + f"skipped={report.get('skipped', 0)} " + f"failed={report.get('failed', 0)} " + f"elapsed={report.get('elapsed_sec', 0.0):.2f}s" + ) + await event.send(MessageChain().message(msg)) + + @staticmethod + def _yes_no(value: bool) -> str: + return "是" if value else "否" diff --git a/astrbot/builtin_stars/builtin_commands/main.py b/astrbot/builtin_stars/builtin_commands/main.py index fb4a834035..d2d7fbc1b2 100644 --- a/astrbot/builtin_stars/builtin_commands/main.py +++ b/astrbot/builtin_stars/builtin_commands/main.py @@ -4,6 +4,7 @@ from .commands import ( AdminCommands, AlterCmdCommands, + ContextCompactionCommands, ConversationCommands, HelpCommand, LLMCommands, @@ -26,6 +27,7 @@ def __init__(self, context: star.Context) -> None: self.plugin_c = PluginCommands(self.context) self.admin_c = AdminCommands(self.context) self.conversation_c = ConversationCommands(self.context) + self.ctxcompact_c = ContextCompactionCommands(self.context) self.provider_c = ProviderCommands(self.context) self.persona_c = PersonaCommands(self.context) self.alter_cmd_c = AlterCmdCommands(self.context) @@ -127,6 +129,27 @@ async def provider( """查看或者切换 LLM Provider""" await self.provider_c.provider(event, idx, idx2) + @filter.permission_type(filter.PermissionType.ADMIN) + @filter.command_group("ctxcompact") + def ctxcompact(self) -> None: + """上下文定时压缩管理""" + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxcompact.command("status") + async def ctxcompact_status(self, event: AstrMessageEvent) -> None: + """查看定时上下文压缩状态""" + await self.ctxcompact_c.status(event) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxcompact.command("run") + async def ctxcompact_run( + self, + event: AstrMessageEvent, + limit: int | None = None, + ) -> None: + """手动触发一次上下文压缩(可选 limit 覆盖本次压缩会话数)""" + await self.ctxcompact_c.run(event, limit) + @filter.command("reset") async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index 7493575a57..87f8a313db 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -70,12 +70,29 @@ def __init__( self._running_lock = asyncio.Lock() self._token_counter = EstimateTokenCounter() self._bootstrapped = False + self._last_report: dict[str, Any] | None = None + self._last_started_at: str | None = None + self._last_finished_at: str | None = None + self._last_error: str | None = None + + def get_status(self) -> dict[str, Any]: + cfg = self._load_config() + return { + "running": self._running_lock.locked(), + "bootstrapped": self._bootstrapped, + "stop_requested": self._stop_event.is_set(), + "config": cfg, + "last_started_at": self._last_started_at, + "last_finished_at": self._last_finished_at, + "last_error": self._last_error, + "last_report": self._last_report, + } async def run(self) -> None: logger.info("[ContextCompact] scheduler started") while not self._stop_event.is_set(): cfg = self._load_config() - wait_seconds = max(5, int(cfg["interval_minutes"])) * 60 + wait_seconds = self._resolve_wait_seconds(cfg) if not cfg["enabled"]: await self._sleep_or_stop(wait_seconds) @@ -105,6 +122,7 @@ async def run(self) -> None: report.get("elapsed_sec", 0.0), ) except Exception as exc: + self._last_error = str(exc) logger.error( "[ContextCompact] scheduler run error: %s", exc, @@ -118,18 +136,23 @@ async def run(self) -> None: async def stop(self) -> None: self._stop_event.set() - async def run_once(self, reason: str = "manual") -> dict[str, Any]: + async def run_once( + self, + reason: str = "manual", + max_conversations_override: int | None = None, + ) -> dict[str, Any]: """Run one compaction sweep. Exposed so future admin command/cron endpoints can trigger ad-hoc compaction. """ async with self._running_lock: + self._last_started_at = self._now_iso() cfg = self._load_config() started = time.monotonic() stats = _CompactionStats() if not cfg["enabled"] and reason == "scheduled": - return { + report = { "reason": reason, "scanned": 0, "compacted": 0, @@ -138,8 +161,14 @@ async def run_once(self, reason: str = "manual") -> dict[str, Any]: "elapsed_sec": 0.0, "message": "disabled", } + self._last_report = report + self._last_finished_at = self._now_iso() + self._last_error = None + return report max_to_compact = max(1, int(cfg["max_conversations_per_run"])) + if max_conversations_override is not None: + max_to_compact = max(1, int(max_conversations_override)) max_to_scan = max(max_to_compact, int(cfg["max_scan_per_run"])) scan_page_size = max(10, int(cfg["scan_page_size"])) @@ -180,7 +209,7 @@ async def run_once(self, reason: str = "manual") -> dict[str, Any]: page += 1 elapsed = time.monotonic() - started - return { + report = { "reason": reason, "scanned": stats.scanned, "compacted": stats.compacted, @@ -188,6 +217,10 @@ async def run_once(self, reason: str = "manual") -> dict[str, Any]: "failed": stats.failed, "elapsed_sec": elapsed, } + self._last_report = report + self._last_finished_at = self._now_iso() + self._last_error = None + return report async def _sleep_or_stop(self, seconds: int) -> None: try: @@ -195,6 +228,10 @@ async def _sleep_or_stop(self, seconds: int) -> None: except asyncio.TimeoutError: return + @staticmethod + def _resolve_wait_seconds(cfg: dict[str, Any]) -> int: + return max(1, int(cfg["interval_minutes"])) * 60 + def _load_config(self) -> dict[str, Any]: default_conf = self.config_manager.default_conf provider_settings = default_conf.get("provider_settings", {}) @@ -227,11 +264,12 @@ def _load_config(self) -> dict[str, Any]: cfg["min_idle_minutes"] = self._to_int(cfg.get("min_idle_minutes"), 15, 0) cfg["min_messages"] = self._to_int(cfg.get("min_messages"), 14, 2) cfg["target_tokens"] = self._to_int(cfg.get("target_tokens"), 4096, 512) - cfg["trigger_tokens"] = self._to_int( - cfg.get("trigger_tokens"), - max(int(cfg["target_tokens"] * 1.5), cfg["target_tokens"] + 1), - 512, - ) + trigger_default = max(int(cfg["target_tokens"] * 1.5), cfg["target_tokens"] + 1) + raw_trigger = raw_cfg.get("trigger_tokens") + if raw_trigger is None or (isinstance(raw_trigger, str) and not raw_trigger): + cfg["trigger_tokens"] = trigger_default + else: + cfg["trigger_tokens"] = self._to_int(raw_trigger, trigger_default, 512) if cfg["trigger_tokens"] <= cfg["target_tokens"]: cfg["trigger_tokens"] = cfg["target_tokens"] + 1 cfg["max_rounds"] = self._to_int(cfg.get("max_rounds"), 3, 1) @@ -553,3 +591,7 @@ def _to_int(value: Any, default: int, min_value: int) -> int: except Exception: parsed = default return max(parsed, min_value) + + @staticmethod + def _now_iso() -> str: + return datetime.now(timezone.utc).isoformat() diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 6c12719c8e..ab1a3eb25d 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -206,6 +206,7 @@ async def initialize(self) -> None: self.cron_manager, self.subagent_orchestrator, ) + self.star_context.context_compaction_scheduler = self.context_compaction_scheduler # 初始化插件管理器 self.plugin_manager = PluginManager(self.star_context, self.astrbot_config) diff --git a/tests/unit/test_context_compaction_command.py b/tests/unit/test_context_compaction_command.py new file mode 100644 index 0000000000..560548f43b --- /dev/null +++ b/tests/unit/test_context_compaction_command.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from astrbot.builtin_stars.builtin_commands.commands.context_compaction import ( + ContextCompactionCommands, +) +from astrbot.core.context_compaction_scheduler import PeriodicContextCompactionScheduler + + +class DummyConfigManager: + def __init__(self, default_conf: dict): + self.default_conf = default_conf + + +def _build_scheduler() -> PeriodicContextCompactionScheduler: + cfg_mgr = DummyConfigManager( + { + "provider_settings": { + "periodic_context_compaction": { + "enabled": True, + "interval_minutes": 3, + "max_conversations_per_run": 2, + "max_scan_per_run": 10, + "target_tokens": 1024, + "trigger_tokens": 2048, + "max_rounds": 2, + } + } + } + ) + return PeriodicContextCompactionScheduler( + config_manager=cfg_mgr, + conversation_manager=SimpleNamespace(), + provider_manager=SimpleNamespace(), + ) + + +@pytest.mark.asyncio +async def test_status_when_scheduler_unavailable() -> None: + command = ContextCompactionCommands(context=SimpleNamespace()) + event = SimpleNamespace(send=AsyncMock()) + + await command.status(event) + + event.send.assert_awaited_once() + chain = event.send.await_args.args[0] + assert "不可用" in chain.get_plain_text(with_other_comps_mark=True) + + +@pytest.mark.asyncio +async def test_status_with_runtime_report() -> None: + scheduler = _build_scheduler() + scheduler._last_report = { + "reason": "manual_command", + "scanned": 8, + "compacted": 2, + "skipped": 6, + "failed": 0, + "elapsed_sec": 1.2, + } + scheduler._last_started_at = "2026-03-19T12:00:00+00:00" + scheduler._last_finished_at = "2026-03-19T12:00:01+00:00" + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.status(event) + + event.send.assert_awaited_once() + chain = event.send.await_args.args[0] + text = chain.get_plain_text(with_other_comps_mark=True) + assert "定时上下文压缩状态" in text + assert "最近任务[manual_command]" in text + assert "compacted=2" in text + + +@pytest.mark.asyncio +async def test_run_with_invalid_limit() -> None: + scheduler = _build_scheduler() + scheduler.run_once = AsyncMock() + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.run(event, 0) + + scheduler.run_once.assert_not_awaited() + chain = event.send.await_args.args[0] + assert "limit 必须 >= 1" in chain.get_plain_text(with_other_comps_mark=True) + + +@pytest.mark.asyncio +async def test_run_triggers_scheduler_once() -> None: + scheduler = _build_scheduler() + scheduler.run_once = AsyncMock( + return_value={ + "scanned": 12, + "compacted": 3, + "skipped": 8, + "failed": 1, + "elapsed_sec": 2.5, + } + ) + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.run(event, 5) + + scheduler.run_once.assert_awaited_once_with( + reason="manual_command", + max_conversations_override=5, + ) + chain = event.send.await_args.args[0] + text = chain.get_plain_text(with_other_comps_mark=True) + assert "手动触发完成" in text + assert "compacted=3" in text diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py index c8fab93925..f99191a626 100644 --- a/tests/unit/test_context_compaction_scheduler.py +++ b/tests/unit/test_context_compaction_scheduler.py @@ -3,6 +3,8 @@ from datetime import datetime, timedelta, timezone from types import SimpleNamespace +import pytest + from astrbot.core.context_compaction_scheduler import PeriodicContextCompactionScheduler @@ -42,6 +44,111 @@ def test_load_config_normalizes_values() -> None: assert cfg["max_rounds"] == 2 +@pytest.mark.parametrize( + ("raw_enabled", "expected"), + [ + ("true", True), + ("false", False), + ("1", True), + ("0", False), + ("yes", True), + ("no", False), + ("unknown", False), + ], +) +def test_load_config_enabled_bool_parsing(raw_enabled: str, expected: bool) -> None: + scheduler = _build_scheduler( + { + "periodic_context_compaction": { + "enabled": raw_enabled, + } + } + ) + + cfg = scheduler._load_config() + assert cfg["enabled"] is expected + + +@pytest.mark.parametrize( + ("raw_cfg", "expected_interval", "expected_scan_page_size", "expected_min_messages"), + [ + ({"interval_minutes": 0, "scan_page_size": 1, "min_messages": 0}, 1, 10, 2), + ( + {"interval_minutes": -3, "scan_page_size": -5, "min_messages": -1}, + 1, + 10, + 2, + ), + ( + {"interval_minutes": "0", "scan_page_size": "1", "min_messages": "0"}, + 1, + 10, + 2, + ), + ], +) +def test_load_config_clamps_numeric_minimums( + raw_cfg: dict, + expected_interval: int, + expected_scan_page_size: int, + expected_min_messages: int, +) -> None: + scheduler = _build_scheduler({"periodic_context_compaction": raw_cfg}) + cfg = scheduler._load_config() + + assert cfg["interval_minutes"] == expected_interval + assert cfg["scan_page_size"] == expected_scan_page_size + assert cfg["min_messages"] == expected_min_messages + + +@pytest.mark.parametrize( + ("raw_cfg", "expected_target", "expected_trigger"), + [ + ({"target_tokens": 1024}, 1024, 1536), + ({"target_tokens": 1024, "trigger_tokens": None}, 1024, 1536), + ({"target_tokens": 1024, "trigger_tokens": 512}, 1024, 1025), + ({"target_tokens": 1024, "trigger_tokens": "512"}, 1024, 1025), + ({"target_tokens": 1024, "trigger_tokens": 2048}, 1024, 2048), + ({"target_tokens": 10}, 512, 768), + ], +) +def test_load_config_token_threshold_normalization( + raw_cfg: dict, + expected_target: int, + expected_trigger: int, +) -> None: + scheduler = _build_scheduler({"periodic_context_compaction": raw_cfg}) + cfg = scheduler._load_config() + + assert cfg["target_tokens"] == expected_target + assert cfg["trigger_tokens"] == expected_trigger + + +@pytest.mark.parametrize("raw_value", [None, 1, "not-a-dict", []]) +def test_load_config_falls_back_for_non_dict(raw_value) -> None: + scheduler = _build_scheduler({"periodic_context_compaction": raw_value}) + cfg = scheduler._load_config() + + assert cfg == scheduler._DEFAULTS + + +def test_resolve_wait_seconds_uses_normalized_interval() -> None: + cfg = {"interval_minutes": 1} + assert PeriodicContextCompactionScheduler._resolve_wait_seconds(cfg) == 60 + + +def test_get_status_returns_runtime_snapshot() -> None: + scheduler = _build_scheduler( + {"periodic_context_compaction": {"enabled": True, "interval_minutes": 3}} + ) + status = scheduler.get_status() + + assert status["running"] is False + assert status["config"]["enabled"] is True + assert status["config"]["interval_minutes"] == 3 + assert status["last_report"] is None + + def test_sanitize_message_dict_keeps_supported_parts() -> None: scheduler = _build_scheduler({}) From 7d4d02e6014ed457a8d6440d8cfcbe6f6ce03e50 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 08:37:26 -0500 Subject: [PATCH 03/29] refactor: split context compaction orchestration helpers --- astrbot/core/context_compaction_scheduler.py | 375 ++++++++++++------- 1 file changed, 243 insertions(+), 132 deletions(-) diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index 87f8a313db..55d105c67a 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -31,6 +31,20 @@ class _CompactionStats: failed: int = 0 +@dataclass +class _EligibilityResult: + eligible: bool + messages: list[Message] + before_tokens: int + + +@dataclass +class _RoundResult: + messages: list[Message] + changed: bool + rounds: int + + class PeriodicContextCompactionScheduler: """Periodically compact conversation history and persist summarized history back to DB. @@ -233,37 +247,54 @@ def _resolve_wait_seconds(cfg: dict[str, Any]) -> int: return max(1, int(cfg["interval_minutes"])) * 60 def _load_config(self) -> dict[str, Any]: - default_conf = self.config_manager.default_conf - provider_settings = default_conf.get("provider_settings", {}) - raw_cfg = provider_settings.get("periodic_context_compaction", {}) - if not isinstance(raw_cfg, dict): - raw_cfg = {} + raw_cfg = self._load_raw_config() cfg = dict(self._DEFAULTS) cfg.update(raw_cfg) # normalize cfg["enabled"] = self._to_bool(cfg.get("enabled"), False) - cfg["interval_minutes"] = self._to_int(cfg.get("interval_minutes"), 30, 1) - cfg["startup_delay_seconds"] = self._to_int( - cfg.get("startup_delay_seconds"), - 120, - 0, - ) - cfg["max_conversations_per_run"] = self._to_int( - cfg.get("max_conversations_per_run"), - 8, - 1, - ) - cfg["max_scan_per_run"] = self._to_int( - cfg.get("max_scan_per_run"), - 120, - 1, - ) - cfg["scan_page_size"] = self._to_int(cfg.get("scan_page_size"), 40, 10) - cfg["min_idle_minutes"] = self._to_int(cfg.get("min_idle_minutes"), 15, 0) - cfg["min_messages"] = self._to_int(cfg.get("min_messages"), 14, 2) - cfg["target_tokens"] = self._to_int(cfg.get("target_tokens"), 4096, 512) + self._normalize_int(cfg, "interval_minutes", 30, 1) + self._normalize_int(cfg, "startup_delay_seconds", 120, 0) + self._normalize_int(cfg, "max_conversations_per_run", 8, 1) + self._normalize_int(cfg, "max_scan_per_run", 120, 1) + self._normalize_int(cfg, "scan_page_size", 40, 10) + self._normalize_int(cfg, "min_idle_minutes", 15, 0) + self._normalize_int(cfg, "min_messages", 14, 2) + self._normalize_int(cfg, "target_tokens", 4096, 512) + self._normalize_trigger_tokens(cfg, raw_cfg) + self._normalize_int(cfg, "max_rounds", 3, 1) + self._normalize_int(cfg, "truncate_turns", 1, 1) + self._normalize_int(cfg, "keep_recent", 6, 0) + cfg["provider_id"] = str(cfg.get("provider_id", "") or "").strip() + cfg["instruction"] = str(cfg.get("instruction", "") or "").strip() + cfg["dry_run"] = self._to_bool(cfg.get("dry_run"), False) + + return cfg + + def _load_raw_config(self) -> dict[str, Any]: + default_conf = self.config_manager.default_conf + provider_settings = default_conf.get("provider_settings", {}) + raw_cfg = provider_settings.get("periodic_context_compaction", {}) + if isinstance(raw_cfg, dict): + return raw_cfg + return {} + + def _normalize_int( + self, + cfg: dict[str, Any], + key: str, + default: int, + min_value: int, + ) -> int: + cfg[key] = self._to_int(cfg.get(key), default, min_value) + return cfg[key] + + def _normalize_trigger_tokens( + self, + cfg: dict[str, Any], + raw_cfg: dict[str, Any], + ) -> int: trigger_default = max(int(cfg["target_tokens"] * 1.5), cfg["target_tokens"] + 1) raw_trigger = raw_cfg.get("trigger_tokens") if raw_trigger is None or (isinstance(raw_trigger, str) and not raw_trigger): @@ -272,59 +303,97 @@ def _load_config(self) -> dict[str, Any]: cfg["trigger_tokens"] = self._to_int(raw_trigger, trigger_default, 512) if cfg["trigger_tokens"] <= cfg["target_tokens"]: cfg["trigger_tokens"] = cfg["target_tokens"] + 1 - cfg["max_rounds"] = self._to_int(cfg.get("max_rounds"), 3, 1) - cfg["truncate_turns"] = self._to_int(cfg.get("truncate_turns"), 1, 1) - cfg["keep_recent"] = self._to_int(cfg.get("keep_recent"), 6, 0) - cfg["provider_id"] = str(cfg.get("provider_id", "") or "").strip() - cfg["instruction"] = str(cfg.get("instruction", "") or "").strip() - cfg["dry_run"] = self._to_bool(cfg.get("dry_run"), False) - - return cfg + return cfg["trigger_tokens"] async def _compact_one_conversation( self, conv: ConversationV2, cfg: dict[str, Any], ) -> str: + eligibility = self._check_eligibility(conv, cfg) + if not eligibility.eligible: + return "skipped" + + provider = await self._resolve_provider(cfg, conv.user_id) + if not provider: + return "failed" + + round_result = await self._run_compaction_rounds( + messages=eligibility.messages, + provider=provider, + cfg=cfg, + ) + if not round_result.changed: + return "skipped" + + after_tokens = self._token_counter.count_tokens(round_result.messages) + if after_tokens >= eligibility.before_tokens: + return "skipped" + + if cfg["dry_run"]: + self._log_dry_run(conv, eligibility.before_tokens, after_tokens, round_result) + return "compacted" + + persisted = await self._persist_compacted_history( + conv=conv, + compressed=round_result.messages, + after_tokens=after_tokens, + ) + if not persisted: + return "failed" + + self._log_compacted( + conv, + eligibility.before_tokens, + after_tokens, + round_result, + ) + return "compacted" + + def _check_eligibility( + self, + conv: ConversationV2, + cfg: dict[str, Any], + ) -> _EligibilityResult: history = conv.content if not isinstance(history, list) or len(history) < cfg["min_messages"]: - return "skipped" + return _EligibilityResult(eligible=False, messages=[], before_tokens=0) if not self._is_idle_enough(conv.updated_at, cfg["min_idle_minutes"]): - return "skipped" + return _EligibilityResult(eligible=False, messages=[], before_tokens=0) messages = self._parse_history(history) if len(messages) < cfg["min_messages"]: - return "skipped" + return _EligibilityResult(eligible=False, messages=[], before_tokens=0) trusted_usage = conv.token_usage if isinstance(conv.token_usage, int) else 0 before_tokens = self._token_counter.count_tokens(messages, trusted_usage) if before_tokens < cfg["trigger_tokens"]: - return "skipped" + return _EligibilityResult(eligible=False, messages=[], before_tokens=0) - provider = await self._resolve_provider(cfg, conv.user_id) - if not provider: - return "failed" + return _EligibilityResult( + eligible=True, + messages=messages, + before_tokens=before_tokens, + ) + async def _run_compaction_rounds( + self, + messages: list[Message], + provider: Provider, + cfg: dict[str, Any], + ) -> _RoundResult: compressed = messages changed = False rounds = 0 + instruction = self._resolve_instruction(cfg) + for _ in range(cfg["max_rounds"]): current_tokens = self._token_counter.count_tokens(compressed) if current_tokens <= cfg["target_tokens"]: break - manager = ContextManager( - ContextConfig( - max_context_tokens=cfg["target_tokens"], - enforce_max_turns=-1, - truncate_turns=cfg["truncate_turns"], - llm_compress_keep_recent=cfg["keep_recent"], - llm_compress_instruction=self._resolve_instruction(cfg), - llm_compress_provider=provider, - ) - ) - + manager = self._build_context_manager(cfg, provider, instruction) rounds += 1 next_messages = await manager.process(compressed) if self._messages_equal(compressed, next_messages): @@ -333,24 +402,31 @@ async def _compact_one_conversation( compressed = next_messages changed = True - if not changed: - return "skipped" - - after_tokens = self._token_counter.count_tokens(compressed) - if after_tokens >= before_tokens: - return "skipped" + return _RoundResult(messages=compressed, changed=changed, rounds=rounds) - if cfg["dry_run"]: - logger.info( - "[ContextCompact] dry-run: cid=%s user=%s tokens=%s->%s rounds=%s", - conv.conversation_id, - conv.user_id, - before_tokens, - after_tokens, - rounds, + @staticmethod + def _build_context_manager( + cfg: dict[str, Any], + provider: Provider, + instruction: str, + ) -> ContextManager: + return ContextManager( + ContextConfig( + max_context_tokens=cfg["target_tokens"], + enforce_max_turns=-1, + truncate_turns=cfg["truncate_turns"], + llm_compress_keep_recent=cfg["keep_recent"], + llm_compress_instruction=instruction, + llm_compress_provider=provider, ) - return "compacted" + ) + async def _persist_compacted_history( + self, + conv: ConversationV2, + compressed: list[Message], + after_tokens: int, + ) -> bool: try: await self.conversation_manager.update_conversation( unified_msg_origin=conv.user_id, @@ -366,17 +442,40 @@ async def _compact_one_conversation( exc, exc_info=True, ) - return "failed" + return False + return True + @staticmethod + def _log_dry_run( + conv: ConversationV2, + before_tokens: int, + after_tokens: int, + round_result: _RoundResult, + ) -> None: + logger.info( + "[ContextCompact] dry-run: cid=%s user=%s tokens=%s->%s rounds=%s", + conv.conversation_id, + conv.user_id, + before_tokens, + after_tokens, + round_result.rounds, + ) + + @staticmethod + def _log_compacted( + conv: ConversationV2, + before_tokens: int, + after_tokens: int, + round_result: _RoundResult, + ) -> None: logger.info( "[ContextCompact] compacted cid=%s user=%s tokens=%s->%s rounds=%s", conv.conversation_id, conv.user_id, before_tokens, after_tokens, - rounds, + round_result.rounds, ) - return "compacted" async def _resolve_provider( self, @@ -483,69 +582,7 @@ def _sanitize_content(self, content: Any, role: str) -> str | list[dict] | None: return content if isinstance(content, list): - parts: list[dict[str, Any]] = [] - fallback_texts: list[str] = [] - - for part in content: - if isinstance(part, str): - if part.strip(): - fallback_texts.append(part) - continue - if not isinstance(part, dict): - txt = self._safe_json(part) - if txt: - fallback_texts.append(txt) - continue - - part_type = str(part.get("type", "")).strip() - if part_type == "text": - text_val = part.get("text") - if text_val is not None: - parts.append({"type": "text", "text": str(text_val)}) - continue - if part_type == "image_url": - image_obj = part.get("image_url") - if isinstance(image_obj, dict) and image_obj.get("url"): - image_part: dict[str, Any] = { - "type": "image_url", - "image_url": {"url": str(image_obj.get("url"))}, - } - if image_obj.get("id"): - image_part["image_url"]["id"] = str(image_obj.get("id")) - parts.append(image_part) - continue - if part_type == "audio_url": - audio_obj = part.get("audio_url") - if isinstance(audio_obj, dict) and audio_obj.get("url"): - audio_part: dict[str, Any] = { - "type": "audio_url", - "audio_url": {"url": str(audio_obj.get("url"))}, - } - if audio_obj.get("id"): - audio_part["audio_url"]["id"] = str(audio_obj.get("id")) - parts.append(audio_part) - continue - - if part_type == "think": - think = part.get("think") - if think: - fallback_texts.append(str(think)) - continue - - raw_text = part.get("text") or part.get("content") - if raw_text: - fallback_texts.append(str(raw_text)) - else: - dumped = self._safe_json(part) - if dumped: - fallback_texts.append(dumped) - - if fallback_texts: - parts.insert(0, {"type": "text", "text": "\n".join(fallback_texts)}) - - if parts: - return parts - return "" + return self._sanitize_list_content(content) if content is None: if role == "assistant": @@ -555,6 +592,80 @@ def _sanitize_content(self, content: Any, role: str) -> str | list[dict] | None: dumped = self._safe_json(content) return dumped if dumped is not None else str(content) + def _sanitize_list_content(self, content: list[Any]) -> str | list[dict]: + parts: list[dict[str, Any]] = [] + fallback_texts: list[str] = [] + + for part in content: + if isinstance(part, str): + if part.strip(): + fallback_texts.append(part) + continue + if not isinstance(part, dict): + txt = self._safe_json(part) + if txt: + fallback_texts.append(txt) + continue + self._sanitize_content_part(part, parts, fallback_texts) + + if fallback_texts: + parts.insert(0, {"type": "text", "text": "\n".join(fallback_texts)}) + + if parts: + return parts + return "" + + def _sanitize_content_part( + self, + part: dict[str, Any], + parts: list[dict[str, Any]], + fallback_texts: list[str], + ) -> None: + part_type = str(part.get("type", "")).strip() + if part_type == "text": + text_val = part.get("text") + if text_val is not None: + parts.append({"type": "text", "text": str(text_val)}) + return + + if part_type == "image_url": + image_obj = part.get("image_url") + if isinstance(image_obj, dict) and image_obj.get("url"): + image_part: dict[str, Any] = { + "type": "image_url", + "image_url": {"url": str(image_obj.get("url"))}, + } + if image_obj.get("id"): + image_part["image_url"]["id"] = str(image_obj.get("id")) + parts.append(image_part) + return + + if part_type == "audio_url": + audio_obj = part.get("audio_url") + if isinstance(audio_obj, dict) and audio_obj.get("url"): + audio_part: dict[str, Any] = { + "type": "audio_url", + "audio_url": {"url": str(audio_obj.get("url"))}, + } + if audio_obj.get("id"): + audio_part["audio_url"]["id"] = str(audio_obj.get("id")) + parts.append(audio_part) + return + + if part_type == "think": + think = part.get("think") + if think: + fallback_texts.append(str(think)) + return + + raw_text = part.get("text") or part.get("content") + if raw_text: + fallback_texts.append(str(raw_text)) + else: + dumped = self._safe_json(part) + if dumped: + fallback_texts.append(dumped) + @staticmethod def _messages_equal(a: list[Message], b: list[Message]) -> bool: if len(a) != len(b): From 3236a17b0f744134fe441e3859fbdef55978474d Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 09:20:48 -0500 Subject: [PATCH 04/29] fix: refine compaction run loop and dry-run outcome --- astrbot/core/context_compaction_scheduler.py | 381 ++++++++++-------- .../unit/test_context_compaction_scheduler.py | 39 ++ 2 files changed, 252 insertions(+), 168 deletions(-) diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index 55d105c67a..5fd839fb46 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -3,7 +3,7 @@ import asyncio import json import time -from collections.abc import Iterable +from collections.abc import AsyncIterator, Iterable from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any @@ -45,6 +45,153 @@ class _RoundResult: rounds: int +class _MessageHistoryParser: + def parse(self, history: Iterable[Any]) -> list[Message]: + parsed: list[Message] = [] + for item in history: + if not isinstance(item, dict): + continue + + try: + parsed.append(Message.model_validate(item)) + continue + except Exception: + pass + + fallback = self.sanitize_message_dict(item) + if not fallback: + continue + try: + parsed.append(Message.model_validate(fallback)) + except Exception: + continue + + return parsed + + def sanitize_message_dict(self, item: dict[str, Any]) -> dict[str, Any] | None: + role = str(item.get("role", "")).strip().lower() + if role not in {"system", "user", "assistant", "tool"}: + return None + + result: dict[str, Any] = {"role": role} + + if role == "assistant" and isinstance(item.get("tool_calls"), list): + result["tool_calls"] = item["tool_calls"] + + if role == "tool" and item.get("tool_call_id"): + result["tool_call_id"] = str(item.get("tool_call_id")) + + content = item.get("content") + if content is None and role == "assistant" and result.get("tool_calls"): + result["content"] = None + return result + + result["content"] = self.sanitize_content(content, role) + + if result["content"] is None and not ( + role == "assistant" and result.get("tool_calls") + ): + return None + + return result + + def sanitize_content(self, content: Any, role: str) -> str | list[dict] | None: + if isinstance(content, str): + return content + + if isinstance(content, list): + return self.sanitize_list_content(content) + + if content is None: + if role == "assistant": + return None + return "" + + dumped = self.safe_json(content) + return dumped if dumped is not None else str(content) + + def sanitize_list_content(self, content: list[Any]) -> str | list[dict]: + parts: list[dict[str, Any]] = [] + fallback_texts: list[str] = [] + + for part in content: + if isinstance(part, str): + if part.strip(): + fallback_texts.append(part) + continue + if not isinstance(part, dict): + txt = self.safe_json(part) + if txt: + fallback_texts.append(txt) + continue + self.sanitize_content_part(part, parts, fallback_texts) + + if fallback_texts: + parts.insert(0, {"type": "text", "text": "\n".join(fallback_texts)}) + + if parts: + return parts + return "" + + def sanitize_content_part( + self, + part: dict[str, Any], + parts: list[dict[str, Any]], + fallback_texts: list[str], + ) -> None: + part_type = str(part.get("type", "")).strip() + if part_type == "text": + text_val = part.get("text") + if text_val is not None: + parts.append({"type": "text", "text": str(text_val)}) + return + + if part_type == "image_url": + image_obj = part.get("image_url") + if isinstance(image_obj, dict) and image_obj.get("url"): + image_part: dict[str, Any] = { + "type": "image_url", + "image_url": {"url": str(image_obj.get("url"))}, + } + if image_obj.get("id"): + image_part["image_url"]["id"] = str(image_obj.get("id")) + parts.append(image_part) + return + + if part_type == "audio_url": + audio_obj = part.get("audio_url") + if isinstance(audio_obj, dict) and audio_obj.get("url"): + audio_part: dict[str, Any] = { + "type": "audio_url", + "audio_url": {"url": str(audio_obj.get("url"))}, + } + if audio_obj.get("id"): + audio_part["audio_url"]["id"] = str(audio_obj.get("id")) + parts.append(audio_part) + return + + if part_type == "think": + think = part.get("think") + if think: + fallback_texts.append(str(think)) + return + + raw_text = part.get("text") or part.get("content") + if raw_text: + fallback_texts.append(str(raw_text)) + else: + dumped = self.safe_json(part) + if dumped: + fallback_texts.append(dumped) + + @staticmethod + def safe_json(value: Any) -> str | None: + try: + return json.dumps(value, ensure_ascii=False, default=str) + except Exception: + return None + + class PeriodicContextCompactionScheduler: """Periodically compact conversation history and persist summarized history back to DB. @@ -83,6 +230,7 @@ def __init__( self._stop_event = asyncio.Event() self._running_lock = asyncio.Lock() self._token_counter = EstimateTokenCounter() + self._history_parser = _MessageHistoryParser() self._bootstrapped = False self._last_report: dict[str, Any] | None = None self._last_started_at: str | None = None @@ -180,47 +328,22 @@ async def run_once( self._last_error = None return report - max_to_compact = max(1, int(cfg["max_conversations_per_run"])) - if max_conversations_override is not None: - max_to_compact = max(1, int(max_conversations_override)) - max_to_scan = max(max_to_compact, int(cfg["max_scan_per_run"])) - scan_page_size = max(10, int(cfg["scan_page_size"])) - - page = 1 - while ( - not self._stop_event.is_set() - and stats.scanned < max_to_scan - and stats.compacted < max_to_compact - ): - conversations, total = ( - await self.conversation_manager.db.get_filtered_conversations( - page=page, - page_size=scan_page_size, - ) - ) - if not conversations: - break - - for conv in conversations: - if ( - self._stop_event.is_set() - or stats.scanned >= max_to_scan - or stats.compacted >= max_to_compact - ): - break - - stats.scanned += 1 - outcome = await self._compact_one_conversation(conv, cfg) - if outcome == "compacted": - stats.compacted += 1 - elif outcome == "skipped": - stats.skipped += 1 - else: - stats.failed += 1 + max_to_compact, max_to_scan, scan_page_size = self._resolve_run_limits( + cfg, + max_conversations_override, + ) - if page * scan_page_size >= total: + async for conv in self._iter_candidate_conversations(scan_page_size): + if ( + self._stop_event.is_set() + or stats.scanned >= max_to_scan + or stats.compacted >= max_to_compact + ): break - page += 1 + + stats.scanned += 1 + outcome = await self._compact_one_conversation(conv, cfg) + self._record_outcome(stats, outcome) elapsed = time.monotonic() - started report = { @@ -236,6 +359,49 @@ async def run_once( self._last_error = None return report + @staticmethod + def _resolve_run_limits( + cfg: dict[str, Any], + max_conversations_override: int | None, + ) -> tuple[int, int, int]: + max_to_compact = max(1, int(cfg["max_conversations_per_run"])) + if max_conversations_override is not None: + max_to_compact = max(1, int(max_conversations_override)) + max_to_scan = max(max_to_compact, int(cfg["max_scan_per_run"])) + scan_page_size = max(10, int(cfg["scan_page_size"])) + return max_to_compact, max_to_scan, scan_page_size + + async def _iter_candidate_conversations( + self, + scan_page_size: int, + ) -> AsyncIterator[ConversationV2]: + page = 1 + while not self._stop_event.is_set(): + conversations, total = await self.conversation_manager.db.get_filtered_conversations( + page=page, + page_size=scan_page_size, + ) + if not conversations: + break + + for conv in conversations: + if self._stop_event.is_set(): + return + yield conv + + if page * scan_page_size >= total: + break + page += 1 + + @staticmethod + def _record_outcome(stats: _CompactionStats, outcome: str) -> None: + if outcome == "compacted": + stats.compacted += 1 + elif outcome == "skipped": + stats.skipped += 1 + else: + stats.failed += 1 + async def _sleep_or_stop(self, seconds: int) -> None: try: await asyncio.wait_for(self._stop_event.wait(), timeout=seconds) @@ -332,7 +498,7 @@ async def _compact_one_conversation( if cfg["dry_run"]: self._log_dry_run(conv, eligibility.before_tokens, after_tokens, round_result) - return "compacted" + return "skipped" persisted = await self._persist_compacted_history( conv=conv, @@ -529,91 +695,16 @@ def _is_idle_enough(updated_at: datetime | None, min_idle_minutes: int) -> bool: return (now - at).total_seconds() >= (min_idle_minutes * 60) def _parse_history(self, history: Iterable[Any]) -> list[Message]: - parsed: list[Message] = [] - for item in history: - if not isinstance(item, dict): - continue - - try: - parsed.append(Message.model_validate(item)) - continue - except Exception: - pass - - fallback = self._sanitize_message_dict(item) - if not fallback: - continue - try: - parsed.append(Message.model_validate(fallback)) - except Exception: - continue - - return parsed + return self._history_parser.parse(history) def _sanitize_message_dict(self, item: dict[str, Any]) -> dict[str, Any] | None: - role = str(item.get("role", "")).strip().lower() - if role not in {"system", "user", "assistant", "tool"}: - return None - - result: dict[str, Any] = {"role": role} - - if role == "assistant" and isinstance(item.get("tool_calls"), list): - result["tool_calls"] = item["tool_calls"] - - if role == "tool" and item.get("tool_call_id"): - result["tool_call_id"] = str(item.get("tool_call_id")) - - content = item.get("content") - if content is None and role == "assistant" and result.get("tool_calls"): - result["content"] = None - return result - - result["content"] = self._sanitize_content(content, role) - - if result["content"] is None and not ( - role == "assistant" and result.get("tool_calls") - ): - return None - - return result + return self._history_parser.sanitize_message_dict(item) def _sanitize_content(self, content: Any, role: str) -> str | list[dict] | None: - if isinstance(content, str): - return content - - if isinstance(content, list): - return self._sanitize_list_content(content) - - if content is None: - if role == "assistant": - return None - return "" - - dumped = self._safe_json(content) - return dumped if dumped is not None else str(content) + return self._history_parser.sanitize_content(content, role) def _sanitize_list_content(self, content: list[Any]) -> str | list[dict]: - parts: list[dict[str, Any]] = [] - fallback_texts: list[str] = [] - - for part in content: - if isinstance(part, str): - if part.strip(): - fallback_texts.append(part) - continue - if not isinstance(part, dict): - txt = self._safe_json(part) - if txt: - fallback_texts.append(txt) - continue - self._sanitize_content_part(part, parts, fallback_texts) - - if fallback_texts: - parts.insert(0, {"type": "text", "text": "\n".join(fallback_texts)}) - - if parts: - return parts - return "" + return self._history_parser.sanitize_list_content(content) def _sanitize_content_part( self, @@ -621,50 +712,7 @@ def _sanitize_content_part( parts: list[dict[str, Any]], fallback_texts: list[str], ) -> None: - part_type = str(part.get("type", "")).strip() - if part_type == "text": - text_val = part.get("text") - if text_val is not None: - parts.append({"type": "text", "text": str(text_val)}) - return - - if part_type == "image_url": - image_obj = part.get("image_url") - if isinstance(image_obj, dict) and image_obj.get("url"): - image_part: dict[str, Any] = { - "type": "image_url", - "image_url": {"url": str(image_obj.get("url"))}, - } - if image_obj.get("id"): - image_part["image_url"]["id"] = str(image_obj.get("id")) - parts.append(image_part) - return - - if part_type == "audio_url": - audio_obj = part.get("audio_url") - if isinstance(audio_obj, dict) and audio_obj.get("url"): - audio_part: dict[str, Any] = { - "type": "audio_url", - "audio_url": {"url": str(audio_obj.get("url"))}, - } - if audio_obj.get("id"): - audio_part["audio_url"]["id"] = str(audio_obj.get("id")) - parts.append(audio_part) - return - - if part_type == "think": - think = part.get("think") - if think: - fallback_texts.append(str(think)) - return - - raw_text = part.get("text") or part.get("content") - if raw_text: - fallback_texts.append(str(raw_text)) - else: - dumped = self._safe_json(part) - if dumped: - fallback_texts.append(dumped) + self._history_parser.sanitize_content_part(part, parts, fallback_texts) @staticmethod def _messages_equal(a: list[Message], b: list[Message]) -> bool: @@ -676,10 +724,7 @@ def _messages_equal(a: list[Message], b: list[Message]) -> bool: @staticmethod def _safe_json(value: Any) -> str | None: - try: - return json.dumps(value, ensure_ascii=False, default=str) - except Exception: - return None + return _MessageHistoryParser.safe_json(value) @staticmethod def _to_bool(value: Any, default: bool) -> bool: diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py index f99191a626..1e4a9c9533 100644 --- a/tests/unit/test_context_compaction_scheduler.py +++ b/tests/unit/test_context_compaction_scheduler.py @@ -2,9 +2,11 @@ from datetime import datetime, timedelta, timezone from types import SimpleNamespace +from unittest.mock import AsyncMock import pytest +from astrbot.core.agent.message import Message from astrbot.core.context_compaction_scheduler import PeriodicContextCompactionScheduler @@ -181,3 +183,40 @@ def test_is_idle_enough_respects_threshold() -> None: assert PeriodicContextCompactionScheduler._is_idle_enough(old, 10) is True assert PeriodicContextCompactionScheduler._is_idle_enough(recent, 10) is False assert PeriodicContextCompactionScheduler._is_idle_enough(None, 10) is True + + +@pytest.mark.asyncio +async def test_compact_one_conversation_dry_run_reports_skipped() -> None: + scheduler = _build_scheduler({"periodic_context_compaction": {"enabled": True}}) + cfg = scheduler._load_config() + cfg["dry_run"] = True + + conv = SimpleNamespace( + conversation_id="conv-1", + user_id="user-1", + content=[], + token_usage=0, + updated_at=None, + ) + scheduler._check_eligibility = lambda _conv, _cfg: SimpleNamespace( # type: ignore[method-assign] + eligible=True, + messages=[Message(role="user", content="before")], + before_tokens=100, + ) + scheduler._resolve_provider = AsyncMock(return_value=object()) # type: ignore[method-assign] + scheduler._run_compaction_rounds = AsyncMock( # type: ignore[method-assign] + return_value=SimpleNamespace( + messages=[Message(role="user", content="after")], + changed=True, + rounds=1, + ) + ) + scheduler._token_counter = SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 50) + scheduler._persist_compacted_history = AsyncMock( # type: ignore[method-assign] + return_value=True + ) + + outcome = await scheduler._compact_one_conversation(conv, cfg) + + assert outcome == "skipped" + scheduler._persist_compacted_history.assert_not_awaited() From 803c202299036b3f90b2a07ce0cd7238217a4cd9 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 09:37:23 -0500 Subject: [PATCH 05/29] fix: address review follow-ups for compaction scheduler --- astrbot/core/config/default.py | 38 ++++++------ astrbot/core/context_compaction_scheduler.py | 58 ++++++------------- astrbot/core/db/sqlite.py | 13 +++++ tests/unit/test_context_compaction_command.py | 48 +++++++++++++++ .../unit/test_context_compaction_scheduler.py | 2 +- 5 files changed, 99 insertions(+), 60 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 5180abf5df..d97d7a7332 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -18,6 +18,25 @@ "line", ] +PERIODIC_CONTEXT_COMPACTION_DEFAULTS = { + "enabled": False, + "interval_minutes": 30, + "startup_delay_seconds": 120, + "max_conversations_per_run": 8, + "max_scan_per_run": 120, + "scan_page_size": 40, + "min_idle_minutes": 15, + "min_messages": 14, + "target_tokens": 4096, + "trigger_tokens": 6144, + "max_rounds": 3, + "truncate_turns": 1, + "keep_recent": 6, + "provider_id": "", + "instruction": "", + "dry_run": False, +} + # 默认配置 DEFAULT_CONFIG = { "config_version": 2, @@ -96,24 +115,7 @@ ), "llm_compress_keep_recent": 6, "llm_compress_provider_id": "", - "periodic_context_compaction": { - "enabled": False, - "interval_minutes": 30, - "startup_delay_seconds": 120, - "max_conversations_per_run": 8, - "max_scan_per_run": 120, - "scan_page_size": 40, - "min_idle_minutes": 15, - "min_messages": 14, - "target_tokens": 4096, - "trigger_tokens": 6144, - "max_rounds": 3, - "truncate_turns": 1, - "keep_recent": 6, - "provider_id": "", - "instruction": "", - "dry_run": False, - }, + "periodic_context_compaction": dict(PERIODIC_CONTEXT_COMPACTION_DEFAULTS), "max_context_length": -1, "dequeue_context_length": 1, "streaming_response": False, diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index 5fd839fb46..cb03d8207d 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -5,7 +5,7 @@ import time from collections.abc import AsyncIterator, Iterable from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any from astrbot import logger @@ -14,6 +14,7 @@ from astrbot.core.agent.context.token_counter import EstimateTokenCounter from astrbot.core.agent.message import Message from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.config.default import PERIODIC_CONTEXT_COMPACTION_DEFAULTS from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.db.po import ConversationV2 from astrbot.core.provider.entities import ProviderType @@ -199,24 +200,7 @@ class PeriodicContextCompactionScheduler: conversation-body compaction to keep long sessions lightweight. """ - _DEFAULTS = { - "enabled": False, - "interval_minutes": 30, - "startup_delay_seconds": 120, - "max_conversations_per_run": 8, - "max_scan_per_run": 120, - "scan_page_size": 40, - "min_idle_minutes": 15, - "min_messages": 14, - "target_tokens": 4096, - "trigger_tokens": 6144, - "max_rounds": 3, - "truncate_turns": 1, - "keep_recent": 6, - "provider_id": "", - "instruction": "", - "dry_run": False, - } + _DEFAULTS = dict(PERIODIC_CONTEXT_COMPACTION_DEFAULTS) def __init__( self, @@ -333,7 +317,10 @@ async def run_once( max_conversations_override, ) - async for conv in self._iter_candidate_conversations(scan_page_size): + async for conv in self._iter_candidate_conversations( + scan_page_size=scan_page_size, + cfg=cfg, + ): if ( self._stop_event.is_set() or stats.scanned >= max_to_scan @@ -374,12 +361,21 @@ def _resolve_run_limits( async def _iter_candidate_conversations( self, scan_page_size: int, + cfg: dict[str, Any], ) -> AsyncIterator[ConversationV2]: + updated_before: datetime | None = None + if cfg["min_idle_minutes"] > 0: + updated_before = datetime.now(timezone.utc) - timedelta( + minutes=int(cfg["min_idle_minutes"]), + ) + page = 1 while not self._stop_event.is_set(): conversations, total = await self.conversation_manager.db.get_filtered_conversations( page=page, page_size=scan_page_size, + updated_before=updated_before, + min_messages=cfg["min_messages"], ) if not conversations: break @@ -528,7 +524,7 @@ def _check_eligibility( if not self._is_idle_enough(conv.updated_at, cfg["min_idle_minutes"]): return _EligibilityResult(eligible=False, messages=[], before_tokens=0) - messages = self._parse_history(history) + messages = self._history_parser.parse(history) if len(messages) < cfg["min_messages"]: return _EligibilityResult(eligible=False, messages=[], before_tokens=0) @@ -694,26 +690,6 @@ def _is_idle_enough(updated_at: datetime | None, min_idle_minutes: int) -> bool: at = at.replace(tzinfo=timezone.utc) return (now - at).total_seconds() >= (min_idle_minutes * 60) - def _parse_history(self, history: Iterable[Any]) -> list[Message]: - return self._history_parser.parse(history) - - def _sanitize_message_dict(self, item: dict[str, Any]) -> dict[str, Any] | None: - return self._history_parser.sanitize_message_dict(item) - - def _sanitize_content(self, content: Any, role: str) -> str | list[dict] | None: - return self._history_parser.sanitize_content(content, role) - - def _sanitize_list_content(self, content: list[Any]) -> str | list[dict]: - return self._history_parser.sanitize_list_content(content) - - def _sanitize_content_part( - self, - part: dict[str, Any], - parts: list[dict[str, Any]], - fallback_texts: list[str], - ) -> None: - self._history_parser.sanitize_content_part(part, parts, fallback_texts) - @staticmethod def _messages_equal(a: list[Message], b: list[Message]) -> bool: if len(a) != len(b): diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index c8e50909d5..b19f4d47f3 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -243,6 +243,19 @@ async def get_filtered_conversations( base_query = base_query.where( col(ConversationV2.platform_id).in_(kwargs["platforms"]), ) + if "updated_before" in kwargs and kwargs["updated_before"] is not None: + updated_before = kwargs["updated_before"] + base_query = base_query.where( + or_( + col(ConversationV2.updated_at).is_(None), + col(ConversationV2.updated_at) <= updated_before, + ), + ) + if "min_messages" in kwargs and kwargs["min_messages"]: + min_messages = max(1, int(kwargs["min_messages"])) + base_query = base_query.where( + func.json_array_length(col(ConversationV2.content)) >= min_messages, + ) # Get total count matching the filters count_query = select(func.count()).select_from(base_query.subquery()) diff --git a/tests/unit/test_context_compaction_command.py b/tests/unit/test_context_compaction_command.py index 560548f43b..04eee1beb5 100644 --- a/tests/unit/test_context_compaction_command.py +++ b/tests/unit/test_context_compaction_command.py @@ -80,6 +80,33 @@ async def test_status_with_runtime_report() -> None: assert "compacted=2" in text +@pytest.mark.asyncio +async def test_status_includes_last_error_line() -> None: + scheduler = _build_scheduler() + scheduler._last_report = { + "reason": "manual_command", + "scanned": 1, + "compacted": 0, + "skipped": 1, + "failed": 0, + "elapsed_sec": 0.3, + } + scheduler._last_error = "mock error" + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.status(event) + + event.send.assert_awaited_once() + chain = event.send.await_args.args[0] + text = chain.get_plain_text(with_other_comps_mark=True) + assert "最近任务[manual_command]" in text + assert "最近错误:mock error" in text + + @pytest.mark.asyncio async def test_run_with_invalid_limit() -> None: scheduler = _build_scheduler() @@ -125,3 +152,24 @@ async def test_run_triggers_scheduler_once() -> None: text = chain.get_plain_text(with_other_comps_mark=True) assert "手动触发完成" in text assert "compacted=3" in text + + +@pytest.mark.asyncio +async def test_run_reports_error_when_scheduler_raises() -> None: + scheduler = _build_scheduler() + scheduler.run_once = AsyncMock(side_effect=RuntimeError("mock boom")) + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.run(event, 2) + + scheduler.run_once.assert_awaited_once_with( + reason="manual_command", + max_conversations_override=2, + ) + chain = event.send.await_args.args[0] + text = chain.get_plain_text(with_other_comps_mark=True) + assert text.startswith("触发压缩失败:") diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py index 1e4a9c9533..9525710238 100644 --- a/tests/unit/test_context_compaction_scheduler.py +++ b/tests/unit/test_context_compaction_scheduler.py @@ -164,7 +164,7 @@ def test_sanitize_message_dict_keeps_supported_parts() -> None: ], } - sanitized = scheduler._sanitize_message_dict(raw) + sanitized = scheduler._history_parser.sanitize_message_dict(raw) assert sanitized is not None assert sanitized["role"] == "assistant" From cb023d35ca6f5a1b2299ac7634f3117f69d122ea Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 09:55:01 -0500 Subject: [PATCH 06/29] refactor: type compaction config and extend status tests --- astrbot/core/context_compaction_scheduler.py | 178 ++++++++++-------- tests/unit/test_context_compaction_command.py | 21 +++ .../unit/test_context_compaction_scheduler.py | 39 ++-- 3 files changed, 145 insertions(+), 93 deletions(-) diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index cb03d8207d..35adec114a 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -4,7 +4,7 @@ import json import time from collections.abc import AsyncIterator, Iterable -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any @@ -46,6 +46,26 @@ class _RoundResult: rounds: int +@dataclass(frozen=True) +class CompactionConfig: + enabled: bool + interval_minutes: int + startup_delay_seconds: int + max_conversations_per_run: int + max_scan_per_run: int + scan_page_size: int + min_idle_minutes: int + min_messages: int + target_tokens: int + trigger_tokens: int + max_rounds: int + truncate_turns: int + keep_recent: int + provider_id: str + instruction: str + dry_run: bool + + class _MessageHistoryParser: def parse(self, history: Iterable[Any]) -> list[Message]: parsed: list[Message] = [] @@ -227,7 +247,7 @@ def get_status(self) -> dict[str, Any]: "running": self._running_lock.locked(), "bootstrapped": self._bootstrapped, "stop_requested": self._stop_event.is_set(), - "config": cfg, + "config": asdict(cfg), "last_started_at": self._last_started_at, "last_finished_at": self._last_finished_at, "last_error": self._last_error, @@ -240,13 +260,13 @@ async def run(self) -> None: cfg = self._load_config() wait_seconds = self._resolve_wait_seconds(cfg) - if not cfg["enabled"]: + if not cfg.enabled: await self._sleep_or_stop(wait_seconds) continue if not self._bootstrapped: self._bootstrapped = True - startup_delay = max(0, int(cfg["startup_delay_seconds"])) + startup_delay = max(0, int(cfg.startup_delay_seconds)) if startup_delay > 0: logger.info( "[ContextCompact] startup delay: %ss before first run", @@ -297,7 +317,7 @@ async def run_once( started = time.monotonic() stats = _CompactionStats() - if not cfg["enabled"] and reason == "scheduled": + if not cfg.enabled and reason == "scheduled": report = { "reason": reason, "scanned": 0, @@ -348,25 +368,25 @@ async def run_once( @staticmethod def _resolve_run_limits( - cfg: dict[str, Any], + cfg: CompactionConfig, max_conversations_override: int | None, ) -> tuple[int, int, int]: - max_to_compact = max(1, int(cfg["max_conversations_per_run"])) + max_to_compact = max(1, int(cfg.max_conversations_per_run)) if max_conversations_override is not None: max_to_compact = max(1, int(max_conversations_override)) - max_to_scan = max(max_to_compact, int(cfg["max_scan_per_run"])) - scan_page_size = max(10, int(cfg["scan_page_size"])) + max_to_scan = max(max_to_compact, int(cfg.max_scan_per_run)) + scan_page_size = max(10, int(cfg.scan_page_size)) return max_to_compact, max_to_scan, scan_page_size async def _iter_candidate_conversations( self, scan_page_size: int, - cfg: dict[str, Any], + cfg: CompactionConfig, ) -> AsyncIterator[ConversationV2]: updated_before: datetime | None = None - if cfg["min_idle_minutes"] > 0: + if cfg.min_idle_minutes > 0: updated_before = datetime.now(timezone.utc) - timedelta( - minutes=int(cfg["min_idle_minutes"]), + minutes=int(cfg.min_idle_minutes), ) page = 1 @@ -375,7 +395,7 @@ async def _iter_candidate_conversations( page=page, page_size=scan_page_size, updated_before=updated_before, - min_messages=cfg["min_messages"], + min_messages=cfg.min_messages, ) if not conversations: break @@ -405,34 +425,52 @@ async def _sleep_or_stop(self, seconds: int) -> None: return @staticmethod - def _resolve_wait_seconds(cfg: dict[str, Any]) -> int: - return max(1, int(cfg["interval_minutes"])) * 60 + def _resolve_wait_seconds(cfg: CompactionConfig) -> int: + return max(1, int(cfg.interval_minutes)) * 60 - def _load_config(self) -> dict[str, Any]: + def _load_config(self) -> CompactionConfig: raw_cfg = self._load_raw_config() cfg = dict(self._DEFAULTS) cfg.update(raw_cfg) - # normalize - cfg["enabled"] = self._to_bool(cfg.get("enabled"), False) - self._normalize_int(cfg, "interval_minutes", 30, 1) - self._normalize_int(cfg, "startup_delay_seconds", 120, 0) - self._normalize_int(cfg, "max_conversations_per_run", 8, 1) - self._normalize_int(cfg, "max_scan_per_run", 120, 1) - self._normalize_int(cfg, "scan_page_size", 40, 10) - self._normalize_int(cfg, "min_idle_minutes", 15, 0) - self._normalize_int(cfg, "min_messages", 14, 2) - self._normalize_int(cfg, "target_tokens", 4096, 512) - self._normalize_trigger_tokens(cfg, raw_cfg) - self._normalize_int(cfg, "max_rounds", 3, 1) - self._normalize_int(cfg, "truncate_turns", 1, 1) - self._normalize_int(cfg, "keep_recent", 6, 0) - cfg["provider_id"] = str(cfg.get("provider_id", "") or "").strip() - cfg["instruction"] = str(cfg.get("instruction", "") or "").strip() - cfg["dry_run"] = self._to_bool(cfg.get("dry_run"), False) - - return cfg + enabled = self._to_bool(cfg.get("enabled"), False) + interval_minutes = self._to_int(cfg.get("interval_minutes"), 30, 1) + startup_delay_seconds = self._to_int(cfg.get("startup_delay_seconds"), 120, 0) + max_conversations_per_run = self._to_int( + cfg.get("max_conversations_per_run"), 8, 1 + ) + max_scan_per_run = self._to_int(cfg.get("max_scan_per_run"), 120, 1) + scan_page_size = self._to_int(cfg.get("scan_page_size"), 40, 10) + min_idle_minutes = self._to_int(cfg.get("min_idle_minutes"), 15, 0) + min_messages = self._to_int(cfg.get("min_messages"), 14, 2) + target_tokens = self._to_int(cfg.get("target_tokens"), 4096, 512) + trigger_tokens = self._resolve_trigger_tokens(target_tokens, raw_cfg) + max_rounds = self._to_int(cfg.get("max_rounds"), 3, 1) + truncate_turns = self._to_int(cfg.get("truncate_turns"), 1, 1) + keep_recent = self._to_int(cfg.get("keep_recent"), 6, 0) + provider_id = str(cfg.get("provider_id", "") or "").strip() + instruction = str(cfg.get("instruction", "") or "").strip() + dry_run = self._to_bool(cfg.get("dry_run"), False) + + return CompactionConfig( + enabled=enabled, + interval_minutes=interval_minutes, + startup_delay_seconds=startup_delay_seconds, + max_conversations_per_run=max_conversations_per_run, + max_scan_per_run=max_scan_per_run, + scan_page_size=scan_page_size, + min_idle_minutes=min_idle_minutes, + min_messages=min_messages, + target_tokens=target_tokens, + trigger_tokens=trigger_tokens, + max_rounds=max_rounds, + truncate_turns=truncate_turns, + keep_recent=keep_recent, + provider_id=provider_id, + instruction=instruction, + dry_run=dry_run, + ) def _load_raw_config(self) -> dict[str, Any]: default_conf = self.config_manager.default_conf @@ -442,35 +480,21 @@ def _load_raw_config(self) -> dict[str, Any]: return raw_cfg return {} - def _normalize_int( - self, - cfg: dict[str, Any], - key: str, - default: int, - min_value: int, - ) -> int: - cfg[key] = self._to_int(cfg.get(key), default, min_value) - return cfg[key] - - def _normalize_trigger_tokens( - self, - cfg: dict[str, Any], - raw_cfg: dict[str, Any], - ) -> int: - trigger_default = max(int(cfg["target_tokens"] * 1.5), cfg["target_tokens"] + 1) + def _resolve_trigger_tokens(self, target_tokens: int, raw_cfg: dict[str, Any]) -> int: + trigger_default = max(int(target_tokens * 1.5), target_tokens + 1) raw_trigger = raw_cfg.get("trigger_tokens") if raw_trigger is None or (isinstance(raw_trigger, str) and not raw_trigger): - cfg["trigger_tokens"] = trigger_default + trigger_tokens = trigger_default else: - cfg["trigger_tokens"] = self._to_int(raw_trigger, trigger_default, 512) - if cfg["trigger_tokens"] <= cfg["target_tokens"]: - cfg["trigger_tokens"] = cfg["target_tokens"] + 1 - return cfg["trigger_tokens"] + trigger_tokens = self._to_int(raw_trigger, trigger_default, 512) + if trigger_tokens <= target_tokens: + trigger_tokens = target_tokens + 1 + return trigger_tokens async def _compact_one_conversation( self, conv: ConversationV2, - cfg: dict[str, Any], + cfg: CompactionConfig, ) -> str: eligibility = self._check_eligibility(conv, cfg) if not eligibility.eligible: @@ -492,7 +516,7 @@ async def _compact_one_conversation( if after_tokens >= eligibility.before_tokens: return "skipped" - if cfg["dry_run"]: + if cfg.dry_run: self._log_dry_run(conv, eligibility.before_tokens, after_tokens, round_result) return "skipped" @@ -515,22 +539,22 @@ async def _compact_one_conversation( def _check_eligibility( self, conv: ConversationV2, - cfg: dict[str, Any], + cfg: CompactionConfig, ) -> _EligibilityResult: history = conv.content - if not isinstance(history, list) or len(history) < cfg["min_messages"]: + if not isinstance(history, list) or len(history) < cfg.min_messages: return _EligibilityResult(eligible=False, messages=[], before_tokens=0) - if not self._is_idle_enough(conv.updated_at, cfg["min_idle_minutes"]): + if not self._is_idle_enough(conv.updated_at, cfg.min_idle_minutes): return _EligibilityResult(eligible=False, messages=[], before_tokens=0) messages = self._history_parser.parse(history) - if len(messages) < cfg["min_messages"]: + if len(messages) < cfg.min_messages: return _EligibilityResult(eligible=False, messages=[], before_tokens=0) trusted_usage = conv.token_usage if isinstance(conv.token_usage, int) else 0 before_tokens = self._token_counter.count_tokens(messages, trusted_usage) - if before_tokens < cfg["trigger_tokens"]: + if before_tokens < cfg.trigger_tokens: return _EligibilityResult(eligible=False, messages=[], before_tokens=0) return _EligibilityResult( @@ -543,16 +567,16 @@ async def _run_compaction_rounds( self, messages: list[Message], provider: Provider, - cfg: dict[str, Any], + cfg: CompactionConfig, ) -> _RoundResult: compressed = messages changed = False rounds = 0 instruction = self._resolve_instruction(cfg) - for _ in range(cfg["max_rounds"]): + for _ in range(cfg.max_rounds): current_tokens = self._token_counter.count_tokens(compressed) - if current_tokens <= cfg["target_tokens"]: + if current_tokens <= cfg.target_tokens: break manager = self._build_context_manager(cfg, provider, instruction) @@ -568,16 +592,16 @@ async def _run_compaction_rounds( @staticmethod def _build_context_manager( - cfg: dict[str, Any], + cfg: CompactionConfig, provider: Provider, instruction: str, ) -> ContextManager: return ContextManager( ContextConfig( - max_context_tokens=cfg["target_tokens"], + max_context_tokens=cfg.target_tokens, enforce_max_turns=-1, - truncate_turns=cfg["truncate_turns"], - llm_compress_keep_recent=cfg["keep_recent"], + truncate_turns=cfg.truncate_turns, + llm_compress_keep_recent=cfg.keep_recent, llm_compress_instruction=instruction, llm_compress_provider=provider, ) @@ -641,13 +665,13 @@ def _log_compacted( async def _resolve_provider( self, - cfg: dict[str, Any], + cfg: CompactionConfig, umo: str, ) -> Provider | None: provider = None - if cfg["provider_id"]: - provider = await self.provider_manager.get_provider_by_id(cfg["provider_id"]) + if cfg.provider_id: + provider = await self.provider_manager.get_provider_by_id(cfg.provider_id) else: provider = self.provider_manager.get_using_provider( provider_type=ProviderType.CHAT_COMPLETION, @@ -663,14 +687,14 @@ async def _resolve_provider( logger.warning( "[ContextCompact] provider unavailable for umo=%s provider_id=%s", umo, - cfg["provider_id"], + cfg.provider_id, ) return None return provider - def _resolve_instruction(self, cfg: dict[str, Any]) -> str: - if cfg["instruction"]: - return cfg["instruction"] + def _resolve_instruction(self, cfg: CompactionConfig) -> str: + if cfg.instruction: + return cfg.instruction provider_settings = self.config_manager.default_conf.get("provider_settings", {}) base_instruction = provider_settings.get("llm_compress_instruction", "") diff --git a/tests/unit/test_context_compaction_command.py b/tests/unit/test_context_compaction_command.py index 04eee1beb5..13dce3cd81 100644 --- a/tests/unit/test_context_compaction_command.py +++ b/tests/unit/test_context_compaction_command.py @@ -80,6 +80,27 @@ async def test_status_with_runtime_report() -> None: assert "compacted=2" in text +@pytest.mark.asyncio +async def test_status_with_no_report() -> None: + scheduler = _build_scheduler() + scheduler._last_report = None + scheduler._last_error = None + + command = ContextCompactionCommands( + context=SimpleNamespace(context_compaction_scheduler=scheduler) + ) + event = SimpleNamespace(send=AsyncMock()) + + await command.status(event) + + event.send.assert_awaited_once() + chain = event.send.await_args.args[0] + text = chain.get_plain_text(with_other_comps_mark=True) + assert "定时上下文压缩状态" in text + assert "启用=是" in text + assert "最近任务:暂无" in text + + @pytest.mark.asyncio async def test_status_includes_last_error_line() -> None: scheduler = _build_scheduler() diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py index 9525710238..e8d6c6e795 100644 --- a/tests/unit/test_context_compaction_scheduler.py +++ b/tests/unit/test_context_compaction_scheduler.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import replace from datetime import datetime, timedelta, timezone from types import SimpleNamespace from unittest.mock import AsyncMock @@ -7,7 +8,10 @@ import pytest from astrbot.core.agent.message import Message -from astrbot.core.context_compaction_scheduler import PeriodicContextCompactionScheduler +from astrbot.core.context_compaction_scheduler import ( + CompactionConfig, + PeriodicContextCompactionScheduler, +) class DummyConfigManager: @@ -39,11 +43,11 @@ def test_load_config_normalizes_values() -> None: cfg = scheduler._load_config() - assert cfg["enabled"] is True - assert cfg["interval_minutes"] == 1 - assert cfg["target_tokens"] == 1024 - assert cfg["trigger_tokens"] == 1025 - assert cfg["max_rounds"] == 2 + assert cfg.enabled is True + assert cfg.interval_minutes == 1 + assert cfg.target_tokens == 1024 + assert cfg.trigger_tokens == 1025 + assert cfg.max_rounds == 2 @pytest.mark.parametrize( @@ -68,7 +72,7 @@ def test_load_config_enabled_bool_parsing(raw_enabled: str, expected: bool) -> N ) cfg = scheduler._load_config() - assert cfg["enabled"] is expected + assert cfg.enabled is expected @pytest.mark.parametrize( @@ -98,9 +102,9 @@ def test_load_config_clamps_numeric_minimums( scheduler = _build_scheduler({"periodic_context_compaction": raw_cfg}) cfg = scheduler._load_config() - assert cfg["interval_minutes"] == expected_interval - assert cfg["scan_page_size"] == expected_scan_page_size - assert cfg["min_messages"] == expected_min_messages + assert cfg.interval_minutes == expected_interval + assert cfg.scan_page_size == expected_scan_page_size + assert cfg.min_messages == expected_min_messages @pytest.mark.parametrize( @@ -122,8 +126,8 @@ def test_load_config_token_threshold_normalization( scheduler = _build_scheduler({"periodic_context_compaction": raw_cfg}) cfg = scheduler._load_config() - assert cfg["target_tokens"] == expected_target - assert cfg["trigger_tokens"] == expected_trigger + assert cfg.target_tokens == expected_target + assert cfg.trigger_tokens == expected_trigger @pytest.mark.parametrize("raw_value", [None, 1, "not-a-dict", []]) @@ -131,11 +135,15 @@ def test_load_config_falls_back_for_non_dict(raw_value) -> None: scheduler = _build_scheduler({"periodic_context_compaction": raw_value}) cfg = scheduler._load_config() - assert cfg == scheduler._DEFAULTS + expected = CompactionConfig(**scheduler._DEFAULTS) + assert cfg == expected def test_resolve_wait_seconds_uses_normalized_interval() -> None: - cfg = {"interval_minutes": 1} + cfg = replace( + CompactionConfig(**PeriodicContextCompactionScheduler._DEFAULTS), + interval_minutes=1, + ) assert PeriodicContextCompactionScheduler._resolve_wait_seconds(cfg) == 60 @@ -188,8 +196,7 @@ def test_is_idle_enough_respects_threshold() -> None: @pytest.mark.asyncio async def test_compact_one_conversation_dry_run_reports_skipped() -> None: scheduler = _build_scheduler({"periodic_context_compaction": {"enabled": True}}) - cfg = scheduler._load_config() - cfg["dry_run"] = True + cfg = replace(scheduler._load_config(), dry_run=True) conv = SimpleNamespace( conversation_id="conv-1", From 6e7456607fa7493bf0c25205d1803d39117a5615 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 10:11:39 -0500 Subject: [PATCH 07/29] refactor: extract message history parser from scheduler --- astrbot/core/agent/message_history_parser.py | 154 ++++++++++++++++++ astrbot/core/context_compaction_scheduler.py | 158 +------------------ 2 files changed, 157 insertions(+), 155 deletions(-) create mode 100644 astrbot/core/agent/message_history_parser.py diff --git a/astrbot/core/agent/message_history_parser.py b/astrbot/core/agent/message_history_parser.py new file mode 100644 index 0000000000..9a08fb7b6d --- /dev/null +++ b/astrbot/core/agent/message_history_parser.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import json +from collections.abc import Iterable +from typing import Any + +from astrbot.core.agent.message import Message + + +class MessageHistoryParser: + def parse(self, history: Iterable[Any]) -> list[Message]: + parsed: list[Message] = [] + for item in history: + if not isinstance(item, dict): + continue + + try: + parsed.append(Message.model_validate(item)) + continue + except Exception: + pass + + fallback = self.sanitize_message_dict(item) + if not fallback: + continue + try: + parsed.append(Message.model_validate(fallback)) + except Exception: + continue + + return parsed + + def sanitize_message_dict(self, item: dict[str, Any]) -> dict[str, Any] | None: + role = str(item.get("role", "")).strip().lower() + if role not in {"system", "user", "assistant", "tool"}: + return None + + result: dict[str, Any] = {"role": role} + + if role == "assistant" and isinstance(item.get("tool_calls"), list): + result["tool_calls"] = item["tool_calls"] + + if role == "tool" and item.get("tool_call_id"): + result["tool_call_id"] = str(item.get("tool_call_id")) + + content = item.get("content") + if content is None and role == "assistant" and result.get("tool_calls"): + result["content"] = None + return result + + result["content"] = self.sanitize_content(content, role) + + if result["content"] is None and not ( + role == "assistant" and result.get("tool_calls") + ): + return None + + return result + + def sanitize_content(self, content: Any, role: str) -> str | list[dict] | None: + if isinstance(content, str): + return content + + if isinstance(content, list): + return self.sanitize_list_content(content) + + if content is None: + if role == "assistant": + return None + return "" + + dumped = self.safe_json(content) + return dumped if dumped is not None else str(content) + + def sanitize_list_content(self, content: list[Any]) -> str | list[dict]: + parts: list[dict[str, Any]] = [] + fallback_texts: list[str] = [] + + for part in content: + if isinstance(part, str): + if part.strip(): + fallback_texts.append(part) + continue + if not isinstance(part, dict): + txt = self.safe_json(part) + if txt: + fallback_texts.append(txt) + continue + self.sanitize_content_part(part, parts, fallback_texts) + + if fallback_texts: + parts.insert(0, {"type": "text", "text": "\n".join(fallback_texts)}) + + if parts: + return parts + return "" + + def sanitize_content_part( + self, + part: dict[str, Any], + parts: list[dict[str, Any]], + fallback_texts: list[str], + ) -> None: + part_type = str(part.get("type", "")).strip() + if part_type == "text": + text_val = part.get("text") + if text_val is not None: + parts.append({"type": "text", "text": str(text_val)}) + return + + if part_type == "image_url": + image_obj = part.get("image_url") + if isinstance(image_obj, dict) and image_obj.get("url"): + image_part: dict[str, Any] = { + "type": "image_url", + "image_url": {"url": str(image_obj.get("url"))}, + } + if image_obj.get("id"): + image_part["image_url"]["id"] = str(image_obj.get("id")) + parts.append(image_part) + return + + if part_type == "audio_url": + audio_obj = part.get("audio_url") + if isinstance(audio_obj, dict) and audio_obj.get("url"): + audio_part: dict[str, Any] = { + "type": "audio_url", + "audio_url": {"url": str(audio_obj.get("url"))}, + } + if audio_obj.get("id"): + audio_part["audio_url"]["id"] = str(audio_obj.get("id")) + parts.append(audio_part) + return + + if part_type == "think": + think = part.get("think") + if think: + fallback_texts.append(str(think)) + return + + raw_text = part.get("text") or part.get("content") + if raw_text: + fallback_texts.append(str(raw_text)) + else: + dumped = self.safe_json(part) + if dumped: + fallback_texts.append(dumped) + + @staticmethod + def safe_json(value: Any) -> str | None: + try: + return json.dumps(value, ensure_ascii=False, default=str) + except Exception: + return None diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index 35adec114a..a064f89c63 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -1,9 +1,8 @@ from __future__ import annotations import asyncio -import json import time -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator from dataclasses import asdict, dataclass from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any @@ -13,6 +12,7 @@ from astrbot.core.agent.context.manager import ContextManager from astrbot.core.agent.context.token_counter import EstimateTokenCounter from astrbot.core.agent.message import Message +from astrbot.core.agent.message_history_parser import MessageHistoryParser from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.default import PERIODIC_CONTEXT_COMPACTION_DEFAULTS from astrbot.core.conversation_mgr import ConversationManager @@ -65,154 +65,6 @@ class CompactionConfig: instruction: str dry_run: bool - -class _MessageHistoryParser: - def parse(self, history: Iterable[Any]) -> list[Message]: - parsed: list[Message] = [] - for item in history: - if not isinstance(item, dict): - continue - - try: - parsed.append(Message.model_validate(item)) - continue - except Exception: - pass - - fallback = self.sanitize_message_dict(item) - if not fallback: - continue - try: - parsed.append(Message.model_validate(fallback)) - except Exception: - continue - - return parsed - - def sanitize_message_dict(self, item: dict[str, Any]) -> dict[str, Any] | None: - role = str(item.get("role", "")).strip().lower() - if role not in {"system", "user", "assistant", "tool"}: - return None - - result: dict[str, Any] = {"role": role} - - if role == "assistant" and isinstance(item.get("tool_calls"), list): - result["tool_calls"] = item["tool_calls"] - - if role == "tool" and item.get("tool_call_id"): - result["tool_call_id"] = str(item.get("tool_call_id")) - - content = item.get("content") - if content is None and role == "assistant" and result.get("tool_calls"): - result["content"] = None - return result - - result["content"] = self.sanitize_content(content, role) - - if result["content"] is None and not ( - role == "assistant" and result.get("tool_calls") - ): - return None - - return result - - def sanitize_content(self, content: Any, role: str) -> str | list[dict] | None: - if isinstance(content, str): - return content - - if isinstance(content, list): - return self.sanitize_list_content(content) - - if content is None: - if role == "assistant": - return None - return "" - - dumped = self.safe_json(content) - return dumped if dumped is not None else str(content) - - def sanitize_list_content(self, content: list[Any]) -> str | list[dict]: - parts: list[dict[str, Any]] = [] - fallback_texts: list[str] = [] - - for part in content: - if isinstance(part, str): - if part.strip(): - fallback_texts.append(part) - continue - if not isinstance(part, dict): - txt = self.safe_json(part) - if txt: - fallback_texts.append(txt) - continue - self.sanitize_content_part(part, parts, fallback_texts) - - if fallback_texts: - parts.insert(0, {"type": "text", "text": "\n".join(fallback_texts)}) - - if parts: - return parts - return "" - - def sanitize_content_part( - self, - part: dict[str, Any], - parts: list[dict[str, Any]], - fallback_texts: list[str], - ) -> None: - part_type = str(part.get("type", "")).strip() - if part_type == "text": - text_val = part.get("text") - if text_val is not None: - parts.append({"type": "text", "text": str(text_val)}) - return - - if part_type == "image_url": - image_obj = part.get("image_url") - if isinstance(image_obj, dict) and image_obj.get("url"): - image_part: dict[str, Any] = { - "type": "image_url", - "image_url": {"url": str(image_obj.get("url"))}, - } - if image_obj.get("id"): - image_part["image_url"]["id"] = str(image_obj.get("id")) - parts.append(image_part) - return - - if part_type == "audio_url": - audio_obj = part.get("audio_url") - if isinstance(audio_obj, dict) and audio_obj.get("url"): - audio_part: dict[str, Any] = { - "type": "audio_url", - "audio_url": {"url": str(audio_obj.get("url"))}, - } - if audio_obj.get("id"): - audio_part["audio_url"]["id"] = str(audio_obj.get("id")) - parts.append(audio_part) - return - - if part_type == "think": - think = part.get("think") - if think: - fallback_texts.append(str(think)) - return - - raw_text = part.get("text") or part.get("content") - if raw_text: - fallback_texts.append(str(raw_text)) - else: - dumped = self.safe_json(part) - if dumped: - fallback_texts.append(dumped) - - @staticmethod - def safe_json(value: Any) -> str | None: - try: - return json.dumps(value, ensure_ascii=False, default=str) - except Exception: - return None - - class PeriodicContextCompactionScheduler: """Periodically compact conversation history and persist summarized history back to DB. @@ -234,7 +86,7 @@ def __init__( self._stop_event = asyncio.Event() self._running_lock = asyncio.Lock() self._token_counter = EstimateTokenCounter() - self._history_parser = _MessageHistoryParser() + self._history_parser = MessageHistoryParser() self._bootstrapped = False self._last_report: dict[str, Any] | None = None self._last_started_at: str | None = None @@ -722,10 +574,6 @@ def _messages_equal(a: list[Message], b: list[Message]) -> bool: m.model_dump(exclude_none=True) for m in b ] - @staticmethod - def _safe_json(value: Any) -> str | None: - return _MessageHistoryParser.safe_json(value) - @staticmethod def _to_bool(value: Any, default: bool) -> bool: if isinstance(value, bool): From 5b143060a3510c06866ded9b0fc8cd0bb86092c0 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 12:59:43 -0500 Subject: [PATCH 08/29] refactor: extract compaction config loader helpers --- astrbot/core/context_compaction_scheduler.py | 171 ++++++++++--------- 1 file changed, 88 insertions(+), 83 deletions(-) diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index a064f89c63..1fcc2263e7 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -65,6 +65,90 @@ class CompactionConfig: instruction: str dry_run: bool + +def _to_bool(value: Any, default: bool) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + return default + + +def _to_int(value: Any, default: int, min_value: int) -> int: + try: + parsed = int(value) + except Exception: + parsed = default + return max(parsed, min_value) + + +def _resolve_trigger_tokens(target_tokens: int, raw_cfg: dict[str, Any]) -> int: + trigger_default = max(int(target_tokens * 1.5), target_tokens + 1) + raw_trigger = raw_cfg.get("trigger_tokens") + if raw_trigger is None or (isinstance(raw_trigger, str) and not raw_trigger): + trigger_tokens = trigger_default + else: + trigger_tokens = _to_int(raw_trigger, trigger_default, 512) + if trigger_tokens <= target_tokens: + trigger_tokens = target_tokens + 1 + return trigger_tokens + + +def load_compaction_config( + default_conf: dict[str, Any], + defaults: dict[str, Any] | None = None, +) -> CompactionConfig: + provider_settings = default_conf.get("provider_settings", {}) + raw_cfg = provider_settings.get("periodic_context_compaction", {}) + if not isinstance(raw_cfg, dict): + raw_cfg = {} + + cfg = dict(defaults or PERIODIC_CONTEXT_COMPACTION_DEFAULTS) + cfg.update(raw_cfg) + + enabled = _to_bool(cfg.get("enabled"), False) + interval_minutes = _to_int(cfg.get("interval_minutes"), 30, 1) + startup_delay_seconds = _to_int(cfg.get("startup_delay_seconds"), 120, 0) + max_conversations_per_run = _to_int(cfg.get("max_conversations_per_run"), 8, 1) + max_scan_per_run = _to_int(cfg.get("max_scan_per_run"), 120, 1) + scan_page_size = _to_int(cfg.get("scan_page_size"), 40, 10) + min_idle_minutes = _to_int(cfg.get("min_idle_minutes"), 15, 0) + min_messages = _to_int(cfg.get("min_messages"), 14, 2) + target_tokens = _to_int(cfg.get("target_tokens"), 4096, 512) + trigger_tokens = _resolve_trigger_tokens(target_tokens, raw_cfg) + max_rounds = _to_int(cfg.get("max_rounds"), 3, 1) + truncate_turns = _to_int(cfg.get("truncate_turns"), 1, 1) + keep_recent = _to_int(cfg.get("keep_recent"), 6, 0) + provider_id = str(cfg.get("provider_id", "") or "").strip() + instruction = str(cfg.get("instruction", "") or "").strip() + dry_run = _to_bool(cfg.get("dry_run"), False) + + return CompactionConfig( + enabled=enabled, + interval_minutes=interval_minutes, + startup_delay_seconds=startup_delay_seconds, + max_conversations_per_run=max_conversations_per_run, + max_scan_per_run=max_scan_per_run, + scan_page_size=scan_page_size, + min_idle_minutes=min_idle_minutes, + min_messages=min_messages, + target_tokens=target_tokens, + trigger_tokens=trigger_tokens, + max_rounds=max_rounds, + truncate_turns=truncate_turns, + keep_recent=keep_recent, + provider_id=provider_id, + instruction=instruction, + dry_run=dry_run, + ) + + class PeriodicContextCompactionScheduler: """Periodically compact conversation history and persist summarized history back to DB. @@ -281,68 +365,11 @@ def _resolve_wait_seconds(cfg: CompactionConfig) -> int: return max(1, int(cfg.interval_minutes)) * 60 def _load_config(self) -> CompactionConfig: - raw_cfg = self._load_raw_config() - - cfg = dict(self._DEFAULTS) - cfg.update(raw_cfg) - - enabled = self._to_bool(cfg.get("enabled"), False) - interval_minutes = self._to_int(cfg.get("interval_minutes"), 30, 1) - startup_delay_seconds = self._to_int(cfg.get("startup_delay_seconds"), 120, 0) - max_conversations_per_run = self._to_int( - cfg.get("max_conversations_per_run"), 8, 1 - ) - max_scan_per_run = self._to_int(cfg.get("max_scan_per_run"), 120, 1) - scan_page_size = self._to_int(cfg.get("scan_page_size"), 40, 10) - min_idle_minutes = self._to_int(cfg.get("min_idle_minutes"), 15, 0) - min_messages = self._to_int(cfg.get("min_messages"), 14, 2) - target_tokens = self._to_int(cfg.get("target_tokens"), 4096, 512) - trigger_tokens = self._resolve_trigger_tokens(target_tokens, raw_cfg) - max_rounds = self._to_int(cfg.get("max_rounds"), 3, 1) - truncate_turns = self._to_int(cfg.get("truncate_turns"), 1, 1) - keep_recent = self._to_int(cfg.get("keep_recent"), 6, 0) - provider_id = str(cfg.get("provider_id", "") or "").strip() - instruction = str(cfg.get("instruction", "") or "").strip() - dry_run = self._to_bool(cfg.get("dry_run"), False) - - return CompactionConfig( - enabled=enabled, - interval_minutes=interval_minutes, - startup_delay_seconds=startup_delay_seconds, - max_conversations_per_run=max_conversations_per_run, - max_scan_per_run=max_scan_per_run, - scan_page_size=scan_page_size, - min_idle_minutes=min_idle_minutes, - min_messages=min_messages, - target_tokens=target_tokens, - trigger_tokens=trigger_tokens, - max_rounds=max_rounds, - truncate_turns=truncate_turns, - keep_recent=keep_recent, - provider_id=provider_id, - instruction=instruction, - dry_run=dry_run, + return load_compaction_config( + default_conf=self.config_manager.default_conf, + defaults=self._DEFAULTS, ) - def _load_raw_config(self) -> dict[str, Any]: - default_conf = self.config_manager.default_conf - provider_settings = default_conf.get("provider_settings", {}) - raw_cfg = provider_settings.get("periodic_context_compaction", {}) - if isinstance(raw_cfg, dict): - return raw_cfg - return {} - - def _resolve_trigger_tokens(self, target_tokens: int, raw_cfg: dict[str, Any]) -> int: - trigger_default = max(int(target_tokens * 1.5), target_tokens + 1) - raw_trigger = raw_cfg.get("trigger_tokens") - if raw_trigger is None or (isinstance(raw_trigger, str) and not raw_trigger): - trigger_tokens = trigger_default - else: - trigger_tokens = self._to_int(raw_trigger, trigger_default, 512) - if trigger_tokens <= target_tokens: - trigger_tokens = target_tokens + 1 - return trigger_tokens - async def _compact_one_conversation( self, conv: ConversationV2, @@ -425,13 +452,13 @@ async def _run_compaction_rounds( changed = False rounds = 0 instruction = self._resolve_instruction(cfg) + manager = self._build_context_manager(cfg, provider, instruction) for _ in range(cfg.max_rounds): current_tokens = self._token_counter.count_tokens(compressed) if current_tokens <= cfg.target_tokens: break - manager = self._build_context_manager(cfg, provider, instruction) rounds += 1 next_messages = await manager.process(compressed) if self._messages_equal(compressed, next_messages): @@ -574,28 +601,6 @@ def _messages_equal(a: list[Message], b: list[Message]) -> bool: m.model_dump(exclude_none=True) for m in b ] - @staticmethod - def _to_bool(value: Any, default: bool) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - lowered = value.strip().lower() - if lowered in {"1", "true", "yes", "on"}: - return True - if lowered in {"0", "false", "no", "off"}: - return False - return default - - @staticmethod - def _to_int(value: Any, default: int, min_value: int) -> int: - try: - parsed = int(value) - except Exception: - parsed = default - return max(parsed, min_value) - @staticmethod def _now_iso() -> str: return datetime.now(timezone.utc).isoformat() From 5e460c715cc5ab197a70a2fbc4d6b785c09395fc Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 13:09:10 -0500 Subject: [PATCH 09/29] refactor: streamline compaction config and eligibility flow --- astrbot/core/context_compaction_scheduler.py | 187 ++++++++---------- .../unit/test_context_compaction_scheduler.py | 7 +- 2 files changed, 86 insertions(+), 108 deletions(-) diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index 1fcc2263e7..fa4c2a2d3a 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -32,13 +32,6 @@ class _CompactionStats: failed: int = 0 -@dataclass -class _EligibilityResult: - eligible: bool - messages: list[Message] - before_tokens: int - - @dataclass class _RoundResult: messages: list[Message] @@ -46,6 +39,9 @@ class _RoundResult: rounds: int +EligibilityInfo = tuple[list[Message], int] + + @dataclass(frozen=True) class CompactionConfig: enabled: bool @@ -65,88 +61,74 @@ class CompactionConfig: instruction: str dry_run: bool + @classmethod + def from_default_conf( + cls, + default_conf: dict[str, Any], + defaults: dict[str, Any], + ) -> CompactionConfig: + provider_settings = default_conf.get("provider_settings", {}) or {} + raw_cfg = provider_settings.get("periodic_context_compaction", {}) or {} + if not isinstance(raw_cfg, dict): + raw_cfg = {} + + cfg = dict(defaults) + cfg.update(raw_cfg) + + target_tokens = cls._to_int(cfg.get("target_tokens"), 4096, 512) + raw_trigger = raw_cfg.get("trigger_tokens") + trigger_default = max(int(target_tokens * 1.5), target_tokens + 1) + if raw_trigger is None or (isinstance(raw_trigger, str) and not raw_trigger): + trigger_tokens = trigger_default + else: + trigger_tokens = cls._to_int(raw_trigger, trigger_default, 512) + if trigger_tokens <= target_tokens: + trigger_tokens = target_tokens + 1 + + return cls( + enabled=cls._to_bool(cfg.get("enabled"), False), + interval_minutes=cls._to_int(cfg.get("interval_minutes"), 30, 1), + startup_delay_seconds=cls._to_int(cfg.get("startup_delay_seconds"), 120, 0), + max_conversations_per_run=cls._to_int( + cfg.get("max_conversations_per_run"), + 8, + 1, + ), + max_scan_per_run=cls._to_int(cfg.get("max_scan_per_run"), 120, 1), + scan_page_size=cls._to_int(cfg.get("scan_page_size"), 40, 10), + min_idle_minutes=cls._to_int(cfg.get("min_idle_minutes"), 15, 0), + min_messages=cls._to_int(cfg.get("min_messages"), 14, 2), + target_tokens=target_tokens, + trigger_tokens=trigger_tokens, + max_rounds=cls._to_int(cfg.get("max_rounds"), 3, 1), + truncate_turns=cls._to_int(cfg.get("truncate_turns"), 1, 1), + keep_recent=cls._to_int(cfg.get("keep_recent"), 6, 0), + provider_id=str(cfg.get("provider_id", "") or "").strip(), + instruction=str(cfg.get("instruction", "") or "").strip(), + dry_run=cls._to_bool(cfg.get("dry_run"), False), + ) -def _to_bool(value: Any, default: bool) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - lowered = value.strip().lower() - if lowered in {"1", "true", "yes", "on"}: - return True - if lowered in {"0", "false", "no", "off"}: - return False - return default - - -def _to_int(value: Any, default: int, min_value: int) -> int: - try: - parsed = int(value) - except Exception: - parsed = default - return max(parsed, min_value) - - -def _resolve_trigger_tokens(target_tokens: int, raw_cfg: dict[str, Any]) -> int: - trigger_default = max(int(target_tokens * 1.5), target_tokens + 1) - raw_trigger = raw_cfg.get("trigger_tokens") - if raw_trigger is None or (isinstance(raw_trigger, str) and not raw_trigger): - trigger_tokens = trigger_default - else: - trigger_tokens = _to_int(raw_trigger, trigger_default, 512) - if trigger_tokens <= target_tokens: - trigger_tokens = target_tokens + 1 - return trigger_tokens - - -def load_compaction_config( - default_conf: dict[str, Any], - defaults: dict[str, Any] | None = None, -) -> CompactionConfig: - provider_settings = default_conf.get("provider_settings", {}) - raw_cfg = provider_settings.get("periodic_context_compaction", {}) - if not isinstance(raw_cfg, dict): - raw_cfg = {} - - cfg = dict(defaults or PERIODIC_CONTEXT_COMPACTION_DEFAULTS) - cfg.update(raw_cfg) - - enabled = _to_bool(cfg.get("enabled"), False) - interval_minutes = _to_int(cfg.get("interval_minutes"), 30, 1) - startup_delay_seconds = _to_int(cfg.get("startup_delay_seconds"), 120, 0) - max_conversations_per_run = _to_int(cfg.get("max_conversations_per_run"), 8, 1) - max_scan_per_run = _to_int(cfg.get("max_scan_per_run"), 120, 1) - scan_page_size = _to_int(cfg.get("scan_page_size"), 40, 10) - min_idle_minutes = _to_int(cfg.get("min_idle_minutes"), 15, 0) - min_messages = _to_int(cfg.get("min_messages"), 14, 2) - target_tokens = _to_int(cfg.get("target_tokens"), 4096, 512) - trigger_tokens = _resolve_trigger_tokens(target_tokens, raw_cfg) - max_rounds = _to_int(cfg.get("max_rounds"), 3, 1) - truncate_turns = _to_int(cfg.get("truncate_turns"), 1, 1) - keep_recent = _to_int(cfg.get("keep_recent"), 6, 0) - provider_id = str(cfg.get("provider_id", "") or "").strip() - instruction = str(cfg.get("instruction", "") or "").strip() - dry_run = _to_bool(cfg.get("dry_run"), False) - - return CompactionConfig( - enabled=enabled, - interval_minutes=interval_minutes, - startup_delay_seconds=startup_delay_seconds, - max_conversations_per_run=max_conversations_per_run, - max_scan_per_run=max_scan_per_run, - scan_page_size=scan_page_size, - min_idle_minutes=min_idle_minutes, - min_messages=min_messages, - target_tokens=target_tokens, - trigger_tokens=trigger_tokens, - max_rounds=max_rounds, - truncate_turns=truncate_turns, - keep_recent=keep_recent, - provider_id=provider_id, - instruction=instruction, - dry_run=dry_run, - ) + @staticmethod + def _to_bool(value: Any, default: bool) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + return default + + @staticmethod + def _to_int(value: Any, default: int, min_value: int) -> int: + try: + parsed = int(value) + except Exception: + parsed = default + return max(parsed, min_value) class PeriodicContextCompactionScheduler: @@ -365,7 +347,7 @@ def _resolve_wait_seconds(cfg: CompactionConfig) -> int: return max(1, int(cfg.interval_minutes)) * 60 def _load_config(self) -> CompactionConfig: - return load_compaction_config( + return CompactionConfig.from_default_conf( default_conf=self.config_manager.default_conf, defaults=self._DEFAULTS, ) @@ -376,15 +358,16 @@ async def _compact_one_conversation( cfg: CompactionConfig, ) -> str: eligibility = self._check_eligibility(conv, cfg) - if not eligibility.eligible: + if eligibility is None: return "skipped" + messages, before_tokens = eligibility provider = await self._resolve_provider(cfg, conv.user_id) if not provider: return "failed" round_result = await self._run_compaction_rounds( - messages=eligibility.messages, + messages=messages, provider=provider, cfg=cfg, ) @@ -392,11 +375,11 @@ async def _compact_one_conversation( return "skipped" after_tokens = self._token_counter.count_tokens(round_result.messages) - if after_tokens >= eligibility.before_tokens: + if after_tokens >= before_tokens: return "skipped" if cfg.dry_run: - self._log_dry_run(conv, eligibility.before_tokens, after_tokens, round_result) + self._log_dry_run(conv, before_tokens, after_tokens, round_result) return "skipped" persisted = await self._persist_compacted_history( @@ -409,7 +392,7 @@ async def _compact_one_conversation( self._log_compacted( conv, - eligibility.before_tokens, + before_tokens, after_tokens, round_result, ) @@ -419,28 +402,24 @@ def _check_eligibility( self, conv: ConversationV2, cfg: CompactionConfig, - ) -> _EligibilityResult: + ) -> EligibilityInfo | None: history = conv.content if not isinstance(history, list) or len(history) < cfg.min_messages: - return _EligibilityResult(eligible=False, messages=[], before_tokens=0) + return None if not self._is_idle_enough(conv.updated_at, cfg.min_idle_minutes): - return _EligibilityResult(eligible=False, messages=[], before_tokens=0) + return None messages = self._history_parser.parse(history) if len(messages) < cfg.min_messages: - return _EligibilityResult(eligible=False, messages=[], before_tokens=0) + return None trusted_usage = conv.token_usage if isinstance(conv.token_usage, int) else 0 before_tokens = self._token_counter.count_tokens(messages, trusted_usage) if before_tokens < cfg.trigger_tokens: - return _EligibilityResult(eligible=False, messages=[], before_tokens=0) + return None - return _EligibilityResult( - eligible=True, - messages=messages, - before_tokens=before_tokens, - ) + return messages, before_tokens async def _run_compaction_rounds( self, diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py index e8d6c6e795..9b005090da 100644 --- a/tests/unit/test_context_compaction_scheduler.py +++ b/tests/unit/test_context_compaction_scheduler.py @@ -205,10 +205,9 @@ async def test_compact_one_conversation_dry_run_reports_skipped() -> None: token_usage=0, updated_at=None, ) - scheduler._check_eligibility = lambda _conv, _cfg: SimpleNamespace( # type: ignore[method-assign] - eligible=True, - messages=[Message(role="user", content="before")], - before_tokens=100, + scheduler._check_eligibility = lambda _conv, _cfg: ( # type: ignore[method-assign] + [Message(role="user", content="before")], + 100, ) scheduler._resolve_provider = AsyncMock(return_value=object()) # type: ignore[method-assign] scheduler._run_compaction_rounds = AsyncMock( # type: ignore[method-assign] From f120099d416bbddc464e9bc3489593b65a2feff0 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 20:58:08 -0500 Subject: [PATCH 10/29] refactor: unify run status and support object properties schema --- astrbot/core/config/default.py | 2 +- astrbot/core/context_compaction_scheduler.py | 98 ++++++++++++++++---- astrbot/dashboard/routes/config.py | 5 +- astrbot/dashboard/routes/util.py | 6 +- 4 files changed, 89 insertions(+), 22 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index d97d7a7332..15ba9a258c 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2531,7 +2531,7 @@ class ChatProviderTemplate(TypedDict): }, "periodic_context_compaction": { "type": "object", - "items": { + "properties": { "enabled": { "type": "bool", }, diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index fa4c2a2d3a..0b3761fad9 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -42,6 +42,14 @@ class _RoundResult: EligibilityInfo = tuple[list[Message], int] +@dataclass +class _RunStatus: + started_at: str | None = None + finished_at: str | None = None + error: str | None = None + report: dict[str, Any] | None = None + + @dataclass(frozen=True) class CompactionConfig: enabled: bool @@ -154,10 +162,39 @@ def __init__( self._token_counter = EstimateTokenCounter() self._history_parser = MessageHistoryParser() self._bootstrapped = False - self._last_report: dict[str, Any] | None = None - self._last_started_at: str | None = None - self._last_finished_at: str | None = None - self._last_error: str | None = None + self._last_status = _RunStatus() + + @property + def _last_report(self) -> dict[str, Any] | None: + return self._last_status.report + + @_last_report.setter + def _last_report(self, value: dict[str, Any] | None) -> None: + self._last_status.report = value + + @property + def _last_started_at(self) -> str | None: + return self._last_status.started_at + + @_last_started_at.setter + def _last_started_at(self, value: str | None) -> None: + self._last_status.started_at = value + + @property + def _last_finished_at(self) -> str | None: + return self._last_status.finished_at + + @_last_finished_at.setter + def _last_finished_at(self, value: str | None) -> None: + self._last_status.finished_at = value + + @property + def _last_error(self) -> str | None: + return self._last_status.error + + @_last_error.setter + def _last_error(self, value: str | None) -> None: + self._last_status.error = value def get_status(self) -> dict[str, Any]: cfg = self._load_config() @@ -166,10 +203,11 @@ def get_status(self) -> dict[str, Any]: "bootstrapped": self._bootstrapped, "stop_requested": self._stop_event.is_set(), "config": asdict(cfg), - "last_started_at": self._last_started_at, - "last_finished_at": self._last_finished_at, - "last_error": self._last_error, - "last_report": self._last_report, + "last_started_at": self._last_status.started_at, + "last_finished_at": self._last_status.finished_at, + "last_error": self._last_status.error, + "last_report": self._last_status.report, + "last_status": asdict(self._last_status), } async def run(self) -> None: @@ -195,7 +233,7 @@ async def run(self) -> None: break try: - report = await self.run_once(reason="scheduled") + report = await self.run_once(reason="scheduled", cfg=cfg) logger.info( "[ContextCompact] run done(%s): scanned=%s compacted=%s skipped=%s failed=%s elapsed=%.2fs", report.get("reason", "unknown"), @@ -206,7 +244,12 @@ async def run(self) -> None: report.get("elapsed_sec", 0.0), ) except Exception as exc: - self._last_error = str(exc) + self._last_status.error = str(exc) + self._last_status.finished_at = self._now_iso() + if self._last_status.started_at is None: + self._last_status.started_at = self._last_status.finished_at + if self._last_status.report is None: + self._last_status.report = {} logger.error( "[ContextCompact] scheduler run error: %s", exc, @@ -224,14 +267,16 @@ async def run_once( self, reason: str = "manual", max_conversations_override: int | None = None, + cfg: CompactionConfig | None = None, ) -> dict[str, Any]: """Run one compaction sweep. Exposed so future admin command/cron endpoints can trigger ad-hoc compaction. """ async with self._running_lock: - self._last_started_at = self._now_iso() - cfg = self._load_config() + started_at = self._now_iso() + if cfg is None: + cfg = self._load_config() started = time.monotonic() stats = _CompactionStats() @@ -245,9 +290,11 @@ async def run_once( "elapsed_sec": 0.0, "message": "disabled", } - self._last_report = report - self._last_finished_at = self._now_iso() - self._last_error = None + self._set_last_status( + started_at=started_at, + report=report, + error=None, + ) return report max_to_compact, max_to_scan, scan_page_size = self._resolve_run_limits( @@ -279,9 +326,11 @@ async def run_once( "failed": stats.failed, "elapsed_sec": elapsed, } - self._last_report = report - self._last_finished_at = self._now_iso() - self._last_error = None + self._set_last_status( + started_at=started_at, + report=report, + error=None, + ) return report @staticmethod @@ -336,6 +385,19 @@ def _record_outcome(stats: _CompactionStats, outcome: str) -> None: else: stats.failed += 1 + def _set_last_status( + self, + started_at: str, + report: dict[str, Any], + error: str | None, + ) -> None: + self._last_status = _RunStatus( + started_at=started_at, + finished_at=self._now_iso(), + error=error, + report=report, + ) + async def _sleep_or_stop(self, seconds: int) -> None: try: await asyncio.wait_for(self._stop_event.wait(), timeout=seconds) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index bcd7e075c7..76ffc341e0 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -160,7 +160,10 @@ def validate(data: dict, metadata: dict = schema, path="") -> None: for item in value: validate(item, meta["items"], path=f"{path}{key}.") elif meta["type"] == "object" and isinstance(value, dict): - validate(value, meta["items"], path=f"{path}{key}.") + object_schema = meta.get("items") + if not isinstance(object_schema, dict): + object_schema = meta.get("properties", {}) + validate(value, object_schema, path=f"{path}{key}.") if meta["type"] == "int" and not isinstance(value, int): casted = try_cast(value, "int") diff --git a/astrbot/dashboard/routes/util.py b/astrbot/dashboard/routes/util.py index 1056198158..c3858c7924 100644 --- a/astrbot/dashboard/routes/util.py +++ b/astrbot/dashboard/routes/util.py @@ -15,7 +15,7 @@ def get_schema_item(schema: dict | None, key_path: str) -> dict | None: 同时支持: - 扁平 schema(直接 key 命中) - - 嵌套 object schema({type: "object", items: {...}}) + - 嵌套 object schema({type: "object", items/properties: {...}}) """ if not isinstance(schema, dict) or not key_path: @@ -33,7 +33,9 @@ def get_schema_item(schema: dict | None, key_path: str) -> dict | None: return meta if not isinstance(meta, dict) or meta.get("type") != "object": return None - current = meta.get("items", {}) + current = meta.get("items") + if not isinstance(current, dict): + current = meta.get("properties", {}) return None From c5152cadc0706e62d3d73ec5da34612ad94d39d3 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 21:18:46 -0500 Subject: [PATCH 11/29] refactor: simplify compaction status and config plumbing --- astrbot/core/context_compaction_scheduler.py | 99 +++++++------------ tests/unit/test_context_compaction_command.py | 14 +-- .../unit/test_context_compaction_scheduler.py | 11 +-- 3 files changed, 42 insertions(+), 82 deletions(-) diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index 0b3761fad9..1fe507df25 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -73,8 +73,8 @@ class CompactionConfig: def from_default_conf( cls, default_conf: dict[str, Any], - defaults: dict[str, Any], ) -> CompactionConfig: + defaults = PERIODIC_CONTEXT_COMPACTION_DEFAULTS provider_settings = default_conf.get("provider_settings", {}) or {} raw_cfg = provider_settings.get("periodic_context_compaction", {}) or {} if not isinstance(raw_cfg, dict): @@ -146,8 +146,6 @@ class PeriodicContextCompactionScheduler: conversation-body compaction to keep long sessions lightweight. """ - _DEFAULTS = dict(PERIODIC_CONTEXT_COMPACTION_DEFAULTS) - def __init__( self, config_manager: AstrBotConfigManager, @@ -164,38 +162,6 @@ def __init__( self._bootstrapped = False self._last_status = _RunStatus() - @property - def _last_report(self) -> dict[str, Any] | None: - return self._last_status.report - - @_last_report.setter - def _last_report(self, value: dict[str, Any] | None) -> None: - self._last_status.report = value - - @property - def _last_started_at(self) -> str | None: - return self._last_status.started_at - - @_last_started_at.setter - def _last_started_at(self, value: str | None) -> None: - self._last_status.started_at = value - - @property - def _last_finished_at(self) -> str | None: - return self._last_status.finished_at - - @_last_finished_at.setter - def _last_finished_at(self, value: str | None) -> None: - self._last_status.finished_at = value - - @property - def _last_error(self) -> str | None: - return self._last_status.error - - @_last_error.setter - def _last_error(self, value: str | None) -> None: - self._last_status.error = value - def get_status(self) -> dict[str, Any]: cfg = self._load_config() return { @@ -214,7 +180,7 @@ async def run(self) -> None: logger.info("[ContextCompact] scheduler started") while not self._stop_event.is_set(): cfg = self._load_config() - wait_seconds = self._resolve_wait_seconds(cfg) + wait_seconds = max(1, int(cfg.interval_minutes)) * 60 if not cfg.enabled: await self._sleep_or_stop(wait_seconds) @@ -244,10 +210,13 @@ async def run(self) -> None: report.get("elapsed_sec", 0.0), ) except Exception as exc: - self._last_status.error = str(exc) - self._last_status.finished_at = self._now_iso() + finished = self._now_iso() + self._update_last_status( + finished_at=finished, + error=str(exc), + ) if self._last_status.started_at is None: - self._last_status.started_at = self._last_status.finished_at + self._last_status.started_at = finished if self._last_status.report is None: self._last_status.report = {} logger.error( @@ -275,6 +244,8 @@ async def run_once( """ async with self._running_lock: started_at = self._now_iso() + self._last_status.started_at = started_at + self._last_status.finished_at = None if cfg is None: cfg = self._load_config() started = time.monotonic() @@ -290,8 +261,9 @@ async def run_once( "elapsed_sec": 0.0, "message": "disabled", } - self._set_last_status( + self._update_last_status( started_at=started_at, + finished_at=self._now_iso(), report=report, error=None, ) @@ -315,7 +287,12 @@ async def run_once( stats.scanned += 1 outcome = await self._compact_one_conversation(conv, cfg) - self._record_outcome(stats, outcome) + if outcome == "compacted": + stats.compacted += 1 + elif outcome == "skipped": + stats.skipped += 1 + else: + stats.failed += 1 elapsed = time.monotonic() - started report = { @@ -326,8 +303,9 @@ async def run_once( "failed": stats.failed, "elapsed_sec": elapsed, } - self._set_last_status( + self._update_last_status( started_at=started_at, + finished_at=self._now_iso(), report=report, error=None, ) @@ -376,27 +354,21 @@ async def _iter_candidate_conversations( break page += 1 - @staticmethod - def _record_outcome(stats: _CompactionStats, outcome: str) -> None: - if outcome == "compacted": - stats.compacted += 1 - elif outcome == "skipped": - stats.skipped += 1 - else: - stats.failed += 1 - - def _set_last_status( + def _update_last_status( self, - started_at: str, - report: dict[str, Any], - error: str | None, + *, + started_at: str | None = None, + finished_at: str | None = None, + error: str | None = None, + report: dict[str, Any] | None = None, ) -> None: - self._last_status = _RunStatus( - started_at=started_at, - finished_at=self._now_iso(), - error=error, - report=report, - ) + if started_at is not None: + self._last_status.started_at = started_at + if finished_at is not None: + self._last_status.finished_at = finished_at + self._last_status.error = error + if report is not None: + self._last_status.report = report async def _sleep_or_stop(self, seconds: int) -> None: try: @@ -404,14 +376,9 @@ async def _sleep_or_stop(self, seconds: int) -> None: except asyncio.TimeoutError: return - @staticmethod - def _resolve_wait_seconds(cfg: CompactionConfig) -> int: - return max(1, int(cfg.interval_minutes)) * 60 - def _load_config(self) -> CompactionConfig: return CompactionConfig.from_default_conf( default_conf=self.config_manager.default_conf, - defaults=self._DEFAULTS, ) async def _compact_one_conversation( diff --git a/tests/unit/test_context_compaction_command.py b/tests/unit/test_context_compaction_command.py index 13dce3cd81..35459dc99f 100644 --- a/tests/unit/test_context_compaction_command.py +++ b/tests/unit/test_context_compaction_command.py @@ -54,7 +54,7 @@ async def test_status_when_scheduler_unavailable() -> None: @pytest.mark.asyncio async def test_status_with_runtime_report() -> None: scheduler = _build_scheduler() - scheduler._last_report = { + scheduler._last_status.report = { "reason": "manual_command", "scanned": 8, "compacted": 2, @@ -62,8 +62,8 @@ async def test_status_with_runtime_report() -> None: "failed": 0, "elapsed_sec": 1.2, } - scheduler._last_started_at = "2026-03-19T12:00:00+00:00" - scheduler._last_finished_at = "2026-03-19T12:00:01+00:00" + scheduler._last_status.started_at = "2026-03-19T12:00:00+00:00" + scheduler._last_status.finished_at = "2026-03-19T12:00:01+00:00" command = ContextCompactionCommands( context=SimpleNamespace(context_compaction_scheduler=scheduler) @@ -83,8 +83,8 @@ async def test_status_with_runtime_report() -> None: @pytest.mark.asyncio async def test_status_with_no_report() -> None: scheduler = _build_scheduler() - scheduler._last_report = None - scheduler._last_error = None + scheduler._last_status.report = None + scheduler._last_status.error = None command = ContextCompactionCommands( context=SimpleNamespace(context_compaction_scheduler=scheduler) @@ -104,7 +104,7 @@ async def test_status_with_no_report() -> None: @pytest.mark.asyncio async def test_status_includes_last_error_line() -> None: scheduler = _build_scheduler() - scheduler._last_report = { + scheduler._last_status.report = { "reason": "manual_command", "scanned": 1, "compacted": 0, @@ -112,7 +112,7 @@ async def test_status_includes_last_error_line() -> None: "failed": 0, "elapsed_sec": 0.3, } - scheduler._last_error = "mock error" + scheduler._last_status.error = "mock error" command = ContextCompactionCommands( context=SimpleNamespace(context_compaction_scheduler=scheduler) diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py index 9b005090da..607a418944 100644 --- a/tests/unit/test_context_compaction_scheduler.py +++ b/tests/unit/test_context_compaction_scheduler.py @@ -8,6 +8,7 @@ import pytest from astrbot.core.agent.message import Message +from astrbot.core.config.default import PERIODIC_CONTEXT_COMPACTION_DEFAULTS from astrbot.core.context_compaction_scheduler import ( CompactionConfig, PeriodicContextCompactionScheduler, @@ -135,18 +136,10 @@ def test_load_config_falls_back_for_non_dict(raw_value) -> None: scheduler = _build_scheduler({"periodic_context_compaction": raw_value}) cfg = scheduler._load_config() - expected = CompactionConfig(**scheduler._DEFAULTS) + expected = CompactionConfig(**PERIODIC_CONTEXT_COMPACTION_DEFAULTS) assert cfg == expected -def test_resolve_wait_seconds_uses_normalized_interval() -> None: - cfg = replace( - CompactionConfig(**PeriodicContextCompactionScheduler._DEFAULTS), - interval_minutes=1, - ) - assert PeriodicContextCompactionScheduler._resolve_wait_seconds(cfg) == 60 - - def test_get_status_returns_runtime_snapshot() -> None: scheduler = _build_scheduler( {"periodic_context_compaction": {"enabled": True, "interval_minutes": 3}} From 514b23154f173dbed062f52293408a22b49cdb8a Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 21:34:01 -0500 Subject: [PATCH 12/29] feat: improve periodic compaction trigger and scheduler limits --- .../commands/context_compaction.py | 8 +- astrbot/core/agent/message_history_parser.py | 21 ++-- astrbot/core/config/default.py | 17 ++- astrbot/core/context_compaction_scheduler.py | 70 +++++++++--- .../unit/test_context_compaction_scheduler.py | 102 ++++++++++++++++-- 5 files changed, 189 insertions(+), 29 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py b/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py index 37c0d6aa81..afb3a6b12f 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py +++ b/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py @@ -24,6 +24,12 @@ async def status(self, event: AstrMessageEvent) -> None: status = scheduler.get_status() cfg = status.get("config", {}) last = status.get("last_report") or {} + trigger_tokens = cfg.get("trigger_tokens", "?") + trigger_ratio = cfg.get("trigger_min_context_ratio", "?") + if isinstance(trigger_tokens, int) and trigger_tokens <= 0: + trigger_text = f"自动({trigger_ratio}x模型上下文)" + else: + trigger_text = str(trigger_tokens) lines = ["定时上下文压缩状态:"] lines.append( @@ -37,7 +43,7 @@ async def status(self, event: AstrMessageEvent) -> None: f" | 每轮最多扫描={cfg.get('max_scan_per_run', '?')}" ) lines.append( - f"触发Token={cfg.get('trigger_tokens', '?')}" + f"触发Token={trigger_text}" f" | 目标Token={cfg.get('target_tokens', '?')}" f" | 最大轮次={cfg.get('max_rounds', '?')}" ) diff --git a/astrbot/core/agent/message_history_parser.py b/astrbot/core/agent/message_history_parser.py index 9a08fb7b6d..e7650db50c 100644 --- a/astrbot/core/agent/message_history_parser.py +++ b/astrbot/core/agent/message_history_parser.py @@ -14,22 +14,27 @@ def parse(self, history: Iterable[Any]) -> list[Message]: if not isinstance(item, dict): continue - try: - parsed.append(Message.model_validate(item)) + msg = self._try_validate(item) + if msg is not None: + parsed.append(msg) continue - except Exception: - pass fallback = self.sanitize_message_dict(item) if not fallback: continue - try: - parsed.append(Message.model_validate(fallback)) - except Exception: - continue + msg = self._try_validate(fallback) + if msg is not None: + parsed.append(msg) return parsed + @staticmethod + def _try_validate(data: dict[str, Any]) -> Message | None: + try: + return Message.model_validate(data) + except Exception: + return None + def sanitize_message_dict(self, item: dict[str, Any]) -> dict[str, Any] | None: role = str(item.get("role", "")).strip().lower() if role not in {"system", "user", "assistant", "tool"}: diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 15ba9a258c..74b10224c6 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -28,7 +28,8 @@ "min_idle_minutes": 15, "min_messages": 14, "target_tokens": 4096, - "trigger_tokens": 6144, + "trigger_tokens": 0, + "trigger_min_context_ratio": 0.3, "max_rounds": 3, "truncate_turns": 1, "keep_recent": 6, @@ -2562,6 +2563,9 @@ class ChatProviderTemplate(TypedDict): "trigger_tokens": { "type": "int", }, + "trigger_min_context_ratio": { + "type": "float", + }, "max_rounds": { "type": "int", }, @@ -3352,7 +3356,16 @@ class ChatProviderTemplate(TypedDict): "provider_settings.periodic_context_compaction.trigger_tokens": { "description": "触发 Token 阈值", "type": "int", - "hint": "会话估算 token 超过此值才触发压缩。", + "hint": "会话估算 token 超过此值才触发压缩。<=0 表示自动按模型最大上下文比例计算。", + "condition": { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.periodic_context_compaction.trigger_min_context_ratio": { + "description": "自动触发比例", + "type": "float", + "hint": "当触发 Token 阈值 <= 0 时生效。默认 0.3(即模型最大上下文的 30%)。支持填写 0~1 或 0~100(百分比)。", "condition": { "provider_settings.periodic_context_compaction.enabled": True, "provider_settings.agent_runner_type": "local", diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index 1fe507df25..cc60018632 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -19,6 +19,7 @@ from astrbot.core.db.po import ConversationV2 from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.provider import Provider +from astrbot.core.utils.llm_metadata import LLM_METADATAS if TYPE_CHECKING: from astrbot.core.provider.manager import ProviderManager @@ -62,6 +63,7 @@ class CompactionConfig: min_messages: int target_tokens: int trigger_tokens: int + trigger_min_context_ratio: float max_rounds: int truncate_turns: int keep_recent: int @@ -84,14 +86,11 @@ def from_default_conf( cfg.update(raw_cfg) target_tokens = cls._to_int(cfg.get("target_tokens"), 4096, 512) - raw_trigger = raw_cfg.get("trigger_tokens") - trigger_default = max(int(target_tokens * 1.5), target_tokens + 1) - if raw_trigger is None or (isinstance(raw_trigger, str) and not raw_trigger): - trigger_tokens = trigger_default - else: - trigger_tokens = cls._to_int(raw_trigger, trigger_default, 512) - if trigger_tokens <= target_tokens: - trigger_tokens = target_tokens + 1 + trigger_tokens = cls._to_int(cfg.get("trigger_tokens"), 0, 0) + trigger_min_context_ratio = cls._to_ratio( + cfg.get("trigger_min_context_ratio"), + 0.3, + ) return cls( enabled=cls._to_bool(cfg.get("enabled"), False), @@ -108,6 +107,7 @@ def from_default_conf( min_messages=cls._to_int(cfg.get("min_messages"), 14, 2), target_tokens=target_tokens, trigger_tokens=trigger_tokens, + trigger_min_context_ratio=trigger_min_context_ratio, max_rounds=cls._to_int(cfg.get("max_rounds"), 3, 1), truncate_turns=cls._to_int(cfg.get("truncate_turns"), 1, 1), keep_recent=cls._to_int(cfg.get("keep_recent"), 6, 0), @@ -138,6 +138,16 @@ def _to_int(value: Any, default: int, min_value: int) -> int: parsed = default return max(parsed, min_value) + @staticmethod + def _to_ratio(value: Any, default: float) -> float: + try: + parsed = float(value) + except Exception: + parsed = default + if parsed > 1.0 and parsed <= 100.0: + parsed = parsed / 100.0 + return min(max(parsed, 0.0), 1.0) + class PeriodicContextCompactionScheduler: """Periodically compact conversation history and persist summarized history back to DB. @@ -316,10 +326,11 @@ def _resolve_run_limits( cfg: CompactionConfig, max_conversations_override: int | None, ) -> tuple[int, int, int]: + max_to_scan = max(1, int(cfg.max_scan_per_run)) max_to_compact = max(1, int(cfg.max_conversations_per_run)) if max_conversations_override is not None: max_to_compact = max(1, int(max_conversations_override)) - max_to_scan = max(max_to_compact, int(cfg.max_scan_per_run)) + max_to_compact = min(max_to_compact, max_to_scan) scan_page_size = max(10, int(cfg.scan_page_size)) return max_to_compact, max_to_scan, scan_page_size @@ -395,6 +406,10 @@ async def _compact_one_conversation( if not provider: return "failed" + trigger_tokens = self._resolve_trigger_tokens(cfg, provider) + if before_tokens < trigger_tokens: + return "skipped" + round_result = await self._run_compaction_rounds( messages=messages, provider=provider, @@ -445,11 +460,42 @@ def _check_eligibility( trusted_usage = conv.token_usage if isinstance(conv.token_usage, int) else 0 before_tokens = self._token_counter.count_tokens(messages, trusted_usage) - if before_tokens < cfg.trigger_tokens: - return None - return messages, before_tokens + def _resolve_trigger_tokens(self, cfg: CompactionConfig, provider: Provider) -> int: + if cfg.trigger_tokens > 0: + return cfg.trigger_tokens + + max_context_tokens = self._resolve_provider_max_context(provider) + if max_context_tokens > 0: + return max(1, int(max_context_tokens * cfg.trigger_min_context_ratio)) + + return max(int(cfg.target_tokens * 1.5), cfg.target_tokens + 1) + + @staticmethod + def _resolve_provider_max_context(provider: Provider) -> int: + configured = provider.provider_config.get("max_context_tokens", 0) + try: + configured_tokens = int(configured) + except Exception: + configured_tokens = 0 + if configured_tokens > 0: + return configured_tokens + + model = provider.get_model() + model_info = LLM_METADATAS.get(model) + if not isinstance(model_info, dict): + return 0 + limit = model_info.get("limit") + if not isinstance(limit, dict): + return 0 + context = limit.get("context") + try: + context_tokens = int(context) + except Exception: + context_tokens = 0 + return max(context_tokens, 0) + async def _run_compaction_rounds( self, messages: list[Message], diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py index 607a418944..27d7a758b2 100644 --- a/tests/unit/test_context_compaction_scheduler.py +++ b/tests/unit/test_context_compaction_scheduler.py @@ -47,7 +47,8 @@ def test_load_config_normalizes_values() -> None: assert cfg.enabled is True assert cfg.interval_minutes == 1 assert cfg.target_tokens == 1024 - assert cfg.trigger_tokens == 1025 + assert cfg.trigger_tokens == 1000 + assert cfg.trigger_min_context_ratio == pytest.approx(0.3) assert cfg.max_rounds == 2 @@ -111,12 +112,12 @@ def test_load_config_clamps_numeric_minimums( @pytest.mark.parametrize( ("raw_cfg", "expected_target", "expected_trigger"), [ - ({"target_tokens": 1024}, 1024, 1536), - ({"target_tokens": 1024, "trigger_tokens": None}, 1024, 1536), - ({"target_tokens": 1024, "trigger_tokens": 512}, 1024, 1025), - ({"target_tokens": 1024, "trigger_tokens": "512"}, 1024, 1025), + ({"target_tokens": 1024}, 1024, 0), + ({"target_tokens": 1024, "trigger_tokens": None}, 1024, 0), + ({"target_tokens": 1024, "trigger_tokens": 512}, 1024, 512), + ({"target_tokens": 1024, "trigger_tokens": "512"}, 1024, 512), ({"target_tokens": 1024, "trigger_tokens": 2048}, 1024, 2048), - ({"target_tokens": 10}, 512, 768), + ({"target_tokens": 10}, 512, 0), ], ) def test_load_config_token_threshold_normalization( @@ -131,6 +132,24 @@ def test_load_config_token_threshold_normalization( assert cfg.trigger_tokens == expected_trigger +@pytest.mark.parametrize( + ("raw_ratio", "expected"), + [ + (0.3, 0.3), + ("30", 0.3), + ("0.25", 0.25), + (-1, 0.0), + (500, 1.0), + ], +) +def test_load_config_trigger_ratio_normalization(raw_ratio, expected: float) -> None: + scheduler = _build_scheduler( + {"periodic_context_compaction": {"trigger_min_context_ratio": raw_ratio}} + ) + cfg = scheduler._load_config() + assert cfg.trigger_min_context_ratio == pytest.approx(expected) + + @pytest.mark.parametrize("raw_value", [None, 1, "not-a-dict", []]) def test_load_config_falls_back_for_non_dict(raw_value) -> None: scheduler = _build_scheduler({"periodic_context_compaction": raw_value}) @@ -186,6 +205,76 @@ def test_is_idle_enough_respects_threshold() -> None: assert PeriodicContextCompactionScheduler._is_idle_enough(None, 10) is True +def test_resolve_run_limits_treats_max_scan_as_upper_bound() -> None: + scheduler = _build_scheduler({}) + cfg = replace( + scheduler._load_config(), + max_conversations_per_run=8, + max_scan_per_run=3, + scan_page_size=5, + ) + + max_to_compact, max_to_scan, page_size = scheduler._resolve_run_limits(cfg, None) + assert max_to_compact == 3 + assert max_to_scan == 3 + assert page_size == 10 + + max_to_compact, max_to_scan, _ = scheduler._resolve_run_limits(cfg, 20) + assert max_to_compact == 3 + assert max_to_scan == 3 + + +def test_resolve_trigger_tokens_prefers_manual_value() -> None: + scheduler = _build_scheduler({}) + cfg = replace( + scheduler._load_config(), + target_tokens=1024, + trigger_tokens=1500, + trigger_min_context_ratio=0.3, + ) + provider = SimpleNamespace( + provider_config={"max_context_tokens": 32768}, + get_model=lambda: "unknown-model", + ) + + resolved = scheduler._resolve_trigger_tokens(cfg, provider) + assert resolved == 1500 + + +def test_resolve_trigger_tokens_uses_ratio_when_auto_mode() -> None: + scheduler = _build_scheduler({}) + cfg = replace( + scheduler._load_config(), + target_tokens=1024, + trigger_tokens=0, + trigger_min_context_ratio=0.3, + ) + provider = SimpleNamespace( + provider_config={"max_context_tokens": 32768}, + get_model=lambda: "unknown-model", + ) + + resolved = scheduler._resolve_trigger_tokens(cfg, provider) + assert resolved == 9830 + + +def test_resolve_trigger_tokens_falls_back_when_provider_context_unknown() -> None: + scheduler = _build_scheduler({}) + cfg = replace( + scheduler._load_config(), + target_tokens=1024, + trigger_tokens=0, + trigger_min_context_ratio=0.3, + ) + provider = SimpleNamespace( + provider_config={"max_context_tokens": 0}, + get_model=lambda: "unknown-model", + ) + + resolved = scheduler._resolve_trigger_tokens(cfg, provider) + assert resolved == 1536 + + @pytest.mark.asyncio async def test_compact_one_conversation_dry_run_reports_skipped() -> None: scheduler = _build_scheduler({"periodic_context_compaction": {"enabled": True}}) @@ -210,6 +299,7 @@ async def test_compact_one_conversation_dry_run_reports_skipped() -> None: rounds=1, ) ) + scheduler._resolve_trigger_tokens = lambda _cfg, _provider: 1 # type: ignore[method-assign] scheduler._token_counter = SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 50) scheduler._persist_compacted_history = AsyncMock( # type: ignore[method-assign] return_value=True From 3ff3533f92e6d866ac0740006d8ed9a02ec90162 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 21:56:45 -0500 Subject: [PATCH 13/29] feat: add configurable token counting and post-tool compaction --- astrbot/core/agent/context/config.py | 4 + astrbot/core/agent/context/manager.py | 7 +- astrbot/core/agent/context/token_counter.py | 110 ++++++++++++++++++ .../agent/runners/tool_loop_agent_runner.py | 15 ++- astrbot/core/astr_main_agent.py | 6 + astrbot/core/config/default.py | 26 +++++ .../method/agent_sub_stages/internal.py | 9 ++ tests/agent/test_context_manager.py | 12 ++ tests/agent/test_token_counter.py | 17 ++- tests/test_tool_loop_agent_runner.py | 49 ++++++++ 10 files changed, 251 insertions(+), 4 deletions(-) diff --git a/astrbot/core/agent/context/config.py b/astrbot/core/agent/context/config.py index b8fd8eb968..7d43aaed64 100644 --- a/astrbot/core/agent/context/config.py +++ b/astrbot/core/agent/context/config.py @@ -29,6 +29,10 @@ class ContextConfig: """Number of recent messages to keep during LLM-based compression.""" llm_compress_provider: "Provider | None" = None """LLM provider used for compression tasks. If None, truncation strategy is used.""" + token_counter_mode: str = "estimate" + """Token counting mode: estimate, tokenizer, auto.""" + token_counter_model: str | None = None + """Optional model name for tokenizer-based token counting.""" custom_token_counter: TokenCounter | None = None """Custom token counting method. If None, the default method is used.""" custom_compressor: ContextCompressor | None = None diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 216a3e7e15..927b78a1fe 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -3,7 +3,7 @@ from ..message import Message from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor from .config import ContextConfig -from .token_counter import EstimateTokenCounter +from .token_counter import create_token_counter from .truncator import ContextTruncator @@ -25,7 +25,10 @@ def __init__( """ self.config = config - self.token_counter = config.custom_token_counter or EstimateTokenCounter() + self.token_counter = config.custom_token_counter or create_token_counter( + config.token_counter_mode, + model=config.token_counter_model, + ) self.truncator = ContextTruncator() if config.custom_compressor: diff --git a/astrbot/core/agent/context/token_counter.py b/astrbot/core/agent/context/token_counter.py index 7c60cb23ec..2f28d77854 100644 --- a/astrbot/core/agent/context/token_counter.py +++ b/astrbot/core/agent/context/token_counter.py @@ -1,6 +1,9 @@ import json +from collections.abc import Callable from typing import Protocol, runtime_checkable +from astrbot import logger + from ..message import AudioURLPart, ImageURLPart, Message, TextPart, ThinkPart @@ -76,3 +79,110 @@ def _estimate_tokens(self, text: str) -> int: chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"]) other_count = len(text) - chinese_count return int(chinese_count * 0.6 + other_count * 0.3) + + +class TokenizerTokenCounter: + """Tokenizer-based token counter. + + Uses `tiktoken` when available and falls back to estimate mode if encoding + is unavailable. + """ + + def __init__(self, model: str | None = None) -> None: + self._estimate = EstimateTokenCounter() + self._encode: Callable[[str], int] | None = None + self._available = False + self._init_encoder(model) + + @property + def available(self) -> bool: + return self._available + + def _init_encoder(self, model: str | None) -> None: + try: + import tiktoken # type: ignore + except Exception: + self._available = False + self._encode = None + return + + try: + if model: + encoding = tiktoken.encoding_for_model(model) + else: + encoding = tiktoken.get_encoding("cl100k_base") + except Exception: + encoding = tiktoken.get_encoding("cl100k_base") + + self._available = True + self._encode = lambda text: len(encoding.encode(text)) + + def count_tokens( + self, messages: list[Message], trusted_token_usage: int = 0 + ) -> int: + if trusted_token_usage > 0: + return trusted_token_usage + if not self._available: + return self._estimate.count_tokens(messages) + + total = 0 + for msg in messages: + content = msg.content + if isinstance(content, str): + total += self._encode_len(content) + elif isinstance(content, list): + for part in content: + if isinstance(part, TextPart): + total += self._encode_len(part.text) + elif isinstance(part, ThinkPart): + total += self._encode_len(part.think) + elif isinstance(part, ImageURLPart): + total += IMAGE_TOKEN_ESTIMATE + elif isinstance(part, AudioURLPart): + total += AUDIO_TOKEN_ESTIMATE + + if msg.tool_calls: + for tc in msg.tool_calls: + tc_str = json.dumps( + tc if isinstance(tc, dict) else tc.model_dump(), + ensure_ascii=False, + default=str, + ) + total += self._encode_len(tc_str) + + return total + + def _encode_len(self, text: str) -> int: + if not self._encode: + return self._estimate._estimate_tokens(text) + try: + return self._encode(text) + except Exception: + return self._estimate._estimate_tokens(text) + + +def create_token_counter( + mode: str | None = None, + *, + model: str | None = None, +) -> TokenCounter: + normalized = str(mode or "estimate").strip().lower() + + if normalized == "estimate": + return EstimateTokenCounter() + + if normalized in {"tokenizer", "auto"}: + tokenizer_counter = TokenizerTokenCounter(model=model) + if tokenizer_counter.available: + return tokenizer_counter + if normalized == "tokenizer": + logger.warning( + "context_token_counter_mode=tokenizer but `tiktoken` is unavailable; fallback to estimate." + ) + return EstimateTokenCounter() + + logger.warning( + "Unknown context_token_counter_mode=%s, fallback to estimate.", + normalized, + ) + return EstimateTokenCounter() diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 743b280070..e6964c57a2 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -104,6 +104,10 @@ async def reset( llm_compress_provider: Provider | None = None, # truncate by turns compressor truncate_turns: int = 1, + # context token counting mode + token_counter_mode: str = "estimate", + # run context compression immediately after tool execution + compact_context_after_tool_call: bool = False, # customize custom_token_counter: TokenCounter | None = None, custom_compressor: ContextCompressor | None = None, @@ -118,11 +122,13 @@ async def reset( self.llm_compress_keep_recent = llm_compress_keep_recent self.llm_compress_provider = llm_compress_provider self.truncate_turns = truncate_turns + self.token_counter_mode = token_counter_mode + self.compact_context_after_tool_call = compact_context_after_tool_call self.custom_token_counter = custom_token_counter self.custom_compressor = custom_compressor # we will do compress when: # 1. before requesting LLM - # TODO: 2. after LLM output a tool call + # 2. optionally after tool execution, controlled by config self.context_config = ContextConfig( # <=0 will never do compress max_context_tokens=provider.provider_config.get("max_context_tokens", 0), @@ -132,6 +138,8 @@ async def reset( llm_compress_instruction=self.llm_compress_instruction, llm_compress_keep_recent=self.llm_compress_keep_recent, llm_compress_provider=self.llm_compress_provider, + token_counter_mode=self.token_counter_mode, + token_counter_model=provider.get_model(), custom_token_counter=self.custom_token_counter, custom_compressor=self.custom_compressor, ) @@ -618,6 +626,11 @@ async def step(self): self.req.append_tool_calls_result(tool_calls_result) + if self.compact_context_after_tool_call: + self.run_context.messages = await self.context_manager.process( + self.run_context.messages, + ) + async def step_until_done( self, max_step: int ) -> T.AsyncGenerator[AgentResponse, None]: diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 10b67253fe..3685b2a285 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -119,6 +119,10 @@ class MainAgentBuildConfig: """The number of most recent turns to keep during llm_compress strategy.""" llm_compress_provider_id: str = "" """The provider ID for the LLM used in context compression.""" + context_token_counter_mode: str = "estimate" + """Token counting mode for context compaction: estimate, tokenizer, auto.""" + compact_context_after_tool_call: bool = False + """Whether to run context compaction check immediately after tool execution.""" max_context_length: int = -1 """The maximum number of turns to keep in context. -1 means no limit. This enforce max turns before compression""" @@ -1203,6 +1207,8 @@ async def build_main_agent( llm_compress_instruction=config.llm_compress_instruction, llm_compress_keep_recent=config.llm_compress_keep_recent, llm_compress_provider=_get_compress_provider(config, plugin_context), + token_counter_mode=config.context_token_counter_mode, + compact_context_after_tool_call=config.compact_context_after_tool_call, truncate_turns=config.dequeue_context_length, enforce_max_turns=config.max_context_length, tool_schema_mode=config.tool_schema_mode, diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 74b10224c6..9d77e46864 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -116,6 +116,8 @@ ), "llm_compress_keep_recent": 6, "llm_compress_provider_id": "", + "context_token_counter_mode": "estimate", + "compact_context_after_tool_call": False, "periodic_context_compaction": dict(PERIODIC_CONTEXT_COMPACTION_DEFAULTS), "max_context_length": -1, "dequeue_context_length": 1, @@ -2530,6 +2532,12 @@ class ChatProviderTemplate(TypedDict): "prompt_prefix": { "type": "string", }, + "context_token_counter_mode": { + "type": "string", + }, + "compact_context_after_tool_call": { + "type": "bool", + }, "periodic_context_compaction": { "type": "object", "properties": { @@ -3273,6 +3281,24 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.context_token_counter_mode": { + "description": "Token 计数模式", + "type": "string", + "options": ["estimate", "tokenizer", "auto"], + "labels": ["估算", "Tokenizer", "自动(优先 Tokenizer)"], + "hint": "用于上下文压缩触发判断。tokenizer 模式会优先使用 tiktoken,不可用时回退估算。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.compact_context_after_tool_call": { + "description": "工具调用后立即检查压缩", + "type": "bool", + "hint": "开启后,每次工具执行回写上下文后都会立刻触发一次上下文压缩检查。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, "provider_settings.periodic_context_compaction.enabled": { "description": "启用定时历史压缩", "type": "bool", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 523d758a0a..614c39a06e 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -91,6 +91,13 @@ async def initialize(self, ctx: PipelineContext) -> None: self.llm_compress_provider_id: str = settings.get( "llm_compress_provider_id", "" ) + self.context_token_counter_mode: str = str( + settings.get("context_token_counter_mode", "estimate") + ) + self.compact_context_after_tool_call: bool = settings.get( + "compact_context_after_tool_call", + False, + ) self.max_context_length = settings["max_context_length"] # int self.dequeue_context_length: int = min( max(1, settings["dequeue_context_length"]), @@ -125,6 +132,8 @@ async def initialize(self, ctx: PipelineContext) -> None: llm_compress_instruction=self.llm_compress_instruction, llm_compress_keep_recent=self.llm_compress_keep_recent, llm_compress_provider_id=self.llm_compress_provider_id, + context_token_counter_mode=self.context_token_counter_mode, + compact_context_after_tool_call=self.compact_context_after_tool_call, max_context_length=self.max_context_length, dequeue_context_length=self.dequeue_context_length, llm_safety_mode=self.llm_safety_mode, diff --git a/tests/agent/test_context_manager.py b/tests/agent/test_context_manager.py index 0b955ff401..d5c6c8c1f1 100644 --- a/tests/agent/test_context_manager.py +++ b/tests/agent/test_context_manager.py @@ -94,6 +94,18 @@ def test_init_with_truncate_compressor(self): assert isinstance(manager.compressor, TruncateByTurnsCompressor) + @patch("astrbot.core.agent.context.manager.create_token_counter") + def test_init_uses_token_counter_mode(self, mock_create_token_counter): + """Test token counter mode wiring into ContextManager.""" + fake_counter = MagicMock() + mock_create_token_counter.return_value = fake_counter + config = ContextConfig(token_counter_mode="auto", token_counter_model="gpt-4") + + manager = ContextManager(config) + + mock_create_token_counter.assert_called_once_with("auto", model="gpt-4") + assert manager.token_counter is fake_counter + # ==================== Empty and Edge Cases ==================== @pytest.mark.asyncio diff --git a/tests/agent/test_token_counter.py b/tests/agent/test_token_counter.py index c68b056e66..9f73a2519a 100644 --- a/tests/agent/test_token_counter.py +++ b/tests/agent/test_token_counter.py @@ -4,6 +4,8 @@ AUDIO_TOKEN_ESTIMATE, IMAGE_TOKEN_ESTIMATE, EstimateTokenCounter, + TokenizerTokenCounter, + create_token_counter, ) from astrbot.core.agent.message import ( AudioURLPart, @@ -13,7 +15,6 @@ ThinkPart, ) - counter = EstimateTokenCounter() @@ -101,3 +102,17 @@ def test_tool_calls_counted(self): # 文本 + tool call JSON 都应被计算 text_only = counter.count_tokens([_msg("assistant", "calling tool")]) assert tokens > text_only + + +class TestCounterFactory: + def test_create_estimate_mode(self): + created = create_token_counter("estimate") + assert isinstance(created, EstimateTokenCounter) + + def test_create_unknown_mode_fallback(self): + created = create_token_counter("unknown-mode") + assert isinstance(created, EstimateTokenCounter) + + def test_create_tokenizer_mode_returns_valid_counter_type(self): + created = create_token_counter("tokenizer", model="gpt-4") + assert isinstance(created, (TokenizerTokenCounter, EstimateTokenCounter)) diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 38c601cee5..ff6069b1d7 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -536,6 +536,55 @@ async def test_follow_up_ticket_not_consumed_when_no_next_tool_call( assert ticket.consumed is False +@pytest.mark.asyncio +async def test_compact_context_after_tool_call_enabled( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + mock_provider.should_call_tools = True + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + compact_context_after_tool_call=True, + ) + + runner.context_manager.process = AsyncMock( # type: ignore[method-assign] + side_effect=lambda messages, trusted_token_usage=0: messages, + ) + + async for _ in runner.step(): + pass + + assert runner.context_manager.process.await_count == 2 + + +@pytest.mark.asyncio +async def test_compact_context_after_tool_call_disabled_by_default( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + mock_provider.should_call_tools = True + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + runner.context_manager.process = AsyncMock( # type: ignore[method-assign] + side_effect=lambda messages, trusted_token_usage=0: messages, + ) + + async for _ in runner.step(): + pass + + assert runner.context_manager.process.await_count == 1 + + if __name__ == "__main__": # 运行测试 pytest.main([__file__, "-v"]) From d27d010d2788865b31ef00ea50cb3c42cf2b6ea2 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 23:05:11 -0500 Subject: [PATCH 14/29] feat(context): add pinned top-memory interface and ctxmem admin commands --- .../builtin_commands/commands/__init__.py | 2 + .../commands/context_memory.py | 192 ++++++++++++++++++ .../builtin_stars/builtin_commands/main.py | 50 +++++ astrbot/core/astr_main_agent.py | 27 +++ astrbot/core/config/default.py | 119 +++++++++++ astrbot/core/context_memory.py | 167 +++++++++++++++ tests/unit/test_astr_main_agent.py | 53 +++++ tests/unit/test_context_memory_command.py | 136 +++++++++++++ 8 files changed, 746 insertions(+) create mode 100644 astrbot/builtin_stars/builtin_commands/commands/context_memory.py create mode 100644 astrbot/core/context_memory.py create mode 100644 tests/unit/test_context_memory_command.py diff --git a/astrbot/builtin_stars/builtin_commands/commands/__init__.py b/astrbot/builtin_stars/builtin_commands/commands/__init__.py index d56f0cae24..d96e52da6e 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/__init__.py +++ b/astrbot/builtin_stars/builtin_commands/commands/__init__.py @@ -3,6 +3,7 @@ from .admin import AdminCommands from .alter_cmd import AlterCmdCommands from .context_compaction import ContextCompactionCommands +from .context_memory import ContextMemoryCommands from .conversation import ConversationCommands from .help import HelpCommand from .llm import LLMCommands @@ -19,6 +20,7 @@ "AlterCmdCommands", "ConversationCommands", "ContextCompactionCommands", + "ContextMemoryCommands", "HelpCommand", "LLMCommands", "PersonaCommands", diff --git a/astrbot/builtin_stars/builtin_commands/commands/context_memory.py b/astrbot/builtin_stars/builtin_commands/commands/context_memory.py new file mode 100644 index 0000000000..f200fe65f1 --- /dev/null +++ b/astrbot/builtin_stars/builtin_commands/commands/context_memory.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from typing import Any + +from astrbot.api import star +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.core.context_memory import ensure_context_memory_settings + + +class ContextMemoryCommands: + def __init__(self, context: star.Context) -> None: + self.context = context + + def _get_provider_settings(self, event: AstrMessageEvent) -> tuple[Any, dict[str, Any]]: + cfg = self.context.get_config(umo=event.unified_msg_origin) + provider_settings = cfg.get("provider_settings", {}) + if not isinstance(provider_settings, dict): + provider_settings = {} + cfg["provider_settings"] = provider_settings + return cfg, provider_settings + + @staticmethod + def _save_config(cfg: Any) -> None: + save_func = getattr(cfg, "save_config", None) + if callable(save_func): + save_func() + + @staticmethod + def _parse_switch(value: str) -> bool | None: + normalized = value.strip().lower() + if normalized in {"1", "true", "on", "yes", "enable", "enabled"}: + return True + if normalized in {"0", "false", "off", "no", "disable", "disabled"}: + return False + return None + + async def status(self, event: AstrMessageEvent) -> None: + _, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + pinned = cm_cfg.get("pinned_memories", []) + if not isinstance(pinned, list): + pinned = [] + + lines = ["上下文记忆状态:"] + lines.append( + "启用=" + + ("是" if bool(cm_cfg.get("enabled", False)) else "否") + + " | 注入顶层记忆=" + + ("是" if bool(cm_cfg.get("inject_pinned_memory", True)) else "否") + ) + lines.append( + f"顶层记忆条数={len(pinned)}" + f" | 最大条数={cm_cfg.get('pinned_max_items', '?')}" + f" | 单条最大字符={cm_cfg.get('pinned_max_chars_per_item', '?')}" + ) + lines.append( + "检索增强(开发中)=" + + ("是" if bool(cm_cfg.get("retrieval_enabled", False)) else "否") + + f" | backend={cm_cfg.get('retrieval_backend', '') or '-'}" + + f" | top_k={cm_cfg.get('retrieval_top_k', '?')}" + ) + await event.send(MessageChain().message("\n".join(lines))) + + async def ls(self, event: AstrMessageEvent) -> None: + _, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + pinned = cm_cfg.get("pinned_memories", []) + if not isinstance(pinned, list) or not pinned: + await event.send(MessageChain().message("当前没有手动顶层记忆。")) + return + + lines = ["手动顶层记忆列表:"] + for idx, text in enumerate(pinned, start=1): + text_str = str(text) + if len(text_str) > 180: + text_str = text_str[:180] + "..." + lines.append(f"{idx}. {text_str}") + await event.send(MessageChain().message("\n".join(lines))) + + async def add(self, event: AstrMessageEvent, text: str) -> None: + content = str(text or "").strip() + if not content: + await event.send(MessageChain().message("用法: /ctxmem add <记忆内容>")) + return + + cfg, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + pinned = cm_cfg.get("pinned_memories", []) + if not isinstance(pinned, list): + pinned = [] + cm_cfg["pinned_memories"] = pinned + + max_items = int(cm_cfg.get("pinned_max_items", 8) or 8) + if len(pinned) >= max_items: + await event.send( + MessageChain().message( + f"已达到顶层记忆最大条数({max_items}),请先使用 /ctxmem rm <序号> 或 /ctxmem clear。", + ) + ) + return + + max_chars = int(cm_cfg.get("pinned_max_chars_per_item", 400) or 400) + truncated = False + if len(content) > max_chars: + content = content[:max_chars] + truncated = True + + pinned.append(content) + self._save_config(cfg) + + msg = f"已添加顶层记忆 #{len(pinned)}。" + if truncated: + msg += f" 内容超过上限,已截断到 {max_chars} 字符。" + await event.send(MessageChain().message(msg)) + + async def rm(self, event: AstrMessageEvent, index: int) -> None: + cfg, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + pinned = cm_cfg.get("pinned_memories", []) + if not isinstance(pinned, list) or not pinned: + await event.send(MessageChain().message("当前没有可删除的顶层记忆。")) + return + + if index < 1 or index > len(pinned): + await event.send( + MessageChain().message(f"序号超出范围。请输入 1~{len(pinned)}。") + ) + return + + removed = str(pinned.pop(index - 1)) + self._save_config(cfg) + preview = removed if len(removed) <= 80 else removed[:80] + "..." + await event.send(MessageChain().message(f"已删除顶层记忆 #{index}: {preview}")) + + async def clear(self, event: AstrMessageEvent) -> None: + cfg, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + pinned = cm_cfg.get("pinned_memories", []) + count = len(pinned) if isinstance(pinned, list) else 0 + cm_cfg["pinned_memories"] = [] + self._save_config(cfg) + await event.send(MessageChain().message(f"已清空顶层记忆,共 {count} 条。")) + + async def enable(self, event: AstrMessageEvent, value: str = "") -> None: + cfg, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + enabled = bool(cm_cfg.get("enabled", False)) + + value = str(value or "").strip() + if value: + parsed = self._parse_switch(value) + if parsed is None: + await event.send( + MessageChain().message("参数错误。用法: /ctxmem enable [on|off]") + ) + return + enabled = parsed + else: + enabled = not enabled + + cm_cfg["enabled"] = enabled + self._save_config(cfg) + await event.send( + MessageChain().message( + "上下文记忆注入已" + ("开启。" if enabled else "关闭。") + ) + ) + + async def retrieval(self, event: AstrMessageEvent, value: str = "") -> None: + cfg, provider_settings = self._get_provider_settings(event) + cm_cfg = ensure_context_memory_settings(provider_settings) + enabled = bool(cm_cfg.get("retrieval_enabled", False)) + + value = str(value or "").strip() + if value: + parsed = self._parse_switch(value) + if parsed is None: + await event.send( + MessageChain().message("参数错误。用法: /ctxmem retrieval [on|off]") + ) + return + enabled = parsed + else: + enabled = not enabled + + cm_cfg["retrieval_enabled"] = enabled + self._save_config(cfg) + await event.send( + MessageChain().message( + "检索增强开关(开发中)已" + ("开启。" if enabled else "关闭。") + ) + ) diff --git a/astrbot/builtin_stars/builtin_commands/main.py b/astrbot/builtin_stars/builtin_commands/main.py index d2d7fbc1b2..f2939f7540 100644 --- a/astrbot/builtin_stars/builtin_commands/main.py +++ b/astrbot/builtin_stars/builtin_commands/main.py @@ -1,10 +1,12 @@ from astrbot.api import star from astrbot.api.event import AstrMessageEvent, filter +from astrbot.core.star.filter.command import GreedyStr from .commands import ( AdminCommands, AlterCmdCommands, ContextCompactionCommands, + ContextMemoryCommands, ConversationCommands, HelpCommand, LLMCommands, @@ -28,6 +30,7 @@ def __init__(self, context: star.Context) -> None: self.admin_c = AdminCommands(self.context) self.conversation_c = ConversationCommands(self.context) self.ctxcompact_c = ContextCompactionCommands(self.context) + self.ctxmem_c = ContextMemoryCommands(self.context) self.provider_c = ProviderCommands(self.context) self.persona_c = PersonaCommands(self.context) self.alter_cmd_c = AlterCmdCommands(self.context) @@ -150,6 +153,53 @@ async def ctxcompact_run( """手动触发一次上下文压缩(可选 limit 覆盖本次压缩会话数)""" await self.ctxcompact_c.run(event, limit) + @filter.permission_type(filter.PermissionType.ADMIN) + @filter.command_group("ctxmem") + def ctxmem(self) -> None: + """上下文记忆管理(手动顶层记忆)""" + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("status") + async def ctxmem_status(self, event: AstrMessageEvent) -> None: + """查看上下文记忆状态""" + await self.ctxmem_c.status(event) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("ls") + async def ctxmem_ls(self, event: AstrMessageEvent) -> None: + """查看手动顶层记忆列表""" + await self.ctxmem_c.ls(event) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("add") + async def ctxmem_add(self, event: AstrMessageEvent, text: GreedyStr) -> None: + """添加一条手动顶层记忆。ctxmem add """ + await self.ctxmem_c.add(event, text) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("rm") + async def ctxmem_rm(self, event: AstrMessageEvent, index: int) -> None: + """删除一条手动顶层记忆。ctxmem rm """ + await self.ctxmem_c.rm(event, index) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("clear") + async def ctxmem_clear(self, event: AstrMessageEvent) -> None: + """清空手动顶层记忆""" + await self.ctxmem_c.clear(event) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("enable") + async def ctxmem_enable(self, event: AstrMessageEvent, value: str = "") -> None: + """开关上下文记忆注入。ctxmem enable [on|off]""" + await self.ctxmem_c.enable(event, value) + + @filter.permission_type(filter.PermissionType.ADMIN) + @ctxmem.command("retrieval") + async def ctxmem_retrieval(self, event: AstrMessageEvent, value: str = "") -> None: + """开关检索增强预留开关。ctxmem retrieval [on|off]""" + await self.ctxmem_c.retrieval(event, value) + @filter.command("reset") async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 3685b2a285..c338846bc5 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -50,6 +50,10 @@ TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, retrieve_knowledge_base, ) +from astrbot.core.context_memory import ( + build_pinned_memory_system_block, + load_context_memory_config, +) from astrbot.core.conversation_mgr import Conversation from astrbot.core.message.components import File, Image, Reply from astrbot.core.persona_error_reply import ( @@ -621,6 +625,28 @@ def _append_system_reminders( req.extra_user_content_parts.append(TextPart(text=system_content)) +def _inject_context_memory( + req: ProviderRequest, + cfg: dict, +) -> None: + """Inject manually pinned top-level memories into system prompt. + + Vector retrieval enhancement is intentionally deferred to a follow-up PR. + This function only handles manually configured pinned memories. + """ + if not isinstance(cfg, dict): + return + cm_cfg = load_context_memory_config(cfg) + memory_block = build_pinned_memory_system_block(cm_cfg) + if not memory_block: + return + + if req.system_prompt and req.system_prompt.strip(): + req.system_prompt = f"{req.system_prompt}\n\n{memory_block}" + else: + req.system_prompt = memory_block + + async def _decorate_llm_request( event: AstrMessageEvent, req: ProviderRequest, @@ -659,6 +685,7 @@ async def _decorate_llm_request( if tz is None: tz = plugin_context.get_config().get("timezone") _append_system_reminders(event, req, cfg, tz) + _inject_context_memory(req, cfg) def _modalities_fix(provider: Provider, req: ProviderRequest) -> None: diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 9d77e46864..b918d3bed7 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -38,6 +38,19 @@ "dry_run": False, } +CONTEXT_MEMORY_DEFAULTS = { + "enabled": False, + "inject_pinned_memory": True, + "pinned_memories": [], + "pinned_max_items": 8, + "pinned_max_chars_per_item": 400, + # Reserved switches for follow-up PRs (manual opt-in only). + "retrieval_enabled": False, + "retrieval_backend": "", + "retrieval_provider_id": "", + "retrieval_top_k": 5, +} + # 默认配置 DEFAULT_CONFIG = { "config_version": 2, @@ -119,6 +132,7 @@ "context_token_counter_mode": "estimate", "compact_context_after_tool_call": False, "periodic_context_compaction": dict(PERIODIC_CONTEXT_COMPACTION_DEFAULTS), + "context_memory": dict(CONTEXT_MEMORY_DEFAULTS), "max_context_length": -1, "dequeue_context_length": 1, "streaming_response": False, @@ -2594,6 +2608,39 @@ class ChatProviderTemplate(TypedDict): }, }, }, + "context_memory": { + "type": "object", + "properties": { + "enabled": { + "type": "bool", + }, + "inject_pinned_memory": { + "type": "bool", + }, + "pinned_memories": { + "type": "list", + "items": {"type": "string"}, + }, + "pinned_max_items": { + "type": "int", + }, + "pinned_max_chars_per_item": { + "type": "int", + }, + "retrieval_enabled": { + "type": "bool", + }, + "retrieval_backend": { + "type": "string", + }, + "retrieval_provider_id": { + "type": "string", + }, + "retrieval_top_k": { + "type": "int", + }, + }, + }, "max_context_length": { "type": "int", }, @@ -3452,6 +3499,78 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.context_memory.enabled": { + "description": "启用上下文记忆注入", + "type": "bool", + "hint": "启用后可将手动维护的顶层记忆注入到 system prompt,并预留向量记忆检索接口。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.inject_pinned_memory": { + "description": "注入手动顶层记忆", + "type": "bool", + "hint": "将 `pinned_memories` 作为高优先级记忆注入系统提示词。", + "condition": { + "provider_settings.context_memory.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.pinned_max_items": { + "description": "顶层记忆最大条数", + "type": "int", + "hint": "通过管理命令添加手动顶层记忆时允许保留的最大条目数。", + "condition": { + "provider_settings.context_memory.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.pinned_max_chars_per_item": { + "description": "单条顶层记忆最大字符数", + "type": "int", + "hint": "超出长度的条目会被截断,避免 system prompt 膨胀。", + "condition": { + "provider_settings.context_memory.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.retrieval_enabled": { + "description": "启用检索增强(开发中)", + "type": "bool", + "hint": "预留开关,默认关闭;向量检索增强建议在后续 PR 中实现。", + "condition": { + "provider_settings.context_memory.enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.retrieval_backend": { + "description": "检索后端标识(预留)", + "type": "string", + "hint": "例如 zep/mem0/custom,当前版本仅用于配置预留。", + "condition": { + "provider_settings.context_memory.retrieval_enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.retrieval_provider_id": { + "description": "检索重排模型提供商 ID(预留)", + "type": "string", + "_special": "select_provider", + "hint": "当前版本仅保留配置,不会触发额外检索调用。", + "condition": { + "provider_settings.context_memory.retrieval_enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_memory.retrieval_top_k": { + "description": "检索 Top-K(预留)", + "type": "int", + "hint": "后续检索增强功能默认使用的召回条数。", + "condition": { + "provider_settings.context_memory.retrieval_enabled": True, + "provider_settings.agent_runner_type": "local", + }, + }, }, "condition": { "provider_settings.agent_runner_type": "local", diff --git a/astrbot/core/context_memory.py b/astrbot/core/context_memory.py new file mode 100644 index 0000000000..613f35a6c8 --- /dev/null +++ b/astrbot/core/context_memory.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + +DEFAULT_CONTEXT_MEMORY_SETTINGS: dict[str, Any] = { + # Global switch for context-memory related features. + "enabled": False, + # Manually maintained top-level memories injected into system prompt. + "inject_pinned_memory": True, + "pinned_memories": [], + "pinned_max_items": 8, + "pinned_max_chars_per_item": 400, + # Retrieval enhancement is intentionally reserved for future PRs. + "retrieval_enabled": False, + "retrieval_backend": "", + "retrieval_provider_id": "", + "retrieval_top_k": 5, +} + + +@runtime_checkable +class VectorLongTermMemoryRetriever(Protocol): + """Reserved protocol for future vector-DB long-term memory retrieval.""" + + async def retrieve( + self, + *, + unified_msg_origin: str, + query: str, + top_k: int, + ) -> list[str]: + """Return ranked memory snippets for prompt assembly.""" + ... + + +@dataclass(frozen=True) +class ContextMemoryConfig: + enabled: bool = False + inject_pinned_memory: bool = True + pinned_memories: list[str] = field(default_factory=list) + pinned_max_items: int = 8 + pinned_max_chars_per_item: int = 400 + retrieval_enabled: bool = False + retrieval_backend: str = "" + retrieval_provider_id: str = "" + retrieval_top_k: int = 5 + + +def _to_bool(value: Any, default: bool) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + return default + + +def _to_int(value: Any, default: int, min_value: int) -> int: + try: + parsed = int(value) + except Exception: + parsed = default + return max(parsed, min_value) + + +def normalize_context_memory_settings(raw: dict[str, Any] | None) -> dict[str, Any]: + normalized = dict(DEFAULT_CONTEXT_MEMORY_SETTINGS) + if not isinstance(raw, dict): + return normalized + + normalized["enabled"] = _to_bool( + raw.get("enabled"), + bool(DEFAULT_CONTEXT_MEMORY_SETTINGS["enabled"]), + ) + normalized["inject_pinned_memory"] = _to_bool( + raw.get("inject_pinned_memory"), + bool(DEFAULT_CONTEXT_MEMORY_SETTINGS["inject_pinned_memory"]), + ) + normalized["pinned_max_items"] = _to_int( + raw.get("pinned_max_items"), + int(DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_max_items"]), + 1, + ) + normalized["pinned_max_chars_per_item"] = _to_int( + raw.get("pinned_max_chars_per_item"), + int(DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_max_chars_per_item"]), + 1, + ) + normalized["retrieval_enabled"] = _to_bool( + raw.get("retrieval_enabled"), + bool(DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_enabled"]), + ) + normalized["retrieval_backend"] = str(raw.get("retrieval_backend", "") or "").strip() + normalized["retrieval_provider_id"] = str( + raw.get("retrieval_provider_id", "") or "" + ).strip() + normalized["retrieval_top_k"] = _to_int( + raw.get("retrieval_top_k"), + int(DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_top_k"]), + 1, + ) + + pinned_max_items = int(normalized["pinned_max_items"]) + pinned_max_chars = int(normalized["pinned_max_chars_per_item"]) + pinned_raw = raw.get("pinned_memories", []) + pinned_memories: list[str] = [] + if isinstance(pinned_raw, list): + for item in pinned_raw: + text = str(item or "").strip() + if not text: + continue + if len(text) > pinned_max_chars: + text = text[:pinned_max_chars] + pinned_memories.append(text) + if len(pinned_memories) >= pinned_max_items: + break + normalized["pinned_memories"] = pinned_memories + + return normalized + + +def load_context_memory_config(provider_settings: dict[str, Any] | None) -> ContextMemoryConfig: + raw = None + if isinstance(provider_settings, dict): + raw = provider_settings.get("context_memory") + normalized = normalize_context_memory_settings(raw if isinstance(raw, dict) else None) + return ContextMemoryConfig( + enabled=bool(normalized["enabled"]), + inject_pinned_memory=bool(normalized["inject_pinned_memory"]), + pinned_memories=list(normalized["pinned_memories"]), + pinned_max_items=int(normalized["pinned_max_items"]), + pinned_max_chars_per_item=int(normalized["pinned_max_chars_per_item"]), + retrieval_enabled=bool(normalized["retrieval_enabled"]), + retrieval_backend=str(normalized["retrieval_backend"]), + retrieval_provider_id=str(normalized["retrieval_provider_id"]), + retrieval_top_k=int(normalized["retrieval_top_k"]), + ) + + +def ensure_context_memory_settings(provider_settings: dict[str, Any]) -> dict[str, Any]: + """Normalize and persist context_memory subtree in provider_settings.""" + normalized = normalize_context_memory_settings(provider_settings.get("context_memory")) + provider_settings["context_memory"] = normalized + return normalized + + +def build_pinned_memory_system_block(config: ContextMemoryConfig) -> str: + """Build system-prompt block for manually pinned top-level memories.""" + if not config.enabled or not config.inject_pinned_memory: + return "" + if not config.pinned_memories: + return "" + + lines = [ + "", + "The following high-priority memory is manually configured and should be respected when relevant:", + ] + for idx, memory in enumerate(config.pinned_memories, start=1): + lines.append(f"{idx}. {memory}") + lines.append("") + return "\n".join(lines) diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 9a42abd733..b453e34745 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -663,6 +663,59 @@ async def test_decorate_llm_request_no_conversation(self, mock_event, mock_conte assert req.prompt == "Hello" + @pytest.mark.asyncio + async def test_decorate_llm_request_injects_pinned_context_memory( + self, mock_event, mock_context + ): + module = ama + req = ProviderRequest(prompt="Hello", system_prompt="System") + config = module.MainAgentBuildConfig( + tool_call_timeout=60, + provider_settings={ + "context_memory": { + "enabled": True, + "inject_pinned_memory": True, + "pinned_memories": [ + "用户喜欢先结论后细节。", + "优先使用中文回复。", + ], + "pinned_max_items": 8, + "pinned_max_chars_per_item": 200, + } + }, + ) + + with patch.object(mock_context, "get_config") as mock_get_config: + mock_get_config.return_value = {} + await module._decorate_llm_request(mock_event, req, mock_context, config) + + assert "" in req.system_prompt + assert "1. 用户喜欢先结论后细节。" in req.system_prompt + assert "2. 优先使用中文回复。" in req.system_prompt + + @pytest.mark.asyncio + async def test_decorate_llm_request_skips_pinned_memory_when_disabled( + self, mock_event, mock_context + ): + module = ama + req = ProviderRequest(prompt="Hello", system_prompt="System") + config = module.MainAgentBuildConfig( + tool_call_timeout=60, + provider_settings={ + "context_memory": { + "enabled": True, + "inject_pinned_memory": False, + "pinned_memories": ["这条不应被注入。"], + } + }, + ) + + with patch.object(mock_context, "get_config") as mock_get_config: + mock_get_config.return_value = {} + await module._decorate_llm_request(mock_event, req, mock_context, config) + + assert "" not in req.system_prompt + class TestModalitiesFix: """Tests for _modalities_fix function.""" diff --git a/tests/unit/test_context_memory_command.py b/tests/unit/test_context_memory_command.py new file mode 100644 index 0000000000..955e3f6905 --- /dev/null +++ b/tests/unit/test_context_memory_command.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from astrbot.builtin_stars.builtin_commands.commands.context_memory import ( + ContextMemoryCommands, +) + + +class DummyConfig(dict): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.save_calls = 0 + + def save_config(self) -> None: + self.save_calls += 1 + + +def _build_command_with_cfg(cfg: DummyConfig) -> ContextMemoryCommands: + context = SimpleNamespace(get_config=lambda umo=None: cfg) + return ContextMemoryCommands(context=context) + + +def _build_event() -> SimpleNamespace: + return SimpleNamespace(unified_msg_origin="umo:test", send=AsyncMock()) + + +@pytest.mark.asyncio +async def test_status_shows_defaults() -> None: + cfg = DummyConfig({"provider_settings": {}}) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.status(event) + + event.send.assert_awaited_once() + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert "上下文记忆状态" in text + assert "启用=否" in text + assert "顶层记忆条数=0" in text + + +@pytest.mark.asyncio +async def test_add_list_and_remove_memory() -> None: + cfg = DummyConfig( + { + "provider_settings": { + "context_memory": { + "enabled": True, + "inject_pinned_memory": True, + "pinned_memories": [], + "pinned_max_items": 4, + "pinned_max_chars_per_item": 120, + } + } + } + ) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.add(event, "用户喜欢先给结论再解释。") + assert cfg.save_calls == 1 + stored = cfg["provider_settings"]["context_memory"]["pinned_memories"] + assert stored == ["用户喜欢先给结论再解释。"] + + await command.ls(event) + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert "手动顶层记忆列表" in text + assert "用户喜欢先给结论再解释" in text + + await command.rm(event, 1) + assert cfg["provider_settings"]["context_memory"]["pinned_memories"] == [] + + +@pytest.mark.asyncio +async def test_add_rejects_when_reaching_max_items() -> None: + cfg = DummyConfig( + { + "provider_settings": { + "context_memory": { + "pinned_memories": ["A"], + "pinned_max_items": 1, + "pinned_max_chars_per_item": 100, + } + } + } + ) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.add(event, "B") + + # only normalization happened in-memory, no successful add/save + assert cfg["provider_settings"]["context_memory"]["pinned_memories"] == ["A"] + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert "已达到顶层记忆最大条数" in text + + +@pytest.mark.asyncio +async def test_enable_and_retrieval_toggles() -> None: + cfg = DummyConfig({"provider_settings": {"context_memory": {"enabled": False}}}) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.enable(event) + assert cfg["provider_settings"]["context_memory"]["enabled"] is True + + await command.retrieval(event, "on") + assert cfg["provider_settings"]["context_memory"]["retrieval_enabled"] is True + + +@pytest.mark.asyncio +async def test_add_truncates_long_memory_item() -> None: + cfg = DummyConfig( + { + "provider_settings": { + "context_memory": { + "pinned_memories": [], + "pinned_max_items": 3, + "pinned_max_chars_per_item": 10, + } + } + } + ) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.add(event, "1234567890abcdef") + + stored = cfg["provider_settings"]["context_memory"]["pinned_memories"] + assert stored == ["1234567890"] + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert "已截断到 10 字符" in text From ed91180b9378f6052c45bbf5c8b2fa501083ce69 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 23:13:41 -0500 Subject: [PATCH 15/29] fix(review): harden token counter fallback and align compaction command/scheduler behavior --- .../commands/context_compaction.py | 9 ++- astrbot/core/agent/context/token_counter.py | 7 +- astrbot/core/context_compaction_scheduler.py | 65 ++++++++++++++++--- tests/agent/test_token_counter.py | 18 +++++ tests/unit/test_context_compaction_command.py | 2 +- .../unit/test_context_compaction_scheduler.py | 33 +++++++++- 6 files changed, 117 insertions(+), 17 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py b/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py index afb3a6b12f..d1c6c23d55 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py +++ b/astrbot/builtin_stars/builtin_commands/commands/context_compaction.py @@ -1,5 +1,6 @@ from astrbot.api import star from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.core import logger from astrbot.core.context_compaction_scheduler import PeriodicContextCompactionScheduler @@ -27,7 +28,10 @@ async def status(self, event: AstrMessageEvent) -> None: trigger_tokens = cfg.get("trigger_tokens", "?") trigger_ratio = cfg.get("trigger_min_context_ratio", "?") if isinstance(trigger_tokens, int) and trigger_tokens <= 0: - trigger_text = f"自动({trigger_ratio}x模型上下文)" + if isinstance(trigger_ratio, (int, float)): + trigger_text = f"自动({trigger_ratio}x模型上下文或目标长度估算)" + else: + trigger_text = "自动(基于目标长度估算)" else: trigger_text = str(trigger_tokens) @@ -87,7 +91,8 @@ async def run(self, event: AstrMessageEvent, limit: int | None = None) -> None: max_conversations_override=limit, ) except Exception as exc: - await event.send(MessageChain().message(f"触发压缩失败: {exc}")) + logger.error("ctxcompact run failed: %s", exc, exc_info=True) + await event.send(MessageChain().message("触发压缩失败,请查看服务端日志。")) return msg = ( diff --git a/astrbot/core/agent/context/token_counter.py b/astrbot/core/agent/context/token_counter.py index 2f28d77854..2baca2775e 100644 --- a/astrbot/core/agent/context/token_counter.py +++ b/astrbot/core/agent/context/token_counter.py @@ -112,7 +112,12 @@ def _init_encoder(self, model: str | None) -> None: else: encoding = tiktoken.get_encoding("cl100k_base") except Exception: - encoding = tiktoken.get_encoding("cl100k_base") + try: + encoding = tiktoken.get_encoding("cl100k_base") + except Exception: + self._available = False + self._encode = None + return self._available = True self._encode = lambda text: len(encoding.encode(text)) diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index cc60018632..58c51d7e90 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -10,7 +10,11 @@ from astrbot import logger from astrbot.core.agent.context.config import ContextConfig from astrbot.core.agent.context.manager import ContextManager -from astrbot.core.agent.context.token_counter import EstimateTokenCounter +from astrbot.core.agent.context.token_counter import ( + EstimateTokenCounter, + TokenCounter, + create_token_counter, +) from astrbot.core.agent.message import Message from astrbot.core.agent.message_history_parser import MessageHistoryParser from astrbot.core.astrbot_config_mgr import AstrBotConfigManager @@ -167,7 +171,10 @@ def __init__( self.provider_manager = provider_manager self._stop_event = asyncio.Event() self._running_lock = asyncio.Lock() + # Default fallback counter. Actual counter is resolved by provider_settings + # (context_token_counter_mode) and provider model when available. self._token_counter = EstimateTokenCounter() + self._token_counter_cache: dict[tuple[str, str], TokenCounter] = {} self._history_parser = MessageHistoryParser() self._bootstrapped = False self._last_status = _RunStatus() @@ -397,15 +404,16 @@ async def _compact_one_conversation( conv: ConversationV2, cfg: CompactionConfig, ) -> str: - eligibility = self._check_eligibility(conv, cfg) - if eligibility is None: - return "skipped" - messages, before_tokens = eligibility - provider = await self._resolve_provider(cfg, conv.user_id) if not provider: return "failed" + token_counter = self._resolve_token_counter(provider) + eligibility = self._check_eligibility(conv, cfg, token_counter) + if eligibility is None: + return "skipped" + messages, before_tokens = eligibility + trigger_tokens = self._resolve_trigger_tokens(cfg, provider) if before_tokens < trigger_tokens: return "skipped" @@ -414,11 +422,12 @@ async def _compact_one_conversation( messages=messages, provider=provider, cfg=cfg, + token_counter=token_counter, ) if not round_result.changed: return "skipped" - after_tokens = self._token_counter.count_tokens(round_result.messages) + after_tokens = token_counter.count_tokens(round_result.messages) if after_tokens >= before_tokens: return "skipped" @@ -446,6 +455,7 @@ def _check_eligibility( self, conv: ConversationV2, cfg: CompactionConfig, + token_counter: TokenCounter, ) -> EligibilityInfo | None: history = conv.content if not isinstance(history, list) or len(history) < cfg.min_messages: @@ -459,7 +469,7 @@ def _check_eligibility( return None trusted_usage = conv.token_usage if isinstance(conv.token_usage, int) else 0 - before_tokens = self._token_counter.count_tokens(messages, trusted_usage) + before_tokens = token_counter.count_tokens(messages, trusted_usage) return messages, before_tokens def _resolve_trigger_tokens(self, cfg: CompactionConfig, provider: Provider) -> int: @@ -501,15 +511,16 @@ async def _run_compaction_rounds( messages: list[Message], provider: Provider, cfg: CompactionConfig, + token_counter: TokenCounter, ) -> _RoundResult: compressed = messages changed = False rounds = 0 instruction = self._resolve_instruction(cfg) - manager = self._build_context_manager(cfg, provider, instruction) + manager = self._build_context_manager(cfg, provider, instruction, token_counter) for _ in range(cfg.max_rounds): - current_tokens = self._token_counter.count_tokens(compressed) + current_tokens = token_counter.count_tokens(compressed) if current_tokens <= cfg.target_tokens: break @@ -528,6 +539,7 @@ def _build_context_manager( cfg: CompactionConfig, provider: Provider, instruction: str, + token_counter: TokenCounter, ) -> ContextManager: return ContextManager( ContextConfig( @@ -537,6 +549,7 @@ def _build_context_manager( llm_compress_keep_recent=cfg.keep_recent, llm_compress_instruction=instruction, llm_compress_provider=provider, + custom_token_counter=token_counter, ) ) @@ -635,6 +648,38 @@ def _resolve_instruction(self, cfg: CompactionConfig) -> str: return base_instruction.strip() return "" + def _resolve_token_counter(self, provider: Provider | None) -> TokenCounter: + provider_settings = self.config_manager.default_conf.get("provider_settings", {}) + mode = "estimate" + if isinstance(provider_settings, dict): + mode = str(provider_settings.get("context_token_counter_mode", "estimate")) + mode = mode.strip().lower() or "estimate" + + model = "" + if provider is not None: + try: + model = str(provider.get_model() or "") + except Exception: + model = "" + cache_key = (mode, model) + cached = self._token_counter_cache.get(cache_key) + if cached is not None: + return cached + + try: + resolved = create_token_counter(mode=mode, model=model or None) + except Exception as exc: + logger.warning( + "[ContextCompact] failed to create token counter(mode=%s, model=%s), fallback to estimate: %s", + mode, + model or "-", + exc, + ) + resolved = self._token_counter + + self._token_counter_cache[cache_key] = resolved + return resolved + @staticmethod def _is_idle_enough(updated_at: datetime | None, min_idle_minutes: int) -> bool: if min_idle_minutes <= 0: diff --git a/tests/agent/test_token_counter.py b/tests/agent/test_token_counter.py index 9f73a2519a..ae65b917ad 100644 --- a/tests/agent/test_token_counter.py +++ b/tests/agent/test_token_counter.py @@ -1,5 +1,8 @@ """Tests for EstimateTokenCounter multimodal support.""" +import sys +from types import SimpleNamespace + from astrbot.core.agent.context.token_counter import ( AUDIO_TOKEN_ESTIMATE, IMAGE_TOKEN_ESTIMATE, @@ -116,3 +119,18 @@ def test_create_unknown_mode_fallback(self): def test_create_tokenizer_mode_returns_valid_counter_type(self): created = create_token_counter("tokenizer", model="gpt-4") assert isinstance(created, (TokenizerTokenCounter, EstimateTokenCounter)) + + def test_tokenizer_counter_gracefully_handles_broken_fallback_encoder( + self, monkeypatch + ): + fake_tiktoken = SimpleNamespace( + encoding_for_model=lambda _model: (_ for _ in ()).throw(RuntimeError("boom")), + get_encoding=lambda _name: (_ for _ in ()).throw(RuntimeError("boom")), + ) + monkeypatch.setitem(sys.modules, "tiktoken", fake_tiktoken) + + counter = TokenizerTokenCounter(model="gpt-4") + assert counter.available is False + + created = create_token_counter("tokenizer", model="gpt-4") + assert isinstance(created, EstimateTokenCounter) diff --git a/tests/unit/test_context_compaction_command.py b/tests/unit/test_context_compaction_command.py index 35459dc99f..f4d3e39e93 100644 --- a/tests/unit/test_context_compaction_command.py +++ b/tests/unit/test_context_compaction_command.py @@ -193,4 +193,4 @@ async def test_run_reports_error_when_scheduler_raises() -> None: ) chain = event.send.await_args.args[0] text = chain.get_plain_text(with_other_comps_mark=True) - assert text.startswith("触发压缩失败:") + assert text == "触发压缩失败,请查看服务端日志。" diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py index 27d7a758b2..0d10106e53 100644 --- a/tests/unit/test_context_compaction_scheduler.py +++ b/tests/unit/test_context_compaction_scheduler.py @@ -275,6 +275,29 @@ def test_resolve_trigger_tokens_falls_back_when_provider_context_unknown() -> No assert resolved == 1536 +def test_resolve_token_counter_uses_configured_mode_and_provider_model(monkeypatch) -> None: + scheduler = _build_scheduler({"context_token_counter_mode": "auto"}) + provider = SimpleNamespace(get_model=lambda: "gpt-4o") + called: dict[str, str | None] = {} + + fake_counter = SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 0) + + def _fake_create(mode: str | None = None, *, model: str | None = None): + called["mode"] = mode + called["model"] = model + return fake_counter + + monkeypatch.setattr( + "astrbot.core.context_compaction_scheduler.create_token_counter", + _fake_create, + ) + + resolved = scheduler._resolve_token_counter(provider) + assert resolved is fake_counter + assert called["mode"] == "auto" + assert called["model"] == "gpt-4o" + + @pytest.mark.asyncio async def test_compact_one_conversation_dry_run_reports_skipped() -> None: scheduler = _build_scheduler({"periodic_context_compaction": {"enabled": True}}) @@ -287,11 +310,13 @@ async def test_compact_one_conversation_dry_run_reports_skipped() -> None: token_usage=0, updated_at=None, ) - scheduler._check_eligibility = lambda _conv, _cfg: ( # type: ignore[method-assign] + scheduler._check_eligibility = lambda _conv, _cfg, _counter: ( # type: ignore[method-assign] [Message(role="user", content="before")], 100, ) - scheduler._resolve_provider = AsyncMock(return_value=object()) # type: ignore[method-assign] + scheduler._resolve_provider = AsyncMock( # type: ignore[method-assign] + return_value=SimpleNamespace(get_model=lambda: "gpt-4o") + ) scheduler._run_compaction_rounds = AsyncMock( # type: ignore[method-assign] return_value=SimpleNamespace( messages=[Message(role="user", content="after")], @@ -300,7 +325,9 @@ async def test_compact_one_conversation_dry_run_reports_skipped() -> None: ) ) scheduler._resolve_trigger_tokens = lambda _cfg, _provider: 1 # type: ignore[method-assign] - scheduler._token_counter = SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 50) + scheduler._resolve_token_counter = lambda _provider: SimpleNamespace( # type: ignore[method-assign] + count_tokens=lambda *_args, **_kwargs: 50 + ) scheduler._persist_compacted_history = AsyncMock( # type: ignore[method-assign] return_value=True ) From 022f6f6523bd0405f2408649d567efc99cd60ff3 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 23:21:16 -0500 Subject: [PATCH 16/29] feat(context): add post-tool compaction policy and prompt assembly router --- .../agent/runners/tool_loop_agent_runner.py | 108 +++++++++++++++++- astrbot/core/astr_main_agent.py | 38 ++++-- astrbot/core/config/default.py | 65 +++++++++++ .../method/agent_sub_stages/internal.py | 32 ++++++ astrbot/core/prompt_assembly_router.py | 77 +++++++++++++ tests/test_tool_loop_agent_runner.py | 63 ++++++++++ tests/unit/test_astr_main_agent.py | 41 +++++++ tests/unit/test_prompt_assembly_router.py | 23 ++++ 8 files changed, 436 insertions(+), 11 deletions(-) create mode 100644 astrbot/core/prompt_assembly_router.py create mode 100644 tests/unit/test_prompt_assembly_router.py diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index e6964c57a2..2fe3f89d17 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -108,6 +108,12 @@ async def reset( token_counter_mode: str = "estimate", # run context compression immediately after tool execution compact_context_after_tool_call: bool = False, + # post-tool-call compaction policy + compact_context_soft_ratio: float = 0.3, + compact_context_hard_ratio: float = 0.7, + compact_context_min_delta_tokens: int = 0, + compact_context_min_delta_turns: int = 0, + compact_context_debounce_seconds: int = 0, # customize custom_token_counter: TokenCounter | None = None, custom_compressor: ContextCompressor | None = None, @@ -124,6 +130,25 @@ async def reset( self.truncate_turns = truncate_turns self.token_counter_mode = token_counter_mode self.compact_context_after_tool_call = compact_context_after_tool_call + self.compact_context_soft_ratio = self._normalize_ratio( + compact_context_soft_ratio, default=0.3 + ) + self.compact_context_hard_ratio = max( + self.compact_context_soft_ratio, + self._normalize_ratio(compact_context_hard_ratio, default=0.7), + ) + self.compact_context_min_delta_tokens = self._to_non_negative_int( + compact_context_min_delta_tokens + ) + self.compact_context_min_delta_turns = self._to_non_negative_int( + compact_context_min_delta_turns + ) + self.compact_context_debounce_seconds = self._to_non_negative_int( + compact_context_debounce_seconds + ) + self._last_tool_compaction_check_at = 0.0 + self._tool_compaction_baseline_tokens = 0 + self._tool_compaction_baseline_messages = 0 self.custom_token_counter = custom_token_counter self.custom_compressor = custom_compressor # we will do compress when: @@ -203,10 +228,73 @@ async def reset( Message(role="system", content=request.system_prompt), ) self.run_context.messages = messages + self._refresh_tool_compaction_baseline( + trusted_token_usage=request.conversation.token_usage if request.conversation else 0 + ) self.stats = AgentStats() self.stats.start_time = time.time() + @staticmethod + def _normalize_ratio(value: float, *, default: float) -> float: + try: + ratio = float(value) + except Exception: + ratio = default + if ratio > 1.0 and ratio <= 100.0: + ratio = ratio / 100.0 + return min(max(ratio, 0.0), 1.0) + + @staticmethod + def _to_non_negative_int(value: int) -> int: + try: + return max(0, int(value)) + except Exception: + return 0 + + def _refresh_tool_compaction_baseline(self, *, trusted_token_usage: int = 0) -> None: + try: + self._tool_compaction_baseline_tokens = ( + self.context_manager.token_counter.count_tokens( + self.run_context.messages, + trusted_token_usage, + ) + ) + except Exception: + self._tool_compaction_baseline_tokens = 0 + self._tool_compaction_baseline_messages = len(self.run_context.messages) + + def _should_run_post_tool_compaction(self) -> bool: + if not self.compact_context_after_tool_call: + return False + max_context_tokens = int(self.context_config.max_context_tokens or 0) + if max_context_tokens <= 0: + # No explicit token budget configured: preserve legacy behavior. + return True + + current_tokens = self.context_manager.token_counter.count_tokens( + self.run_context.messages + ) + current_messages = len(self.run_context.messages) + current_ratio = current_tokens / max(1, max_context_tokens) + + # Hard threshold: force compaction immediately. + if current_ratio >= self.compact_context_hard_ratio: + return True + + # Soft threshold: only compact when context has grown enough. + if current_ratio < self.compact_context_soft_ratio: + return False + + delta_tokens = max(0, current_tokens - self._tool_compaction_baseline_tokens) + delta_messages = max(0, current_messages - self._tool_compaction_baseline_messages) + if ( + delta_tokens < self.compact_context_min_delta_tokens + and delta_messages < self.compact_context_min_delta_turns + ): + return False + return True + async def _iter_llm_responses( self, *, include_model: bool = True ) -> T.AsyncGenerator[LLMResponse, None]: @@ -377,6 +465,7 @@ async def step(self): self.run_context.messages = await self.context_manager.process( self.run_context.messages, trusted_token_usage=token_usage ) + self._refresh_tool_compaction_baseline(trusted_token_usage=token_usage) self._simple_print_message_role("[AftCompact]") async for llm_response in self._iter_llm_responses_with_fallback(): @@ -627,9 +716,22 @@ async def step(self): self.req.append_tool_calls_result(tool_calls_result) if self.compact_context_after_tool_call: - self.run_context.messages = await self.context_manager.process( - self.run_context.messages, - ) + now = time.monotonic() + if ( + self.compact_context_debounce_seconds > 0 + and self._last_tool_compaction_check_at > 0 + and ( + now - self._last_tool_compaction_check_at + < self.compact_context_debounce_seconds + ) + ): + pass + elif self._should_run_post_tool_compaction(): + self.run_context.messages = await self.context_manager.process( + self.run_context.messages, + ) + self._refresh_tool_compaction_baseline() + self._last_tool_compaction_check_at = now async def step_until_done( self, max_step: int diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index c338846bc5..199f2a0328 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -61,6 +61,7 @@ set_persona_custom_error_message_on_event, ) from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.prompt_assembly_router import assemble_system_prompt from astrbot.core.provider import Provider from astrbot.core.provider.entities import ProviderRequest from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt @@ -127,6 +128,16 @@ class MainAgentBuildConfig: """Token counting mode for context compaction: estimate, tokenizer, auto.""" compact_context_after_tool_call: bool = False """Whether to run context compaction check immediately after tool execution.""" + compact_context_soft_ratio: float = 0.3 + """Soft trigger threshold for post-tool-call context compaction.""" + compact_context_hard_ratio: float = 0.7 + """Hard trigger threshold for post-tool-call context compaction.""" + compact_context_min_delta_tokens: int = 0 + """Minimum token growth required before post-tool-call compaction runs in soft zone.""" + compact_context_min_delta_turns: int = 0 + """Minimum message growth required before post-tool-call compaction runs in soft zone.""" + compact_context_debounce_seconds: int = 0 + """Debounce window for post-tool-call compaction checks.""" max_context_length: int = -1 """The maximum number of turns to keep in context. -1 means no limit. This enforce max turns before compression""" @@ -626,6 +637,7 @@ def _append_system_reminders( def _inject_context_memory( + event: AstrMessageEvent, req: ProviderRequest, cfg: dict, ) -> None: @@ -638,13 +650,18 @@ def _inject_context_memory( return cm_cfg = load_context_memory_config(cfg) memory_block = build_pinned_memory_system_block(cm_cfg) - if not memory_block: - return - - if req.system_prompt and req.system_prompt.strip(): - req.system_prompt = f"{req.system_prompt}\n\n{memory_block}" - else: - req.system_prompt = memory_block + retrieved_facts = event.get_extra("retrieved_long_term_facts") + summarized_history = event.get_extra("compacted_history_summary") + req.system_prompt = assemble_system_prompt( + base_system_prompt=req.system_prompt or "", + retrieved_long_term_facts=retrieved_facts + if isinstance(retrieved_facts, list) + else None, + summarized_history=summarized_history + if isinstance(summarized_history, str) + else "", + pinned_memory_block=memory_block, + ) async def _decorate_llm_request( @@ -685,7 +702,7 @@ async def _decorate_llm_request( if tz is None: tz = plugin_context.get_config().get("timezone") _append_system_reminders(event, req, cfg, tz) - _inject_context_memory(req, cfg) + _inject_context_memory(event, req, cfg) def _modalities_fix(provider: Provider, req: ProviderRequest) -> None: @@ -1236,6 +1253,11 @@ async def build_main_agent( llm_compress_provider=_get_compress_provider(config, plugin_context), token_counter_mode=config.context_token_counter_mode, compact_context_after_tool_call=config.compact_context_after_tool_call, + compact_context_soft_ratio=config.compact_context_soft_ratio, + compact_context_hard_ratio=config.compact_context_hard_ratio, + compact_context_min_delta_tokens=config.compact_context_min_delta_tokens, + compact_context_min_delta_turns=config.compact_context_min_delta_turns, + compact_context_debounce_seconds=config.compact_context_debounce_seconds, truncate_turns=config.dequeue_context_length, enforce_max_turns=config.max_context_length, tool_schema_mode=config.tool_schema_mode, diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index b918d3bed7..3af2ddfc15 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -131,6 +131,11 @@ "llm_compress_provider_id": "", "context_token_counter_mode": "estimate", "compact_context_after_tool_call": False, + "compact_context_soft_ratio": 0.3, + "compact_context_hard_ratio": 0.7, + "compact_context_min_delta_tokens": 0, + "compact_context_min_delta_turns": 0, + "compact_context_debounce_seconds": 0, "periodic_context_compaction": dict(PERIODIC_CONTEXT_COMPACTION_DEFAULTS), "context_memory": dict(CONTEXT_MEMORY_DEFAULTS), "max_context_length": -1, @@ -2552,6 +2557,21 @@ class ChatProviderTemplate(TypedDict): "compact_context_after_tool_call": { "type": "bool", }, + "compact_context_soft_ratio": { + "type": "float", + }, + "compact_context_hard_ratio": { + "type": "float", + }, + "compact_context_min_delta_tokens": { + "type": "int", + }, + "compact_context_min_delta_turns": { + "type": "int", + }, + "compact_context_debounce_seconds": { + "type": "int", + }, "periodic_context_compaction": { "type": "object", "properties": { @@ -3346,6 +3366,51 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, + "provider_settings.compact_context_soft_ratio": { + "description": "工具后压缩软阈值", + "type": "float", + "hint": "当上下文占比达到该阈值时,按“最小增长量”规则决定是否压缩。支持填写 0~1 或 0~100(百分比)。", + "condition": { + "provider_settings.compact_context_after_tool_call": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.compact_context_hard_ratio": { + "description": "工具后压缩硬阈值", + "type": "float", + "hint": "当上下文占比达到该阈值时,强制执行一次压缩。支持填写 0~1 或 0~100(百分比)。", + "condition": { + "provider_settings.compact_context_after_tool_call": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.compact_context_min_delta_tokens": { + "description": "工具后最小 Token 增长", + "type": "int", + "hint": "在软阈值区间内,Token 增长低于该值时不触发压缩。0 表示不限制。", + "condition": { + "provider_settings.compact_context_after_tool_call": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.compact_context_min_delta_turns": { + "description": "工具后最小消息增长", + "type": "int", + "hint": "在软阈值区间内,消息增长低于该值时不触发压缩。0 表示不限制。", + "condition": { + "provider_settings.compact_context_after_tool_call": True, + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.compact_context_debounce_seconds": { + "description": "工具后压缩防抖(秒)", + "type": "int", + "hint": "两次工具后压缩检查的最小间隔秒数。0 表示关闭防抖。", + "condition": { + "provider_settings.compact_context_after_tool_call": True, + "provider_settings.agent_runner_type": "local", + }, + }, "provider_settings.periodic_context_compaction.enabled": { "description": "启用定时历史压缩", "type": "bool", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 614c39a06e..ca988ff01e 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -98,6 +98,33 @@ async def initialize(self, ctx: PipelineContext) -> None: "compact_context_after_tool_call", False, ) + def _safe_float(value, default: float) -> float: + try: + return float(value) + except Exception: + return default + + def _safe_int(value, default: int) -> int: + try: + return int(value) + except Exception: + return default + + self.compact_context_soft_ratio: float = _safe_float( + settings.get("compact_context_soft_ratio", 0.3), 0.3 + ) + self.compact_context_hard_ratio: float = _safe_float( + settings.get("compact_context_hard_ratio", 0.7), 0.7 + ) + self.compact_context_min_delta_tokens: int = _safe_int( + settings.get("compact_context_min_delta_tokens", 0), 0 + ) + self.compact_context_min_delta_turns: int = _safe_int( + settings.get("compact_context_min_delta_turns", 0), 0 + ) + self.compact_context_debounce_seconds: int = _safe_int( + settings.get("compact_context_debounce_seconds", 0), 0 + ) self.max_context_length = settings["max_context_length"] # int self.dequeue_context_length: int = min( max(1, settings["dequeue_context_length"]), @@ -134,6 +161,11 @@ async def initialize(self, ctx: PipelineContext) -> None: llm_compress_provider_id=self.llm_compress_provider_id, context_token_counter_mode=self.context_token_counter_mode, compact_context_after_tool_call=self.compact_context_after_tool_call, + compact_context_soft_ratio=self.compact_context_soft_ratio, + compact_context_hard_ratio=self.compact_context_hard_ratio, + compact_context_min_delta_tokens=self.compact_context_min_delta_tokens, + compact_context_min_delta_turns=self.compact_context_min_delta_turns, + compact_context_debounce_seconds=self.compact_context_debounce_seconds, max_context_length=self.max_context_length, dequeue_context_length=self.dequeue_context_length, llm_safety_mode=self.llm_safety_mode, diff --git a/astrbot/core/prompt_assembly_router.py b/astrbot/core/prompt_assembly_router.py new file mode 100644 index 0000000000..5936996fb0 --- /dev/null +++ b/astrbot/core/prompt_assembly_router.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from collections.abc import Iterable + + +def _normalize_text_items(items: Iterable[object] | None) -> list[str]: + normalized: list[str] = [] + if not items: + return normalized + for item in items: + text = str(item or "").strip() + if text: + normalized.append(text) + return normalized + + +def _render_long_term_facts_block(facts: list[str]) -> str: + if not facts: + return "" + lines = [ + "", + "Use these retrieved long-term facts when relevant:", + ] + for idx, fact in enumerate(facts, start=1): + lines.append(f"{idx}. {fact}") + lines.append("") + return "\n".join(lines) + + +def _render_summarized_history_block(summary: str) -> str: + text = str(summary or "").strip() + if not text: + return "" + return "\n".join( + [ + "", + text, + "", + ] + ) + + +def assemble_system_prompt( + *, + base_system_prompt: str, + retrieved_long_term_facts: Iterable[object] | None = None, + summarized_history: str = "", + pinned_memory_block: str = "", +) -> str: + """Assemble final system prompt with stable section ordering. + + Section order: + 1) Base system prompt + 2) Retrieved long-term facts + 3) Summarized history + 4) Pinned top-level memory + """ + sections: list[str] = [] + base = str(base_system_prompt or "").strip() + if base: + sections.append(base) + + facts_block = _render_long_term_facts_block( + _normalize_text_items(retrieved_long_term_facts) + ) + if facts_block: + sections.append(facts_block) + + summary_block = _render_summarized_history_block(summarized_history) + if summary_block: + sections.append(summary_block) + + pinned = str(pinned_memory_block or "").strip() + if pinned: + sections.append(pinned) + + return "\n\n".join(sections).strip() diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index ff6069b1d7..41bb0906a6 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -1,5 +1,6 @@ import os import sys +from types import SimpleNamespace from unittest.mock import AsyncMock import pytest @@ -585,6 +586,68 @@ async def test_compact_context_after_tool_call_disabled_by_default( assert runner.context_manager.process.await_count == 1 +@pytest.mark.asyncio +async def test_compact_context_after_tool_call_honors_debounce( + runner, mock_provider, provider_request, mock_tool_executor, mock_hooks +): + mock_provider.should_call_tools = True + mock_provider.provider_config["max_context_tokens"] = 100 + await runner.reset( + provider=mock_provider, + request=provider_request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + compact_context_after_tool_call=True, + compact_context_soft_ratio=0.3, + compact_context_hard_ratio=0.4, + compact_context_debounce_seconds=3600, + ) + + runner.context_manager.token_counter = SimpleNamespace( + count_tokens=lambda *_args, **_kwargs: 90 + ) + runner.context_manager.process = AsyncMock( # type: ignore[method-assign] + side_effect=lambda messages, trusted_token_usage=0: messages, + ) + + # step 1: pre-LLM compact + post-tool compact + async for _ in runner.step(): + pass + # step 2: pre-LLM compact + post-tool compact skipped by debounce + async for _ in runner.step(): + pass + + assert runner.context_manager.process.await_count == 3 + + +def test_post_tool_compaction_soft_zone_respects_min_delta(runner): + # initialize attributes by calling reset in event loop would be overkill for this pure helper; + # emulate the required fields directly. + runner.compact_context_after_tool_call = True + runner.compact_context_soft_ratio = 0.3 + runner.compact_context_hard_ratio = 0.9 + runner.compact_context_min_delta_tokens = 10 + runner.compact_context_min_delta_turns = 10 + runner.context_config = SimpleNamespace(max_context_tokens=100) + runner.run_context = SimpleNamespace(messages=[object(), object()]) + runner.context_manager = SimpleNamespace( + token_counter=SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 35) + ) + runner._tool_compaction_baseline_tokens = 30 + runner._tool_compaction_baseline_messages = 2 + + # ratio=0.35 in soft zone, token delta=5 and message delta=0 -> should skip + assert runner._should_run_post_tool_compaction() is False + + runner.context_manager = SimpleNamespace( + token_counter=SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 95) + ) + # ratio=0.95 in hard zone -> force compaction + assert runner._should_run_post_tool_compaction() is True + + if __name__ == "__main__": # 运行测试 pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index b453e34745..40c393ad96 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -716,6 +716,47 @@ async def test_decorate_llm_request_skips_pinned_memory_when_disabled( assert "" not in req.system_prompt + @pytest.mark.asyncio + async def test_decorate_llm_request_assembles_prompt_layers_in_order( + self, mock_event, mock_context + ): + module = ama + req = ProviderRequest(prompt="Hello", system_prompt="System") + config = module.MainAgentBuildConfig( + tool_call_timeout=60, + provider_settings={ + "context_memory": { + "enabled": True, + "inject_pinned_memory": True, + "pinned_memories": ["固定偏好A"], + } + }, + ) + + extras = { + "retrieved_long_term_facts": ["事实1", "事实2"], + "compacted_history_summary": "历史摘要S", + } + mock_event.get_extra.side_effect = lambda key: extras.get(key) + + with patch.object(mock_context, "get_config") as mock_get_config: + mock_get_config.return_value = {} + await module._decorate_llm_request(mock_event, req, mock_context, config) + + system_prompt = req.system_prompt + assert "System" in system_prompt + assert "" in system_prompt + assert "" in system_prompt + assert "" in system_prompt + # ordering: base -> long_term_facts -> summarized_history -> top_level_memory + assert system_prompt.index("System") < system_prompt.index("") + assert system_prompt.index("") < system_prompt.index( + "" + ) + assert system_prompt.index("") < system_prompt.index( + "" + ) + class TestModalitiesFix: """Tests for _modalities_fix function.""" diff --git a/tests/unit/test_prompt_assembly_router.py b/tests/unit/test_prompt_assembly_router.py new file mode 100644 index 0000000000..eaa46403d2 --- /dev/null +++ b/tests/unit/test_prompt_assembly_router.py @@ -0,0 +1,23 @@ +from astrbot.core.prompt_assembly_router import assemble_system_prompt + + +def test_assemble_system_prompt_omits_empty_sections() -> None: + prompt = assemble_system_prompt(base_system_prompt="Base") + assert prompt == "Base" + + +def test_assemble_system_prompt_orders_sections() -> None: + prompt = assemble_system_prompt( + base_system_prompt="Base", + retrieved_long_term_facts=["Fact A", "Fact B"], + summarized_history="Summary C", + pinned_memory_block="\nPinned D\n", + ) + + assert "Base" in prompt + assert "" in prompt + assert "" in prompt + assert "" in prompt + assert prompt.index("Base") < prompt.index("") + assert prompt.index("") < prompt.index("") + assert prompt.index("") < prompt.index("") From 1219993528d5238cd5715b8d6aba035d8f473ec7 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 23:27:23 -0500 Subject: [PATCH 17/29] fix latest review issues for context compaction --- .../commands/context_memory.py | 16 ++++++- astrbot/core/agent/context/token_counter.py | 7 ++- .../agent/runners/tool_loop_agent_runner.py | 9 ++-- tests/test_tool_loop_agent_runner.py | 19 ++++++++ tests/unit/test_context_memory_command.py | 46 +++++++++++++++++++ 5 files changed, 90 insertions(+), 7 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/context_memory.py b/astrbot/builtin_stars/builtin_commands/commands/context_memory.py index f200fe65f1..8bb923fdc2 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/context_memory.py +++ b/astrbot/builtin_stars/builtin_commands/commands/context_memory.py @@ -6,6 +6,8 @@ from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.core.context_memory import ensure_context_memory_settings +PINNED_PREVIEW_MAX_CHARS = 180 + class ContextMemoryCommands: def __init__(self, context: star.Context) -> None: @@ -69,11 +71,21 @@ async def ls(self, event: AstrMessageEvent) -> None: await event.send(MessageChain().message("当前没有手动顶层记忆。")) return + configured_max_chars = cm_cfg.get("pinned_max_chars_per_item", 400) + try: + configured_max_chars = int(configured_max_chars) + except Exception: + configured_max_chars = 400 + preview_max_chars = min( + max(1, configured_max_chars), + PINNED_PREVIEW_MAX_CHARS, + ) + lines = ["手动顶层记忆列表:"] for idx, text in enumerate(pinned, start=1): text_str = str(text) - if len(text_str) > 180: - text_str = text_str[:180] + "..." + if len(text_str) > preview_max_chars: + text_str = text_str[:preview_max_chars] + "..." lines.append(f"{idx}. {text_str}") await event.send(MessageChain().message("\n".join(lines))) diff --git a/astrbot/core/agent/context/token_counter.py b/astrbot/core/agent/context/token_counter.py index 2baca2775e..ca9f886a85 100644 --- a/astrbot/core/agent/context/token_counter.py +++ b/astrbot/core/agent/context/token_counter.py @@ -80,6 +80,9 @@ def _estimate_tokens(self, text: str) -> int: other_count = len(text) - chinese_count return int(chinese_count * 0.6 + other_count * 0.3) + def estimate_text_tokens(self, text: str) -> int: + return self._estimate_tokens(text) + class TokenizerTokenCounter: """Tokenizer-based token counter. @@ -159,11 +162,11 @@ def count_tokens( def _encode_len(self, text: str) -> int: if not self._encode: - return self._estimate._estimate_tokens(text) + return self._estimate.estimate_text_tokens(text) try: return self._encode(text) except Exception: - return self._estimate._estimate_tokens(text) + return self._estimate.estimate_text_tokens(text) def create_token_counter( diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 2fe3f89d17..b8ce1bf149 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -272,9 +272,12 @@ def _should_run_post_tool_compaction(self) -> bool: # No explicit token budget configured: preserve legacy behavior. return True - current_tokens = self.context_manager.token_counter.count_tokens( - self.run_context.messages - ) + try: + current_tokens = self.context_manager.token_counter.count_tokens( + self.run_context.messages + ) + except Exception: + return False current_messages = len(self.run_context.messages) current_ratio = current_tokens / max(1, max_context_tokens) diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 41bb0906a6..9085d7c56e 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -648,6 +648,25 @@ def test_post_tool_compaction_soft_zone_respects_min_delta(runner): assert runner._should_run_post_tool_compaction() is True +def test_post_tool_compaction_handles_token_counter_errors(runner): + runner.compact_context_after_tool_call = True + runner.compact_context_soft_ratio = 0.3 + runner.compact_context_hard_ratio = 0.9 + runner.compact_context_min_delta_tokens = 10 + runner.compact_context_min_delta_turns = 10 + runner.context_config = SimpleNamespace(max_context_tokens=100) + runner.run_context = SimpleNamespace(messages=[object(), object()]) + + def _raise(*_args, **_kwargs): + raise RuntimeError("counter broken") + + runner.context_manager = SimpleNamespace( + token_counter=SimpleNamespace(count_tokens=_raise) + ) + + assert runner._should_run_post_tool_compaction() is False + + if __name__ == "__main__": # 运行测试 pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_context_memory_command.py b/tests/unit/test_context_memory_command.py index 955e3f6905..8dfe85aa51 100644 --- a/tests/unit/test_context_memory_command.py +++ b/tests/unit/test_context_memory_command.py @@ -134,3 +134,49 @@ async def test_add_truncates_long_memory_item() -> None: assert stored == ["1234567890"] text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) assert "已截断到 10 字符" in text + + +@pytest.mark.asyncio +async def test_ls_preview_length_uses_config_limit() -> None: + cfg = DummyConfig( + { + "provider_settings": { + "context_memory": { + "pinned_memories": ["abcdefghijklmno"], + "pinned_max_items": 3, + "pinned_max_chars_per_item": 10, + } + } + } + ) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.ls(event) + + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert "1. abcdefghij" in text + assert "1. abcdefghijk" not in text + + +@pytest.mark.asyncio +async def test_ls_preview_length_is_capped_by_display_limit() -> None: + long_text = "x" * 250 + cfg = DummyConfig( + { + "provider_settings": { + "context_memory": { + "pinned_memories": [long_text], + "pinned_max_items": 3, + "pinned_max_chars_per_item": 999, + } + } + } + ) + command = _build_command_with_cfg(cfg) + event = _build_event() + + await command.ls(event) + + text = event.send.await_args.args[0].get_plain_text(with_other_comps_mark=True) + assert ("1. " + ("x" * 180) + "...") in text From 667ead22b40aea64502cfe624f63db05db540bb2 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Thu, 19 Mar 2026 23:36:41 -0500 Subject: [PATCH 18/29] refactor post-tool compaction policy and centralize parsing --- .../agent/runners/tool_loop_agent_runner.py | 202 ++++++++++-------- astrbot/core/context_compaction_scheduler.py | 63 ++---- astrbot/core/context_memory.py | 37 +--- .../method/agent_sub_stages/internal.py | 38 ++-- astrbot/core/utils/config_normalization.py | 41 ++++ tests/test_tool_loop_agent_runner.py | 46 ++-- 6 files changed, 221 insertions(+), 206 deletions(-) create mode 100644 astrbot/core/utils/config_normalization.py diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index b8ce1bf149..9879d2b878 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -32,6 +32,7 @@ ToolCallsResult, ) from astrbot.core.provider.provider import Provider +from astrbot.core.utils.config_normalization import to_non_negative_int, to_ratio from ..context.compressor import ContextCompressor from ..context.config import ContextConfig @@ -80,6 +81,86 @@ class FollowUpTicket: resolved: asyncio.Event = field(default_factory=asyncio.Event) +@dataclass(slots=True, frozen=True) +class PostToolCompactionConfig: + enabled: bool = False + soft_ratio: float = 0.3 + hard_ratio: float = 0.7 + min_delta_tokens: int = 0 + min_delta_turns: int = 0 + debounce_seconds: int = 0 + + +class PostToolCompactionController: + def __init__(self, config: PostToolCompactionConfig) -> None: + self.config = config + self._baseline_tokens = 0 + self._baseline_messages = 0 + self._last_check_at = 0.0 + + def refresh_baseline( + self, + *, + messages: list[Message], + token_counter: TokenCounter, + trusted_token_usage: int = 0, + ) -> None: + try: + self._baseline_tokens = token_counter.count_tokens( + messages, + trusted_token_usage, + ) + except Exception: + self._baseline_tokens = 0 + self._baseline_messages = len(messages) + + def should_compact( + self, + *, + messages: list[Message], + token_counter: TokenCounter, + max_context_tokens: int, + ) -> bool: + if not self.config.enabled: + return False + + now = time.monotonic() + if ( + self.config.debounce_seconds > 0 + and self._last_check_at > 0 + and (now - self._last_check_at) < self.config.debounce_seconds + ): + self._last_check_at = now + return False + self._last_check_at = now + + if max_context_tokens <= 0: + # No explicit token budget configured: preserve legacy behavior. + return True + + try: + current_tokens = token_counter.count_tokens(messages) + except Exception: + return False + + current_messages = len(messages) + current_ratio = current_tokens / max(1, max_context_tokens) + + if current_ratio >= self.config.hard_ratio: + return True + if current_ratio < self.config.soft_ratio: + return False + + delta_tokens = max(0, current_tokens - self._baseline_tokens) + delta_messages = max(0, current_messages - self._baseline_messages) + if ( + delta_tokens < self.config.min_delta_tokens + and delta_messages < self.config.min_delta_turns + ): + return False + return True + + class ToolLoopAgentRunner(BaseAgentRunner[TContext]): def _get_persona_custom_error_message(self) -> str | None: """Read persona-level custom error message from event extras when available.""" @@ -129,26 +210,20 @@ async def reset( self.llm_compress_provider = llm_compress_provider self.truncate_turns = truncate_turns self.token_counter_mode = token_counter_mode - self.compact_context_after_tool_call = compact_context_after_tool_call - self.compact_context_soft_ratio = self._normalize_ratio( - compact_context_soft_ratio, default=0.3 - ) - self.compact_context_hard_ratio = max( - self.compact_context_soft_ratio, - self._normalize_ratio(compact_context_hard_ratio, default=0.7), - ) - self.compact_context_min_delta_tokens = self._to_non_negative_int( - compact_context_min_delta_tokens + post_tool_soft_ratio = to_ratio(compact_context_soft_ratio, 0.3) + self.post_tool_compaction = PostToolCompactionConfig( + enabled=bool(compact_context_after_tool_call), + soft_ratio=post_tool_soft_ratio, + hard_ratio=max(post_tool_soft_ratio, to_ratio(compact_context_hard_ratio, 0.7)), + min_delta_tokens=to_non_negative_int(compact_context_min_delta_tokens), + min_delta_turns=to_non_negative_int(compact_context_min_delta_turns), + debounce_seconds=to_non_negative_int( + compact_context_debounce_seconds + ), ) - self.compact_context_min_delta_turns = self._to_non_negative_int( - compact_context_min_delta_turns + self.post_tool_compaction_controller = PostToolCompactionController( + self.post_tool_compaction ) - self.compact_context_debounce_seconds = self._to_non_negative_int( - compact_context_debounce_seconds - ) - self._last_tool_compaction_check_at = 0.0 - self._tool_compaction_baseline_tokens = 0 - self._tool_compaction_baseline_messages = 0 self.custom_token_counter = custom_token_counter self.custom_compressor = custom_compressor # we will do compress when: @@ -235,68 +310,21 @@ async def reset( self.stats = AgentStats() self.stats.start_time = time.time() - @staticmethod - def _normalize_ratio(value: float, *, default: float) -> float: - try: - ratio = float(value) - except Exception: - ratio = default - if ratio > 1.0 and ratio <= 100.0: - ratio = ratio / 100.0 - return min(max(ratio, 0.0), 1.0) - - @staticmethod - def _to_non_negative_int(value: int) -> int: - try: - return max(0, int(value)) - except Exception: - return 0 - def _refresh_tool_compaction_baseline(self, *, trusted_token_usage: int = 0) -> None: - try: - self._tool_compaction_baseline_tokens = ( - self.context_manager.token_counter.count_tokens( - self.run_context.messages, - trusted_token_usage, - ) - ) - except Exception: - self._tool_compaction_baseline_tokens = 0 - self._tool_compaction_baseline_messages = len(self.run_context.messages) + self.post_tool_compaction_controller.refresh_baseline( + messages=self.run_context.messages, + token_counter=self.context_manager.token_counter, + trusted_token_usage=trusted_token_usage, + ) def _should_run_post_tool_compaction(self) -> bool: - if not self.compact_context_after_tool_call: - return False - max_context_tokens = int(self.context_config.max_context_tokens or 0) - if max_context_tokens <= 0: - # No explicit token budget configured: preserve legacy behavior. - return True - - try: - current_tokens = self.context_manager.token_counter.count_tokens( - self.run_context.messages - ) - except Exception: - return False - current_messages = len(self.run_context.messages) - current_ratio = current_tokens / max(1, max_context_tokens) - - # Hard threshold: force compaction immediately. - if current_ratio >= self.compact_context_hard_ratio: - return True - - # Soft threshold: only compact when context has grown enough. - if current_ratio < self.compact_context_soft_ratio: + if not hasattr(self, "post_tool_compaction_controller"): return False - - delta_tokens = max(0, current_tokens - self._tool_compaction_baseline_tokens) - delta_messages = max(0, current_messages - self._tool_compaction_baseline_messages) - if ( - delta_tokens < self.compact_context_min_delta_tokens - and delta_messages < self.compact_context_min_delta_turns - ): - return False - return True + return self.post_tool_compaction_controller.should_compact( + messages=self.run_context.messages, + token_counter=self.context_manager.token_counter, + max_context_tokens=int(self.context_config.max_context_tokens or 0), + ) async def _iter_llm_responses( self, *, include_model: bool = True @@ -718,23 +746,11 @@ async def step(self): self.req.append_tool_calls_result(tool_calls_result) - if self.compact_context_after_tool_call: - now = time.monotonic() - if ( - self.compact_context_debounce_seconds > 0 - and self._last_tool_compaction_check_at > 0 - and ( - now - self._last_tool_compaction_check_at - < self.compact_context_debounce_seconds - ) - ): - pass - elif self._should_run_post_tool_compaction(): - self.run_context.messages = await self.context_manager.process( - self.run_context.messages, - ) - self._refresh_tool_compaction_baseline() - self._last_tool_compaction_check_at = now + if self._should_run_post_tool_compaction(): + self.run_context.messages = await self.context_manager.process( + self.run_context.messages, + ) + self._refresh_tool_compaction_baseline() async def step_until_done( self, max_step: int diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index 58c51d7e90..d1cbc6fe3c 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -23,6 +23,7 @@ from astrbot.core.db.po import ConversationV2 from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.provider import Provider +from astrbot.core.utils.config_normalization import to_bool, to_int, to_ratio from astrbot.core.utils.llm_metadata import LLM_METADATAS if TYPE_CHECKING: @@ -89,69 +90,37 @@ def from_default_conf( cfg = dict(defaults) cfg.update(raw_cfg) - target_tokens = cls._to_int(cfg.get("target_tokens"), 4096, 512) - trigger_tokens = cls._to_int(cfg.get("trigger_tokens"), 0, 0) - trigger_min_context_ratio = cls._to_ratio( + target_tokens = to_int(cfg.get("target_tokens"), 4096, 512) + trigger_tokens = to_int(cfg.get("trigger_tokens"), 0, 0) + trigger_min_context_ratio = to_ratio( cfg.get("trigger_min_context_ratio"), 0.3, ) return cls( - enabled=cls._to_bool(cfg.get("enabled"), False), - interval_minutes=cls._to_int(cfg.get("interval_minutes"), 30, 1), - startup_delay_seconds=cls._to_int(cfg.get("startup_delay_seconds"), 120, 0), - max_conversations_per_run=cls._to_int( + enabled=to_bool(cfg.get("enabled"), False), + interval_minutes=to_int(cfg.get("interval_minutes"), 30, 1), + startup_delay_seconds=to_int(cfg.get("startup_delay_seconds"), 120, 0), + max_conversations_per_run=to_int( cfg.get("max_conversations_per_run"), 8, 1, ), - max_scan_per_run=cls._to_int(cfg.get("max_scan_per_run"), 120, 1), - scan_page_size=cls._to_int(cfg.get("scan_page_size"), 40, 10), - min_idle_minutes=cls._to_int(cfg.get("min_idle_minutes"), 15, 0), - min_messages=cls._to_int(cfg.get("min_messages"), 14, 2), + max_scan_per_run=to_int(cfg.get("max_scan_per_run"), 120, 1), + scan_page_size=to_int(cfg.get("scan_page_size"), 40, 10), + min_idle_minutes=to_int(cfg.get("min_idle_minutes"), 15, 0), + min_messages=to_int(cfg.get("min_messages"), 14, 2), target_tokens=target_tokens, trigger_tokens=trigger_tokens, trigger_min_context_ratio=trigger_min_context_ratio, - max_rounds=cls._to_int(cfg.get("max_rounds"), 3, 1), - truncate_turns=cls._to_int(cfg.get("truncate_turns"), 1, 1), - keep_recent=cls._to_int(cfg.get("keep_recent"), 6, 0), + max_rounds=to_int(cfg.get("max_rounds"), 3, 1), + truncate_turns=to_int(cfg.get("truncate_turns"), 1, 1), + keep_recent=to_int(cfg.get("keep_recent"), 6, 0), provider_id=str(cfg.get("provider_id", "") or "").strip(), instruction=str(cfg.get("instruction", "") or "").strip(), - dry_run=cls._to_bool(cfg.get("dry_run"), False), + dry_run=to_bool(cfg.get("dry_run"), False), ) - @staticmethod - def _to_bool(value: Any, default: bool) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - lowered = value.strip().lower() - if lowered in {"1", "true", "yes", "on"}: - return True - if lowered in {"0", "false", "no", "off"}: - return False - return default - - @staticmethod - def _to_int(value: Any, default: int, min_value: int) -> int: - try: - parsed = int(value) - except Exception: - parsed = default - return max(parsed, min_value) - - @staticmethod - def _to_ratio(value: Any, default: float) -> float: - try: - parsed = float(value) - except Exception: - parsed = default - if parsed > 1.0 and parsed <= 100.0: - parsed = parsed / 100.0 - return min(max(parsed, 0.0), 1.0) - class PeriodicContextCompactionScheduler: """Periodically compact conversation history and persist summarized history back to DB. diff --git a/astrbot/core/context_memory.py b/astrbot/core/context_memory.py index 613f35a6c8..1a390d84c4 100644 --- a/astrbot/core/context_memory.py +++ b/astrbot/core/context_memory.py @@ -3,6 +3,8 @@ from dataclasses import dataclass, field from typing import Any, Protocol, runtime_checkable +from astrbot.core.utils.config_normalization import to_bool, to_int + DEFAULT_CONTEXT_MEMORY_SETTINGS: dict[str, Any] = { # Global switch for context-memory related features. "enabled": False, @@ -46,53 +48,30 @@ class ContextMemoryConfig: retrieval_provider_id: str = "" retrieval_top_k: int = 5 - -def _to_bool(value: Any, default: bool) -> bool: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - lowered = value.strip().lower() - if lowered in {"1", "true", "yes", "on"}: - return True - if lowered in {"0", "false", "no", "off"}: - return False - return default - - -def _to_int(value: Any, default: int, min_value: int) -> int: - try: - parsed = int(value) - except Exception: - parsed = default - return max(parsed, min_value) - - def normalize_context_memory_settings(raw: dict[str, Any] | None) -> dict[str, Any]: normalized = dict(DEFAULT_CONTEXT_MEMORY_SETTINGS) if not isinstance(raw, dict): return normalized - normalized["enabled"] = _to_bool( + normalized["enabled"] = to_bool( raw.get("enabled"), bool(DEFAULT_CONTEXT_MEMORY_SETTINGS["enabled"]), ) - normalized["inject_pinned_memory"] = _to_bool( + normalized["inject_pinned_memory"] = to_bool( raw.get("inject_pinned_memory"), bool(DEFAULT_CONTEXT_MEMORY_SETTINGS["inject_pinned_memory"]), ) - normalized["pinned_max_items"] = _to_int( + normalized["pinned_max_items"] = to_int( raw.get("pinned_max_items"), int(DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_max_items"]), 1, ) - normalized["pinned_max_chars_per_item"] = _to_int( + normalized["pinned_max_chars_per_item"] = to_int( raw.get("pinned_max_chars_per_item"), int(DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_max_chars_per_item"]), 1, ) - normalized["retrieval_enabled"] = _to_bool( + normalized["retrieval_enabled"] = to_bool( raw.get("retrieval_enabled"), bool(DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_enabled"]), ) @@ -100,7 +79,7 @@ def normalize_context_memory_settings(raw: dict[str, Any] | None) -> dict[str, A normalized["retrieval_provider_id"] = str( raw.get("retrieval_provider_id", "") or "" ).strip() - normalized["retrieval_top_k"] = _to_int( + normalized["retrieval_top_k"] = to_int( raw.get("retrieval_top_k"), int(DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_top_k"]), 1, diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index ca988ff01e..7b1d41eef6 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -29,6 +29,7 @@ ProviderRequest, ) from astrbot.core.star.star_handler import EventType +from astrbot.core.utils.config_normalization import to_non_negative_int, to_ratio from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager @@ -98,32 +99,25 @@ async def initialize(self, ctx: PipelineContext) -> None: "compact_context_after_tool_call", False, ) - def _safe_float(value, default: float) -> float: - try: - return float(value) - except Exception: - return default - - def _safe_int(value, default: int) -> int: - try: - return int(value) - except Exception: - return default - - self.compact_context_soft_ratio: float = _safe_float( - settings.get("compact_context_soft_ratio", 0.3), 0.3 + self.compact_context_soft_ratio: float = to_ratio( + settings.get("compact_context_soft_ratio", 0.3), + 0.3, ) - self.compact_context_hard_ratio: float = _safe_float( - settings.get("compact_context_hard_ratio", 0.7), 0.7 + self.compact_context_hard_ratio: float = to_ratio( + settings.get("compact_context_hard_ratio", 0.7), + 0.7, ) - self.compact_context_min_delta_tokens: int = _safe_int( - settings.get("compact_context_min_delta_tokens", 0), 0 + self.compact_context_min_delta_tokens: int = to_non_negative_int( + settings.get("compact_context_min_delta_tokens", 0), + 0, ) - self.compact_context_min_delta_turns: int = _safe_int( - settings.get("compact_context_min_delta_turns", 0), 0 + self.compact_context_min_delta_turns: int = to_non_negative_int( + settings.get("compact_context_min_delta_turns", 0), + 0, ) - self.compact_context_debounce_seconds: int = _safe_int( - settings.get("compact_context_debounce_seconds", 0), 0 + self.compact_context_debounce_seconds: int = to_non_negative_int( + settings.get("compact_context_debounce_seconds", 0), + 0, ) self.max_context_length = settings["max_context_length"] # int self.dequeue_context_length: int = min( diff --git a/astrbot/core/utils/config_normalization.py b/astrbot/core/utils/config_normalization.py new file mode 100644 index 0000000000..84faa19719 --- /dev/null +++ b/astrbot/core/utils/config_normalization.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Any + + +def to_bool(value: Any, default: bool) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + return True + if lowered in {"0", "false", "no", "off"}: + return False + return default + + +def to_int(value: Any, default: int, min_value: int | None = None) -> int: + try: + parsed = int(value) + except Exception: + parsed = default + if min_value is not None: + parsed = max(parsed, min_value) + return parsed + + +def to_non_negative_int(value: Any, default: int = 0) -> int: + return max(0, to_int(value, default)) + + +def to_ratio(value: Any, default: float) -> float: + try: + parsed = float(value) + except Exception: + parsed = default + if parsed > 1.0 and parsed <= 100.0: + parsed = parsed / 100.0 + return min(max(parsed, 0.0), 1.0) diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index 9085d7c56e..ad5c0c0c69 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -10,7 +10,11 @@ from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner +from astrbot.core.agent.runners.tool_loop_agent_runner import ( + PostToolCompactionConfig, + PostToolCompactionController, + ToolLoopAgentRunner, +) from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.provider.entities import LLMResponse, ProviderRequest, TokenUsage from astrbot.core.provider.provider import Provider @@ -623,20 +627,26 @@ async def test_compact_context_after_tool_call_honors_debounce( def test_post_tool_compaction_soft_zone_respects_min_delta(runner): - # initialize attributes by calling reset in event loop would be overkill for this pure helper; - # emulate the required fields directly. - runner.compact_context_after_tool_call = True - runner.compact_context_soft_ratio = 0.3 - runner.compact_context_hard_ratio = 0.9 - runner.compact_context_min_delta_tokens = 10 - runner.compact_context_min_delta_turns = 10 + runner.post_tool_compaction = PostToolCompactionConfig( + enabled=True, + soft_ratio=0.3, + hard_ratio=0.9, + min_delta_tokens=10, + min_delta_turns=10, + debounce_seconds=0, + ) + runner.post_tool_compaction_controller = PostToolCompactionController( + runner.post_tool_compaction + ) runner.context_config = SimpleNamespace(max_context_tokens=100) runner.run_context = SimpleNamespace(messages=[object(), object()]) runner.context_manager = SimpleNamespace( token_counter=SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 35) ) - runner._tool_compaction_baseline_tokens = 30 - runner._tool_compaction_baseline_messages = 2 + runner.post_tool_compaction_controller.refresh_baseline( + messages=runner.run_context.messages, + token_counter=SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 30), + ) # ratio=0.35 in soft zone, token delta=5 and message delta=0 -> should skip assert runner._should_run_post_tool_compaction() is False @@ -649,11 +659,17 @@ def test_post_tool_compaction_soft_zone_respects_min_delta(runner): def test_post_tool_compaction_handles_token_counter_errors(runner): - runner.compact_context_after_tool_call = True - runner.compact_context_soft_ratio = 0.3 - runner.compact_context_hard_ratio = 0.9 - runner.compact_context_min_delta_tokens = 10 - runner.compact_context_min_delta_turns = 10 + runner.post_tool_compaction = PostToolCompactionConfig( + enabled=True, + soft_ratio=0.3, + hard_ratio=0.9, + min_delta_tokens=10, + min_delta_turns=10, + debounce_seconds=0, + ) + runner.post_tool_compaction_controller = PostToolCompactionController( + runner.post_tool_compaction + ) runner.context_config = SimpleNamespace(max_context_tokens=100) runner.run_context = SimpleNamespace(messages=[object(), object()]) From f6175f26ab2f45a92bcf5616356e768d818ed6fe Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Fri, 20 Mar 2026 00:05:53 -0500 Subject: [PATCH 19/29] refactor scheduler compaction policy extraction --- astrbot/core/context_compaction_scheduler.py | 144 +++++++++--------- .../unit/test_context_compaction_scheduler.py | 33 ++-- 2 files changed, 96 insertions(+), 81 deletions(-) diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index d1cbc6fe3c..5b8666dab8 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -122,6 +122,78 @@ def from_default_conf( ) +@dataclass(frozen=True) +class CompactionPolicy: + cfg: CompactionConfig + token_counter: TokenCounter + + def check_eligibility( + self, + conv: ConversationV2, + history_parser: MessageHistoryParser, + ) -> EligibilityInfo | None: + history = conv.content + if not isinstance(history, list) or len(history) < self.cfg.min_messages: + return None + + if not self.is_idle_enough(conv.updated_at, self.cfg.min_idle_minutes): + return None + + messages = history_parser.parse(history) + if len(messages) < self.cfg.min_messages: + return None + + trusted_usage = conv.token_usage if isinstance(conv.token_usage, int) else 0 + before_tokens = self.token_counter.count_tokens(messages, trusted_usage) + return messages, before_tokens + + def resolve_trigger_tokens(self, provider: Provider) -> int: + if self.cfg.trigger_tokens > 0: + return self.cfg.trigger_tokens + + max_context_tokens = self.resolve_provider_max_context(provider) + if max_context_tokens > 0: + return max(1, int(max_context_tokens * self.cfg.trigger_min_context_ratio)) + + return max(int(self.cfg.target_tokens * 1.5), self.cfg.target_tokens + 1) + + @staticmethod + def resolve_provider_max_context(provider: Provider) -> int: + configured = provider.provider_config.get("max_context_tokens", 0) + try: + configured_tokens = int(configured) + except Exception: + configured_tokens = 0 + if configured_tokens > 0: + return configured_tokens + + model = provider.get_model() + model_info = LLM_METADATAS.get(model) + if not isinstance(model_info, dict): + return 0 + limit = model_info.get("limit") + if not isinstance(limit, dict): + return 0 + context = limit.get("context") + try: + context_tokens = int(context) + except Exception: + context_tokens = 0 + return max(context_tokens, 0) + + @staticmethod + def is_idle_enough(updated_at: datetime | None, min_idle_minutes: int) -> bool: + if min_idle_minutes <= 0: + return True + if updated_at is None: + return True + now = datetime.now(timezone.utc) + at = updated_at + if at.tzinfo is None: + at = at.replace(tzinfo=timezone.utc) + return (now - at).total_seconds() >= (min_idle_minutes * 60) + + class PeriodicContextCompactionScheduler: """Periodically compact conversation history and persist summarized history back to DB. @@ -378,12 +450,13 @@ async def _compact_one_conversation( return "failed" token_counter = self._resolve_token_counter(provider) - eligibility = self._check_eligibility(conv, cfg, token_counter) + policy = CompactionPolicy(cfg=cfg, token_counter=token_counter) + eligibility = policy.check_eligibility(conv, self._history_parser) if eligibility is None: return "skipped" messages, before_tokens = eligibility - trigger_tokens = self._resolve_trigger_tokens(cfg, provider) + trigger_tokens = policy.resolve_trigger_tokens(provider) if before_tokens < trigger_tokens: return "skipped" @@ -420,61 +493,6 @@ async def _compact_one_conversation( ) return "compacted" - def _check_eligibility( - self, - conv: ConversationV2, - cfg: CompactionConfig, - token_counter: TokenCounter, - ) -> EligibilityInfo | None: - history = conv.content - if not isinstance(history, list) or len(history) < cfg.min_messages: - return None - - if not self._is_idle_enough(conv.updated_at, cfg.min_idle_minutes): - return None - - messages = self._history_parser.parse(history) - if len(messages) < cfg.min_messages: - return None - - trusted_usage = conv.token_usage if isinstance(conv.token_usage, int) else 0 - before_tokens = token_counter.count_tokens(messages, trusted_usage) - return messages, before_tokens - - def _resolve_trigger_tokens(self, cfg: CompactionConfig, provider: Provider) -> int: - if cfg.trigger_tokens > 0: - return cfg.trigger_tokens - - max_context_tokens = self._resolve_provider_max_context(provider) - if max_context_tokens > 0: - return max(1, int(max_context_tokens * cfg.trigger_min_context_ratio)) - - return max(int(cfg.target_tokens * 1.5), cfg.target_tokens + 1) - - @staticmethod - def _resolve_provider_max_context(provider: Provider) -> int: - configured = provider.provider_config.get("max_context_tokens", 0) - try: - configured_tokens = int(configured) - except Exception: - configured_tokens = 0 - if configured_tokens > 0: - return configured_tokens - - model = provider.get_model() - model_info = LLM_METADATAS.get(model) - if not isinstance(model_info, dict): - return 0 - limit = model_info.get("limit") - if not isinstance(limit, dict): - return 0 - context = limit.get("context") - try: - context_tokens = int(context) - except Exception: - context_tokens = 0 - return max(context_tokens, 0) - async def _run_compaction_rounds( self, messages: list[Message], @@ -649,18 +667,6 @@ def _resolve_token_counter(self, provider: Provider | None) -> TokenCounter: self._token_counter_cache[cache_key] = resolved return resolved - @staticmethod - def _is_idle_enough(updated_at: datetime | None, min_idle_minutes: int) -> bool: - if min_idle_minutes <= 0: - return True - if updated_at is None: - return True - now = datetime.now(timezone.utc) - at = updated_at - if at.tzinfo is None: - at = at.replace(tzinfo=timezone.utc) - return (now - at).total_seconds() >= (min_idle_minutes * 60) - @staticmethod def _messages_equal(a: list[Message], b: list[Message]) -> bool: if len(a) != len(b): diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py index 0d10106e53..ecd8e2f521 100644 --- a/tests/unit/test_context_compaction_scheduler.py +++ b/tests/unit/test_context_compaction_scheduler.py @@ -11,6 +11,7 @@ from astrbot.core.config.default import PERIODIC_CONTEXT_COMPACTION_DEFAULTS from astrbot.core.context_compaction_scheduler import ( CompactionConfig, + CompactionPolicy, PeriodicContextCompactionScheduler, ) @@ -200,9 +201,9 @@ def test_is_idle_enough_respects_threshold() -> None: old = now - timedelta(minutes=30) recent = now - timedelta(minutes=2) - assert PeriodicContextCompactionScheduler._is_idle_enough(old, 10) is True - assert PeriodicContextCompactionScheduler._is_idle_enough(recent, 10) is False - assert PeriodicContextCompactionScheduler._is_idle_enough(None, 10) is True + assert CompactionPolicy.is_idle_enough(old, 10) is True + assert CompactionPolicy.is_idle_enough(recent, 10) is False + assert CompactionPolicy.is_idle_enough(None, 10) is True def test_resolve_run_limits_treats_max_scan_as_upper_bound() -> None: @@ -232,12 +233,13 @@ def test_resolve_trigger_tokens_prefers_manual_value() -> None: trigger_tokens=1500, trigger_min_context_ratio=0.3, ) + policy = CompactionPolicy(cfg=cfg, token_counter=SimpleNamespace()) provider = SimpleNamespace( provider_config={"max_context_tokens": 32768}, get_model=lambda: "unknown-model", ) - resolved = scheduler._resolve_trigger_tokens(cfg, provider) + resolved = policy.resolve_trigger_tokens(provider) assert resolved == 1500 @@ -249,12 +251,13 @@ def test_resolve_trigger_tokens_uses_ratio_when_auto_mode() -> None: trigger_tokens=0, trigger_min_context_ratio=0.3, ) + policy = CompactionPolicy(cfg=cfg, token_counter=SimpleNamespace()) provider = SimpleNamespace( provider_config={"max_context_tokens": 32768}, get_model=lambda: "unknown-model", ) - resolved = scheduler._resolve_trigger_tokens(cfg, provider) + resolved = policy.resolve_trigger_tokens(provider) assert resolved == 9830 @@ -266,12 +269,13 @@ def test_resolve_trigger_tokens_falls_back_when_provider_context_unknown() -> No trigger_tokens=0, trigger_min_context_ratio=0.3, ) + policy = CompactionPolicy(cfg=cfg, token_counter=SimpleNamespace()) provider = SimpleNamespace( provider_config={"max_context_tokens": 0}, get_model=lambda: "unknown-model", ) - resolved = scheduler._resolve_trigger_tokens(cfg, provider) + resolved = policy.resolve_trigger_tokens(provider) assert resolved == 1536 @@ -299,7 +303,7 @@ def _fake_create(mode: str | None = None, *, model: str | None = None): @pytest.mark.asyncio -async def test_compact_one_conversation_dry_run_reports_skipped() -> None: +async def test_compact_one_conversation_dry_run_reports_skipped(monkeypatch) -> None: scheduler = _build_scheduler({"periodic_context_compaction": {"enabled": True}}) cfg = replace(scheduler._load_config(), dry_run=True) @@ -310,10 +314,6 @@ async def test_compact_one_conversation_dry_run_reports_skipped() -> None: token_usage=0, updated_at=None, ) - scheduler._check_eligibility = lambda _conv, _cfg, _counter: ( # type: ignore[method-assign] - [Message(role="user", content="before")], - 100, - ) scheduler._resolve_provider = AsyncMock( # type: ignore[method-assign] return_value=SimpleNamespace(get_model=lambda: "gpt-4o") ) @@ -324,13 +324,22 @@ async def test_compact_one_conversation_dry_run_reports_skipped() -> None: rounds=1, ) ) - scheduler._resolve_trigger_tokens = lambda _cfg, _provider: 1 # type: ignore[method-assign] scheduler._resolve_token_counter = lambda _provider: SimpleNamespace( # type: ignore[method-assign] count_tokens=lambda *_args, **_kwargs: 50 ) scheduler._persist_compacted_history = AsyncMock( # type: ignore[method-assign] return_value=True ) + monkeypatch.setattr( + CompactionPolicy, + "check_eligibility", + lambda self, _conv, _parser: ([Message(role="user", content="before")], 100), + ) + monkeypatch.setattr( + CompactionPolicy, + "resolve_trigger_tokens", + lambda self, _provider: 1, + ) outcome = await scheduler._compact_one_conversation(conv, cfg) From 464506778c40b41e01bc084646a46ebc54788cbb Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Fri, 20 Mar 2026 00:16:46 -0500 Subject: [PATCH 20/29] fix context memory defaults and reserve migration interfaces --- astrbot/core/context_memory.py | 77 ++++++++++++++++++++++++++++++- tests/unit/test_context_memory.py | 36 +++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_context_memory.py diff --git a/astrbot/core/context_memory.py b/astrbot/core/context_memory.py index 1a390d84c4..9f6aaafd20 100644 --- a/astrbot/core/context_memory.py +++ b/astrbot/core/context_memory.py @@ -10,7 +10,6 @@ "enabled": False, # Manually maintained top-level memories injected into system prompt. "inject_pinned_memory": True, - "pinned_memories": [], "pinned_max_items": 8, "pinned_max_chars_per_item": 400, # Retrieval enhancement is intentionally reserved for future PRs. @@ -36,6 +35,79 @@ async def retrieve( ... +@runtime_checkable +class ContextMemoryEvolutionBackend(Protocol): + """Reserved protocol for MemEvolve-style memory evolution backend integration.""" + + async def evolve( + self, + *, + unified_msg_origin: str, + turns: list[str], + metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Evolve short-term conversation turns into durable memory artifacts.""" + ... + + async def retrieve( + self, + *, + unified_msg_origin: str, + query: str, + top_k: int, + ) -> list[str]: + """Retrieve evolved memory snippets for prompt assembly.""" + ... + + +@runtime_checkable +class ContextMemoryMigrationAdapter(Protocol): + """Reserved protocol for future context-memory schema/store migration.""" + + async def export_session( + self, + *, + unified_msg_origin: str, + ) -> dict[str, Any]: + """Export memory payload for migration or backup.""" + ... + + async def import_session( + self, + *, + unified_msg_origin: str, + payload: dict[str, Any], + ) -> None: + """Import migrated memory payload into target backend.""" + ... + + +_context_memory_evolution_backend: ContextMemoryEvolutionBackend | None = None +_context_memory_migration_adapter: ContextMemoryMigrationAdapter | None = None + + +def set_context_memory_evolution_backend( + backend: ContextMemoryEvolutionBackend | None, +) -> None: + global _context_memory_evolution_backend + _context_memory_evolution_backend = backend + + +def get_context_memory_evolution_backend() -> ContextMemoryEvolutionBackend | None: + return _context_memory_evolution_backend + + +def set_context_memory_migration_adapter( + adapter: ContextMemoryMigrationAdapter | None, +) -> None: + global _context_memory_migration_adapter + _context_memory_migration_adapter = adapter + + +def get_context_memory_migration_adapter() -> ContextMemoryMigrationAdapter | None: + return _context_memory_migration_adapter + + @dataclass(frozen=True) class ContextMemoryConfig: enabled: bool = False @@ -48,8 +120,11 @@ class ContextMemoryConfig: retrieval_provider_id: str = "" retrieval_top_k: int = 5 + def normalize_context_memory_settings(raw: dict[str, Any] | None) -> dict[str, Any]: normalized = dict(DEFAULT_CONTEXT_MEMORY_SETTINGS) + # Always initialize pinned_memories explicitly to avoid sharing mutable defaults. + normalized["pinned_memories"] = [] if not isinstance(raw, dict): return normalized diff --git a/tests/unit/test_context_memory.py b/tests/unit/test_context_memory.py new file mode 100644 index 0000000000..87f20e2e4a --- /dev/null +++ b/tests/unit/test_context_memory.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from astrbot.core.context_memory import ( + get_context_memory_evolution_backend, + get_context_memory_migration_adapter, + normalize_context_memory_settings, + set_context_memory_evolution_backend, + set_context_memory_migration_adapter, +) + + +def test_normalize_context_memory_settings_initializes_fresh_pinned_memories() -> None: + first = normalize_context_memory_settings(None) + second = normalize_context_memory_settings(None) + + first["pinned_memories"].append("A") + + assert first["pinned_memories"] == ["A"] + assert second["pinned_memories"] == [] + + +def test_context_memory_reserved_backend_registration() -> None: + backend = object() + adapter = object() + + set_context_memory_evolution_backend(backend) # type: ignore[arg-type] + set_context_memory_migration_adapter(adapter) # type: ignore[arg-type] + + assert get_context_memory_evolution_backend() is backend + assert get_context_memory_migration_adapter() is adapter + + set_context_memory_evolution_backend(None) + set_context_memory_migration_adapter(None) + + assert get_context_memory_evolution_backend() is None + assert get_context_memory_migration_adapter() is None From 8e54277761eeedb54e75a61c4d0a445f3f51adb8 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Fri, 20 Mar 2026 00:26:48 -0500 Subject: [PATCH 21/29] fix review issues on context defaults and debounce behavior --- .../agent/runners/tool_loop_agent_runner.py | 1 - astrbot/core/context_memory.py | 140 +++++------------- astrbot/core/context_memory_backends.py | 92 ++++++++++++ tests/test_tool_loop_agent_runner.py | 62 ++++++++ tests/unit/test_context_memory.py | 31 ++++ 5 files changed, 222 insertions(+), 104 deletions(-) create mode 100644 astrbot/core/context_memory_backends.py diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 9879d2b878..6a50974a78 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -130,7 +130,6 @@ def should_compact( and self._last_check_at > 0 and (now - self._last_check_at) < self.config.debounce_seconds ): - self._last_check_at = now return False self._last_check_at = now diff --git a/astrbot/core/context_memory.py b/astrbot/core/context_memory.py index 9f6aaafd20..4e4c0c3276 100644 --- a/astrbot/core/context_memory.py +++ b/astrbot/core/context_memory.py @@ -1,111 +1,45 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Protocol, runtime_checkable - +from typing import Any + +from astrbot.core.config.default import CONTEXT_MEMORY_DEFAULTS +from astrbot.core.context_memory_backends import ( + ContextMemoryEvolutionBackend, + ContextMemoryMigrationAdapter, + VectorLongTermMemoryRetriever, + get_context_memory_evolution_backend, + get_context_memory_migration_adapter, + set_context_memory_evolution_backend, + set_context_memory_migration_adapter, +) from astrbot.core.utils.config_normalization import to_bool, to_int -DEFAULT_CONTEXT_MEMORY_SETTINGS: dict[str, Any] = { - # Global switch for context-memory related features. - "enabled": False, - # Manually maintained top-level memories injected into system prompt. - "inject_pinned_memory": True, - "pinned_max_items": 8, - "pinned_max_chars_per_item": 400, - # Retrieval enhancement is intentionally reserved for future PRs. - "retrieval_enabled": False, - "retrieval_backend": "", - "retrieval_provider_id": "", - "retrieval_top_k": 5, -} - - -@runtime_checkable -class VectorLongTermMemoryRetriever(Protocol): - """Reserved protocol for future vector-DB long-term memory retrieval.""" - - async def retrieve( - self, - *, - unified_msg_origin: str, - query: str, - top_k: int, - ) -> list[str]: - """Return ranked memory snippets for prompt assembly.""" - ... - - -@runtime_checkable -class ContextMemoryEvolutionBackend(Protocol): - """Reserved protocol for MemEvolve-style memory evolution backend integration.""" - - async def evolve( - self, - *, - unified_msg_origin: str, - turns: list[str], - metadata: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Evolve short-term conversation turns into durable memory artifacts.""" - ... - - async def retrieve( - self, - *, - unified_msg_origin: str, - query: str, - top_k: int, - ) -> list[str]: - """Retrieve evolved memory snippets for prompt assembly.""" - ... - - -@runtime_checkable -class ContextMemoryMigrationAdapter(Protocol): - """Reserved protocol for future context-memory schema/store migration.""" - - async def export_session( - self, - *, - unified_msg_origin: str, - ) -> dict[str, Any]: - """Export memory payload for migration or backup.""" - ... - - async def import_session( - self, - *, - unified_msg_origin: str, - payload: dict[str, Any], - ) -> None: - """Import migrated memory payload into target backend.""" - ... - - -_context_memory_evolution_backend: ContextMemoryEvolutionBackend | None = None -_context_memory_migration_adapter: ContextMemoryMigrationAdapter | None = None - - -def set_context_memory_evolution_backend( - backend: ContextMemoryEvolutionBackend | None, -) -> None: - global _context_memory_evolution_backend - _context_memory_evolution_backend = backend - - -def get_context_memory_evolution_backend() -> ContextMemoryEvolutionBackend | None: - return _context_memory_evolution_backend - - -def set_context_memory_migration_adapter( - adapter: ContextMemoryMigrationAdapter | None, -) -> None: - global _context_memory_migration_adapter - _context_memory_migration_adapter = adapter - - -def get_context_memory_migration_adapter() -> ContextMemoryMigrationAdapter | None: - return _context_memory_migration_adapter + +def _clone_context_memory_defaults() -> dict[str, Any]: + defaults = dict(CONTEXT_MEMORY_DEFAULTS) + pinned = defaults.get("pinned_memories") + defaults["pinned_memories"] = list(pinned) if isinstance(pinned, list) else [] + return defaults + + +DEFAULT_CONTEXT_MEMORY_SETTINGS: dict[str, Any] = _clone_context_memory_defaults() + +__all__ = [ + "ContextMemoryConfig", + "DEFAULT_CONTEXT_MEMORY_SETTINGS", + "normalize_context_memory_settings", + "load_context_memory_config", + "ensure_context_memory_settings", + "build_pinned_memory_system_block", + "VectorLongTermMemoryRetriever", + "ContextMemoryEvolutionBackend", + "ContextMemoryMigrationAdapter", + "set_context_memory_evolution_backend", + "get_context_memory_evolution_backend", + "set_context_memory_migration_adapter", + "get_context_memory_migration_adapter", +] @dataclass(frozen=True) diff --git a/astrbot/core/context_memory_backends.py b/astrbot/core/context_memory_backends.py new file mode 100644 index 0000000000..225bfd720b --- /dev/null +++ b/astrbot/core/context_memory_backends.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class VectorLongTermMemoryRetriever(Protocol): + """Reserved protocol for future vector-DB long-term memory retrieval.""" + + async def retrieve( + self, + *, + unified_msg_origin: str, + query: str, + top_k: int, + ) -> list[str]: + """Return ranked memory snippets for prompt assembly.""" + ... + + +@runtime_checkable +class ContextMemoryEvolutionBackend(Protocol): + """Reserved protocol for MemEvolve-style memory evolution backend integration.""" + + async def evolve( + self, + *, + unified_msg_origin: str, + turns: list[str], + metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Evolve short-term conversation turns into durable memory artifacts.""" + ... + + async def retrieve( + self, + *, + unified_msg_origin: str, + query: str, + top_k: int, + ) -> list[str]: + """Retrieve evolved memory snippets for prompt assembly.""" + ... + + +@runtime_checkable +class ContextMemoryMigrationAdapter(Protocol): + """Reserved protocol for future context-memory schema/store migration.""" + + async def export_session( + self, + *, + unified_msg_origin: str, + ) -> dict[str, Any]: + """Export memory payload for migration or backup.""" + ... + + async def import_session( + self, + *, + unified_msg_origin: str, + payload: dict[str, Any], + ) -> None: + """Import migrated memory payload into target backend.""" + ... + + +_context_memory_evolution_backend: ContextMemoryEvolutionBackend | None = None +_context_memory_migration_adapter: ContextMemoryMigrationAdapter | None = None + + +def set_context_memory_evolution_backend( + backend: ContextMemoryEvolutionBackend | None, +) -> None: + global _context_memory_evolution_backend + _context_memory_evolution_backend = backend + + +def get_context_memory_evolution_backend() -> ContextMemoryEvolutionBackend | None: + return _context_memory_evolution_backend + + +def set_context_memory_migration_adapter( + adapter: ContextMemoryMigrationAdapter | None, +) -> None: + global _context_memory_migration_adapter + _context_memory_migration_adapter = adapter + + +def get_context_memory_migration_adapter() -> ContextMemoryMigrationAdapter | None: + return _context_memory_migration_adapter + diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index ad5c0c0c69..e124941a94 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -683,6 +683,68 @@ def _raise(*_args, **_kwargs): assert runner._should_run_post_tool_compaction() is False +def test_post_tool_compaction_debounce_is_not_extended(monkeypatch): + config = PostToolCompactionConfig( + enabled=True, + soft_ratio=0.3, + hard_ratio=0.9, + min_delta_tokens=0, + min_delta_turns=0, + debounce_seconds=100, + ) + controller = PostToolCompactionController(config) + messages = [object()] + token_counter = SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 95) + + # refresh baseline before checks + controller.refresh_baseline( + messages=messages, + token_counter=SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 30), + ) + + ts = iter([1.0, 10.0, 20.0, 105.0]) + monkeypatch.setattr( + "astrbot.core.agent.runners.tool_loop_agent_runner.time.monotonic", + lambda: next(ts), + ) + + # first check performs decision and sets baseline timestamp + assert ( + controller.should_compact( + messages=messages, + token_counter=token_counter, + max_context_tokens=100, + ) + is True + ) + # next two checks are debounced + assert ( + controller.should_compact( + messages=messages, + token_counter=token_counter, + max_context_tokens=100, + ) + is False + ) + assert ( + controller.should_compact( + messages=messages, + token_counter=token_counter, + max_context_tokens=100, + ) + is False + ) + # should become eligible at t=105 if debounce anchor remains at first real check (t=1) + assert ( + controller.should_compact( + messages=messages, + token_counter=token_counter, + max_context_tokens=100, + ) + is True + ) + + if __name__ == "__main__": # 运行测试 pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_context_memory.py b/tests/unit/test_context_memory.py index 87f20e2e4a..377b03778e 100644 --- a/tests/unit/test_context_memory.py +++ b/tests/unit/test_context_memory.py @@ -1,6 +1,8 @@ from __future__ import annotations +from astrbot.core.config.default import CONTEXT_MEMORY_DEFAULTS from astrbot.core.context_memory import ( + DEFAULT_CONTEXT_MEMORY_SETTINGS, get_context_memory_evolution_backend, get_context_memory_migration_adapter, normalize_context_memory_settings, @@ -34,3 +36,32 @@ def test_context_memory_reserved_backend_registration() -> None: assert get_context_memory_evolution_backend() is None assert get_context_memory_migration_adapter() is None + + +def test_context_memory_defaults_follow_single_source() -> None: + assert DEFAULT_CONTEXT_MEMORY_SETTINGS["enabled"] == CONTEXT_MEMORY_DEFAULTS["enabled"] + assert DEFAULT_CONTEXT_MEMORY_SETTINGS["inject_pinned_memory"] == CONTEXT_MEMORY_DEFAULTS[ + "inject_pinned_memory" + ] + assert DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_max_items"] == CONTEXT_MEMORY_DEFAULTS[ + "pinned_max_items" + ] + assert DEFAULT_CONTEXT_MEMORY_SETTINGS[ + "pinned_max_chars_per_item" + ] == CONTEXT_MEMORY_DEFAULTS["pinned_max_chars_per_item"] + assert DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_enabled"] == CONTEXT_MEMORY_DEFAULTS[ + "retrieval_enabled" + ] + assert DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_backend"] == CONTEXT_MEMORY_DEFAULTS[ + "retrieval_backend" + ] + assert DEFAULT_CONTEXT_MEMORY_SETTINGS[ + "retrieval_provider_id" + ] == CONTEXT_MEMORY_DEFAULTS["retrieval_provider_id"] + assert DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_top_k"] == CONTEXT_MEMORY_DEFAULTS[ + "retrieval_top_k" + ] + assert DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_memories"] == [] + assert DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_memories"] is not CONTEXT_MEMORY_DEFAULTS[ + "pinned_memories" + ] From 2bb335a8f1ce0c714ce7f88849f442e42bb32b4d Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Fri, 20 Mar 2026 00:36:08 -0500 Subject: [PATCH 22/29] refactor context compaction metadata and expand token counter tests --- astrbot/core/config/default.py | 346 +++++++----------- astrbot/core/context_memory.py | 2 +- astrbot/core/context_memory_backends.py | 118 ++---- .../context_memory_experimental_backends.py | 92 +++++ tests/agent/test_token_counter.py | 93 +++++ 5 files changed, 352 insertions(+), 299 deletions(-) create mode 100644 astrbot/core/context_memory_experimental_backends.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 3af2ddfc15..5df43e3967 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -38,6 +38,144 @@ "dry_run": False, } +PERIODIC_CONTEXT_COMPACTION_FIELD_META: dict[str, dict[str, Any]] = { + "enabled": { + "schema_type": "bool", + "ui_type": "bool", + "description": "启用定时历史压缩", + "hint": "后台定时扫描会话历史,使用 LLM 摘要旧消息并回写对话历史,实现多轮 compact context。", + }, + "interval_minutes": { + "schema_type": "int", + "ui_type": "int", + "description": "定时间隔(分钟)", + "hint": "每隔多少分钟执行一次压缩扫描。", + }, + "startup_delay_seconds": { + "schema_type": "int", + "ui_type": "int", + "description": "启动延迟(秒)", + "hint": "AstrBot 启动后,等待指定秒数再执行首次压缩任务。", + }, + "max_conversations_per_run": { + "schema_type": "int", + "ui_type": "int", + "description": "单次最多压缩会话数", + "hint": "每次任务最多实际压缩多少个会话。", + }, + "max_scan_per_run": { + "schema_type": "int", + "ui_type": "int", + "description": "单次最多扫描会话数", + "hint": "每次任务最多扫描多少会话(包括被跳过的会话)。", + }, + "scan_page_size": { + "schema_type": "int", + "ui_type": "int", + "description": "分页扫描大小", + "hint": "扫描 conversations 表时每页读取条数。", + }, + "min_idle_minutes": { + "schema_type": "int", + "ui_type": "int", + "description": "最小静默时长(分钟)", + "hint": "会话最近更新时间小于该值时跳过,避免压缩活跃会话。", + }, + "min_messages": { + "schema_type": "int", + "ui_type": "int", + "description": "最小消息条数", + "hint": "少于该消息条数的会话不参与压缩。", + }, + "target_tokens": { + "schema_type": "int", + "ui_type": "int", + "description": "目标 Token 阈值", + "hint": "压缩目标上下文大小(token 估算值)。", + }, + "trigger_tokens": { + "schema_type": "int", + "ui_type": "int", + "description": "触发 Token 阈值", + "hint": "会话估算 token 超过此值才触发压缩。<=0 表示自动按模型最大上下文比例计算。", + }, + "trigger_min_context_ratio": { + "schema_type": "float", + "ui_type": "float", + "description": "自动触发比例", + "hint": "当触发 Token 阈值 <= 0 时生效。默认 0.3(即模型最大上下文的 30%)。支持填写 0~1 或 0~100(百分比)。", + }, + "max_rounds": { + "schema_type": "int", + "ui_type": "int", + "description": "每会话最大压缩轮数", + "hint": "单个会话一次任务内最多执行几轮摘要压缩(实现 multiple compact context)。", + }, + "truncate_turns": { + "schema_type": "int", + "ui_type": "int", + "description": "截断轮数(后备)", + "hint": "LLM 压缩后仍超限时,按轮截断的每次丢弃轮数。", + }, + "keep_recent": { + "schema_type": "int", + "ui_type": "int", + "description": "保留最近轮数", + "hint": "压缩时始终保留最近 N 轮消息。", + }, + "provider_id": { + "schema_type": "string", + "ui_type": "string", + "description": "压缩模型提供商 ID", + "hint": "可自定义指定任意可用对话模型;留空时按会话当前模型执行压缩。建议优先选择成本较低、响应较快的模型。", + "_special": "select_provider", + }, + "instruction": { + "schema_type": "string", + "ui_type": "text", + "description": "定时压缩提示词", + "hint": "留空时复用 provider_settings.llm_compress_instruction。", + }, + "dry_run": { + "schema_type": "bool", + "ui_type": "bool", + "description": "演练模式(不回写)", + "hint": "开启后只记录日志,不实际写回数据库。", + }, +} + + +def _build_periodic_context_compaction_schema_properties() -> dict[str, dict[str, str]]: + return { + key: {"type": str(meta["schema_type"])} + for key, meta in PERIODIC_CONTEXT_COMPACTION_FIELD_META.items() + } + + +def _build_periodic_context_compaction_dashboard_items() -> dict[str, dict[str, Any]]: + items: dict[str, dict[str, Any]] = {} + base_enabled_condition = { + "provider_settings.periodic_context_compaction.enabled": True, + "provider_settings.agent_runner_type": "local", + } + for key, meta in PERIODIC_CONTEXT_COMPACTION_FIELD_META.items(): + condition = ( + {"provider_settings.agent_runner_type": "local"} + if key == "enabled" + else dict(base_enabled_condition) + ) + field: dict[str, Any] = { + "description": meta["description"], + "type": meta["ui_type"], + "hint": meta["hint"], + "condition": condition, + } + if "_special" in meta: + field["_special"] = meta["_special"] + items[f"provider_settings.periodic_context_compaction.{key}"] = field + return items + + CONTEXT_MEMORY_DEFAULTS = { "enabled": False, "inject_pinned_memory": True, @@ -2574,59 +2712,7 @@ class ChatProviderTemplate(TypedDict): }, "periodic_context_compaction": { "type": "object", - "properties": { - "enabled": { - "type": "bool", - }, - "interval_minutes": { - "type": "int", - }, - "startup_delay_seconds": { - "type": "int", - }, - "max_conversations_per_run": { - "type": "int", - }, - "max_scan_per_run": { - "type": "int", - }, - "scan_page_size": { - "type": "int", - }, - "min_idle_minutes": { - "type": "int", - }, - "min_messages": { - "type": "int", - }, - "target_tokens": { - "type": "int", - }, - "trigger_tokens": { - "type": "int", - }, - "trigger_min_context_ratio": { - "type": "float", - }, - "max_rounds": { - "type": "int", - }, - "truncate_turns": { - "type": "int", - }, - "keep_recent": { - "type": "int", - }, - "provider_id": { - "type": "string", - }, - "instruction": { - "type": "string", - }, - "dry_run": { - "type": "bool", - }, - }, + "properties": _build_periodic_context_compaction_schema_properties(), }, "context_memory": { "type": "object", @@ -3411,159 +3497,7 @@ class ChatProviderTemplate(TypedDict): "provider_settings.agent_runner_type": "local", }, }, - "provider_settings.periodic_context_compaction.enabled": { - "description": "启用定时历史压缩", - "type": "bool", - "hint": "后台定时扫描会话历史,使用 LLM 摘要旧消息并回写对话历史,实现多轮 compact context。", - "condition": { - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.interval_minutes": { - "description": "定时间隔(分钟)", - "type": "int", - "hint": "每隔多少分钟执行一次压缩扫描。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.startup_delay_seconds": { - "description": "启动延迟(秒)", - "type": "int", - "hint": "AstrBot 启动后,等待指定秒数再执行首次压缩任务。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.max_conversations_per_run": { - "description": "单次最多压缩会话数", - "type": "int", - "hint": "每次任务最多实际压缩多少个会话。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.max_scan_per_run": { - "description": "单次最多扫描会话数", - "type": "int", - "hint": "每次任务最多扫描多少会话(包括被跳过的会话)。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.scan_page_size": { - "description": "分页扫描大小", - "type": "int", - "hint": "扫描 conversations 表时每页读取条数。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.min_idle_minutes": { - "description": "最小静默时长(分钟)", - "type": "int", - "hint": "会话最近更新时间小于该值时跳过,避免压缩活跃会话。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.min_messages": { - "description": "最小消息条数", - "type": "int", - "hint": "少于该消息条数的会话不参与压缩。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.target_tokens": { - "description": "目标 Token 阈值", - "type": "int", - "hint": "压缩目标上下文大小(token 估算值)。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.trigger_tokens": { - "description": "触发 Token 阈值", - "type": "int", - "hint": "会话估算 token 超过此值才触发压缩。<=0 表示自动按模型最大上下文比例计算。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.trigger_min_context_ratio": { - "description": "自动触发比例", - "type": "float", - "hint": "当触发 Token 阈值 <= 0 时生效。默认 0.3(即模型最大上下文的 30%)。支持填写 0~1 或 0~100(百分比)。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.max_rounds": { - "description": "每会话最大压缩轮数", - "type": "int", - "hint": "单个会话一次任务内最多执行几轮摘要压缩(实现 multiple compact context)。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.truncate_turns": { - "description": "截断轮数(后备)", - "type": "int", - "hint": "LLM 压缩后仍超限时,按轮截断的每次丢弃轮数。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.keep_recent": { - "description": "保留最近轮数", - "type": "int", - "hint": "压缩时始终保留最近 N 轮消息。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.provider_id": { - "description": "压缩模型提供商 ID", - "type": "string", - "_special": "select_provider", - "hint": "留空时按会话当前模型执行压缩。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.instruction": { - "description": "定时压缩提示词", - "type": "text", - "hint": "留空时复用 provider_settings.llm_compress_instruction。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.periodic_context_compaction.dry_run": { - "description": "演练模式(不回写)", - "type": "bool", - "hint": "开启后只记录日志,不实际写回数据库。", - "condition": { - "provider_settings.periodic_context_compaction.enabled": True, - "provider_settings.agent_runner_type": "local", - }, - }, + **_build_periodic_context_compaction_dashboard_items(), "provider_settings.context_memory.enabled": { "description": "启用上下文记忆注入", "type": "bool", diff --git a/astrbot/core/context_memory.py b/astrbot/core/context_memory.py index 4e4c0c3276..a9dd2dcad2 100644 --- a/astrbot/core/context_memory.py +++ b/astrbot/core/context_memory.py @@ -4,7 +4,7 @@ from typing import Any from astrbot.core.config.default import CONTEXT_MEMORY_DEFAULTS -from astrbot.core.context_memory_backends import ( +from astrbot.core.context_memory_experimental_backends import ( ContextMemoryEvolutionBackend, ContextMemoryMigrationAdapter, VectorLongTermMemoryRetriever, diff --git a/astrbot/core/context_memory_backends.py b/astrbot/core/context_memory_backends.py index 225bfd720b..79d91f1c34 100644 --- a/astrbot/core/context_memory_backends.py +++ b/astrbot/core/context_memory_backends.py @@ -1,92 +1,26 @@ -from __future__ import annotations - -from typing import Any, Protocol, runtime_checkable - - -@runtime_checkable -class VectorLongTermMemoryRetriever(Protocol): - """Reserved protocol for future vector-DB long-term memory retrieval.""" - - async def retrieve( - self, - *, - unified_msg_origin: str, - query: str, - top_k: int, - ) -> list[str]: - """Return ranked memory snippets for prompt assembly.""" - ... - - -@runtime_checkable -class ContextMemoryEvolutionBackend(Protocol): - """Reserved protocol for MemEvolve-style memory evolution backend integration.""" - - async def evolve( - self, - *, - unified_msg_origin: str, - turns: list[str], - metadata: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Evolve short-term conversation turns into durable memory artifacts.""" - ... - - async def retrieve( - self, - *, - unified_msg_origin: str, - query: str, - top_k: int, - ) -> list[str]: - """Retrieve evolved memory snippets for prompt assembly.""" - ... - - -@runtime_checkable -class ContextMemoryMigrationAdapter(Protocol): - """Reserved protocol for future context-memory schema/store migration.""" - - async def export_session( - self, - *, - unified_msg_origin: str, - ) -> dict[str, Any]: - """Export memory payload for migration or backup.""" - ... - - async def import_session( - self, - *, - unified_msg_origin: str, - payload: dict[str, Any], - ) -> None: - """Import migrated memory payload into target backend.""" - ... - - -_context_memory_evolution_backend: ContextMemoryEvolutionBackend | None = None -_context_memory_migration_adapter: ContextMemoryMigrationAdapter | None = None - - -def set_context_memory_evolution_backend( - backend: ContextMemoryEvolutionBackend | None, -) -> None: - global _context_memory_evolution_backend - _context_memory_evolution_backend = backend - - -def get_context_memory_evolution_backend() -> ContextMemoryEvolutionBackend | None: - return _context_memory_evolution_backend - - -def set_context_memory_migration_adapter( - adapter: ContextMemoryMigrationAdapter | None, -) -> None: - global _context_memory_migration_adapter - _context_memory_migration_adapter = adapter - - -def get_context_memory_migration_adapter() -> ContextMemoryMigrationAdapter | None: - return _context_memory_migration_adapter - +"""Compatibility re-exports for experimental context-memory backend hooks. + +The protocol definitions and global hook state live in +`context_memory_experimental_backends.py` to keep experimental extension points +explicitly isolated from stable context-memory config logic. +""" + +from astrbot.core.context_memory_experimental_backends import ( + ContextMemoryEvolutionBackend, + ContextMemoryMigrationAdapter, + VectorLongTermMemoryRetriever, + get_context_memory_evolution_backend, + get_context_memory_migration_adapter, + set_context_memory_evolution_backend, + set_context_memory_migration_adapter, +) + +__all__ = [ + "VectorLongTermMemoryRetriever", + "ContextMemoryEvolutionBackend", + "ContextMemoryMigrationAdapter", + "set_context_memory_evolution_backend", + "get_context_memory_evolution_backend", + "set_context_memory_migration_adapter", + "get_context_memory_migration_adapter", +] diff --git a/astrbot/core/context_memory_experimental_backends.py b/astrbot/core/context_memory_experimental_backends.py new file mode 100644 index 0000000000..b461440097 --- /dev/null +++ b/astrbot/core/context_memory_experimental_backends.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class VectorLongTermMemoryRetriever(Protocol): + """Experimental protocol for future vector-DB long-term memory retrieval.""" + + async def retrieve( + self, + *, + unified_msg_origin: str, + query: str, + top_k: int, + ) -> list[str]: + """Return ranked memory snippets for prompt assembly.""" + ... + + +@runtime_checkable +class ContextMemoryEvolutionBackend(Protocol): + """Experimental protocol for MemEvolve-style memory evolution integration.""" + + async def evolve( + self, + *, + unified_msg_origin: str, + turns: list[str], + metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Evolve short-term conversation turns into durable memory artifacts.""" + ... + + async def retrieve( + self, + *, + unified_msg_origin: str, + query: str, + top_k: int, + ) -> list[str]: + """Retrieve evolved memory snippets for prompt assembly.""" + ... + + +@runtime_checkable +class ContextMemoryMigrationAdapter(Protocol): + """Experimental protocol for future context-memory schema/store migration.""" + + async def export_session( + self, + *, + unified_msg_origin: str, + ) -> dict[str, Any]: + """Export memory payload for migration or backup.""" + ... + + async def import_session( + self, + *, + unified_msg_origin: str, + payload: dict[str, Any], + ) -> None: + """Import migrated memory payload into target backend.""" + ... + + +_context_memory_evolution_backend: ContextMemoryEvolutionBackend | None = None +_context_memory_migration_adapter: ContextMemoryMigrationAdapter | None = None + + +def set_context_memory_evolution_backend( + backend: ContextMemoryEvolutionBackend | None, +) -> None: + global _context_memory_evolution_backend + _context_memory_evolution_backend = backend + + +def get_context_memory_evolution_backend() -> ContextMemoryEvolutionBackend | None: + return _context_memory_evolution_backend + + +def set_context_memory_migration_adapter( + adapter: ContextMemoryMigrationAdapter | None, +) -> None: + global _context_memory_migration_adapter + _context_memory_migration_adapter = adapter + + +def get_context_memory_migration_adapter() -> ContextMemoryMigrationAdapter | None: + return _context_memory_migration_adapter + diff --git a/tests/agent/test_token_counter.py b/tests/agent/test_token_counter.py index ae65b917ad..fad8dfda1f 100644 --- a/tests/agent/test_token_counter.py +++ b/tests/agent/test_token_counter.py @@ -134,3 +134,96 @@ def test_tokenizer_counter_gracefully_handles_broken_fallback_encoder( created = create_token_counter("tokenizer", model="gpt-4") assert isinstance(created, EstimateTokenCounter) + + +class TestTokenizerTokenCounterBehavior: + def test_count_tokens_uses_trusted_usage_and_skips_encode(self, monkeypatch): + counter = TokenizerTokenCounter(model="gpt-4") + + def _encode_should_not_be_called(_text): # pragma: no cover - safety guard + raise AssertionError( + "_encode should not be called when trusted_token_usage is provided" + ) + + monkeypatch.setattr(counter, "_encode", _encode_should_not_be_called) + + messages = [ + _msg( + "user", + [ + TextPart(text="hello"), + ThinkPart(think="internal thoughts"), + ImageURLPart( + image_url=ImageURLPart.ImageURL( + url="data:image/png;base64,abc" + ) + ), + AudioURLPart( + audio_url=AudioURLPart.AudioURL(url="https://x.com/a.mp3") + ), + ], + ) + ] + + trusted_usage = 123 + result = counter.count_tokens(messages, trusted_token_usage=trusted_usage) + assert result == trusted_usage + + def test_encode_error_falls_back_to_estimate_text_tokens(self, monkeypatch): + counter = TokenizerTokenCounter(model="gpt-4") + + def broken_encode(_text): + raise RuntimeError("tiktoken failure") + + captured: dict[str, str] = {} + + def fake_estimate(text: str) -> int: + captured["text"] = text + return 42 + + monkeypatch.setattr(counter, "_encode", broken_encode) + monkeypatch.setattr(counter._estimate, "estimate_text_tokens", fake_estimate) + + result = counter._encode_len("fallback text") + assert result == 42 + assert captured["text"] == "fallback text" + + def test_tokenizer_mode_mixed_modalities_use_fixed_estimates(self, monkeypatch): + counter = TokenizerTokenCounter(model="gpt-4") + counter._available = True + + def fake_encode_len(text: str) -> int: + return len(text.split()) + + monkeypatch.setattr(counter, "_encode_len", fake_encode_len) + + messages = [ + _msg( + "user", + [ + TextPart(text="hello world"), + ImageURLPart( + image_url=ImageURLPart.ImageURL( + url="data:image/png;base64,image" + ) + ), + ], + ), + _msg( + "assistant", + [ + ThinkPart(think="thinking hard"), + AudioURLPart( + audio_url=AudioURLPart.AudioURL(url="https://x.com/a.mp3") + ), + ], + ), + ] + + expected_text_tokens = fake_encode_len("hello world") + fake_encode_len( + "thinking hard" + ) + expected_non_text_tokens = IMAGE_TOKEN_ESTIMATE + AUDIO_TOKEN_ESTIMATE + + tokens = counter.count_tokens(messages) + assert tokens == expected_text_tokens + expected_non_text_tokens From cb5685cca9c283d47a14823485c77e8152925525 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Fri, 20 Mar 2026 00:59:50 -0500 Subject: [PATCH 23/29] refactor context memory config and experimental backend facade --- astrbot/core/context_memory.py | 181 ++++++++---------- astrbot/core/context_memory_backends.py | 6 + .../context_memory_experimental_backends.py | 50 ++++- tests/unit/test_context_memory.py | 57 +++--- 4 files changed, 157 insertions(+), 137 deletions(-) diff --git a/astrbot/core/context_memory.py b/astrbot/core/context_memory.py index a9dd2dcad2..008a8c1c6c 100644 --- a/astrbot/core/context_memory.py +++ b/astrbot/core/context_memory.py @@ -4,41 +4,14 @@ from typing import Any from astrbot.core.config.default import CONTEXT_MEMORY_DEFAULTS -from astrbot.core.context_memory_experimental_backends import ( - ContextMemoryEvolutionBackend, - ContextMemoryMigrationAdapter, - VectorLongTermMemoryRetriever, - get_context_memory_evolution_backend, - get_context_memory_migration_adapter, - set_context_memory_evolution_backend, - set_context_memory_migration_adapter, -) from astrbot.core.utils.config_normalization import to_bool, to_int - -def _clone_context_memory_defaults() -> dict[str, Any]: - defaults = dict(CONTEXT_MEMORY_DEFAULTS) - pinned = defaults.get("pinned_memories") - defaults["pinned_memories"] = list(pinned) if isinstance(pinned, list) else [] - return defaults - - -DEFAULT_CONTEXT_MEMORY_SETTINGS: dict[str, Any] = _clone_context_memory_defaults() - __all__ = [ "ContextMemoryConfig", - "DEFAULT_CONTEXT_MEMORY_SETTINGS", "normalize_context_memory_settings", "load_context_memory_config", "ensure_context_memory_settings", "build_pinned_memory_system_block", - "VectorLongTermMemoryRetriever", - "ContextMemoryEvolutionBackend", - "ContextMemoryMigrationAdapter", - "set_context_memory_evolution_backend", - "get_context_memory_evolution_backend", - "set_context_memory_migration_adapter", - "get_context_memory_migration_adapter", ] @@ -54,86 +27,98 @@ class ContextMemoryConfig: retrieval_provider_id: str = "" retrieval_top_k: int = 5 + @classmethod + def from_settings( + cls, + provider_settings: dict[str, Any] | None, + ) -> ContextMemoryConfig: + raw = None + if isinstance(provider_settings, dict): + raw = provider_settings.get("context_memory") + return cls.from_raw(raw if isinstance(raw, dict) else None) + + @classmethod + def from_raw(cls, raw: dict[str, Any] | None) -> ContextMemoryConfig: + defaults = CONTEXT_MEMORY_DEFAULTS + data = raw if isinstance(raw, dict) else {} + + enabled = to_bool(data.get("enabled"), bool(defaults["enabled"])) + inject_pinned_memory = to_bool( + data.get("inject_pinned_memory"), + bool(defaults["inject_pinned_memory"]), + ) + pinned_max_items = to_int( + data.get("pinned_max_items"), + int(defaults["pinned_max_items"]), + 1, + ) + pinned_max_chars_per_item = to_int( + data.get("pinned_max_chars_per_item"), + int(defaults["pinned_max_chars_per_item"]), + 1, + ) + retrieval_enabled = to_bool( + data.get("retrieval_enabled"), + bool(defaults["retrieval_enabled"]), + ) + retrieval_backend = str(data.get("retrieval_backend", "") or "").strip() + retrieval_provider_id = str(data.get("retrieval_provider_id", "") or "").strip() + retrieval_top_k = to_int( + data.get("retrieval_top_k"), + int(defaults["retrieval_top_k"]), + 1, + ) + + pinned_raw = data.get("pinned_memories", []) + pinned_memories: list[str] = [] + if isinstance(pinned_raw, list): + for item in pinned_raw: + text = str(item or "").strip() + if not text: + continue + if len(text) > pinned_max_chars_per_item: + text = text[:pinned_max_chars_per_item] + pinned_memories.append(text) + if len(pinned_memories) >= pinned_max_items: + break + + return cls( + enabled=enabled, + inject_pinned_memory=inject_pinned_memory, + pinned_memories=pinned_memories, + pinned_max_items=pinned_max_items, + pinned_max_chars_per_item=pinned_max_chars_per_item, + retrieval_enabled=retrieval_enabled, + retrieval_backend=retrieval_backend, + retrieval_provider_id=retrieval_provider_id, + retrieval_top_k=retrieval_top_k, + ) + + def to_settings_dict(self) -> dict[str, Any]: + return { + "enabled": self.enabled, + "inject_pinned_memory": self.inject_pinned_memory, + "pinned_memories": list(self.pinned_memories), + "pinned_max_items": self.pinned_max_items, + "pinned_max_chars_per_item": self.pinned_max_chars_per_item, + "retrieval_enabled": self.retrieval_enabled, + "retrieval_backend": self.retrieval_backend, + "retrieval_provider_id": self.retrieval_provider_id, + "retrieval_top_k": self.retrieval_top_k, + } -def normalize_context_memory_settings(raw: dict[str, Any] | None) -> dict[str, Any]: - normalized = dict(DEFAULT_CONTEXT_MEMORY_SETTINGS) - # Always initialize pinned_memories explicitly to avoid sharing mutable defaults. - normalized["pinned_memories"] = [] - if not isinstance(raw, dict): - return normalized - - normalized["enabled"] = to_bool( - raw.get("enabled"), - bool(DEFAULT_CONTEXT_MEMORY_SETTINGS["enabled"]), - ) - normalized["inject_pinned_memory"] = to_bool( - raw.get("inject_pinned_memory"), - bool(DEFAULT_CONTEXT_MEMORY_SETTINGS["inject_pinned_memory"]), - ) - normalized["pinned_max_items"] = to_int( - raw.get("pinned_max_items"), - int(DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_max_items"]), - 1, - ) - normalized["pinned_max_chars_per_item"] = to_int( - raw.get("pinned_max_chars_per_item"), - int(DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_max_chars_per_item"]), - 1, - ) - normalized["retrieval_enabled"] = to_bool( - raw.get("retrieval_enabled"), - bool(DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_enabled"]), - ) - normalized["retrieval_backend"] = str(raw.get("retrieval_backend", "") or "").strip() - normalized["retrieval_provider_id"] = str( - raw.get("retrieval_provider_id", "") or "" - ).strip() - normalized["retrieval_top_k"] = to_int( - raw.get("retrieval_top_k"), - int(DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_top_k"]), - 1, - ) - - pinned_max_items = int(normalized["pinned_max_items"]) - pinned_max_chars = int(normalized["pinned_max_chars_per_item"]) - pinned_raw = raw.get("pinned_memories", []) - pinned_memories: list[str] = [] - if isinstance(pinned_raw, list): - for item in pinned_raw: - text = str(item or "").strip() - if not text: - continue - if len(text) > pinned_max_chars: - text = text[:pinned_max_chars] - pinned_memories.append(text) - if len(pinned_memories) >= pinned_max_items: - break - normalized["pinned_memories"] = pinned_memories - return normalized +def normalize_context_memory_settings(raw: dict[str, Any] | None) -> dict[str, Any]: + return ContextMemoryConfig.from_raw(raw if isinstance(raw, dict) else None).to_settings_dict() def load_context_memory_config(provider_settings: dict[str, Any] | None) -> ContextMemoryConfig: - raw = None - if isinstance(provider_settings, dict): - raw = provider_settings.get("context_memory") - normalized = normalize_context_memory_settings(raw if isinstance(raw, dict) else None) - return ContextMemoryConfig( - enabled=bool(normalized["enabled"]), - inject_pinned_memory=bool(normalized["inject_pinned_memory"]), - pinned_memories=list(normalized["pinned_memories"]), - pinned_max_items=int(normalized["pinned_max_items"]), - pinned_max_chars_per_item=int(normalized["pinned_max_chars_per_item"]), - retrieval_enabled=bool(normalized["retrieval_enabled"]), - retrieval_backend=str(normalized["retrieval_backend"]), - retrieval_provider_id=str(normalized["retrieval_provider_id"]), - retrieval_top_k=int(normalized["retrieval_top_k"]), - ) + return ContextMemoryConfig.from_settings(provider_settings) def ensure_context_memory_settings(provider_settings: dict[str, Any]) -> dict[str, Any]: """Normalize and persist context_memory subtree in provider_settings.""" - normalized = normalize_context_memory_settings(provider_settings.get("context_memory")) + normalized = ContextMemoryConfig.from_settings(provider_settings).to_settings_dict() provider_settings["context_memory"] = normalized return normalized diff --git a/astrbot/core/context_memory_backends.py b/astrbot/core/context_memory_backends.py index 79d91f1c34..ca02240997 100644 --- a/astrbot/core/context_memory_backends.py +++ b/astrbot/core/context_memory_backends.py @@ -8,9 +8,12 @@ from astrbot.core.context_memory_experimental_backends import ( ContextMemoryEvolutionBackend, ContextMemoryMigrationAdapter, + ExperimentalContextMemoryBackends, VectorLongTermMemoryRetriever, + configure_context_memory_backends, get_context_memory_evolution_backend, get_context_memory_migration_adapter, + get_experimental_context_memory_backends, set_context_memory_evolution_backend, set_context_memory_migration_adapter, ) @@ -19,6 +22,9 @@ "VectorLongTermMemoryRetriever", "ContextMemoryEvolutionBackend", "ContextMemoryMigrationAdapter", + "ExperimentalContextMemoryBackends", + "configure_context_memory_backends", + "get_experimental_context_memory_backends", "set_context_memory_evolution_backend", "get_context_memory_evolution_backend", "set_context_memory_migration_adapter", diff --git a/astrbot/core/context_memory_experimental_backends.py b/astrbot/core/context_memory_experimental_backends.py index b461440097..981badcad5 100644 --- a/astrbot/core/context_memory_experimental_backends.py +++ b/astrbot/core/context_memory_experimental_backends.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Any, Protocol, runtime_checkable @@ -65,28 +66,59 @@ async def import_session( ... -_context_memory_evolution_backend: ContextMemoryEvolutionBackend | None = None -_context_memory_migration_adapter: ContextMemoryMigrationAdapter | None = None +@dataclass +class ExperimentalContextMemoryBackends: + """Container for optional experimental backends.""" + evolution_backend: ContextMemoryEvolutionBackend | None = None + migration_adapter: ContextMemoryMigrationAdapter | None = None + + +_backends = ExperimentalContextMemoryBackends() + + +def configure_context_memory_backends( + *, + evolution_backend: ContextMemoryEvolutionBackend | None = None, + migration_adapter: ContextMemoryMigrationAdapter | None = None, +) -> None: + """Configure optional experimental backends in one cohesive entry point.""" + _backends.evolution_backend = evolution_backend + _backends.migration_adapter = migration_adapter + + +def get_experimental_context_memory_backends() -> ExperimentalContextMemoryBackends: + return _backends def set_context_memory_evolution_backend( backend: ContextMemoryEvolutionBackend | None, ) -> None: - global _context_memory_evolution_backend - _context_memory_evolution_backend = backend + _backends.evolution_backend = backend def get_context_memory_evolution_backend() -> ContextMemoryEvolutionBackend | None: - return _context_memory_evolution_backend + return _backends.evolution_backend def set_context_memory_migration_adapter( adapter: ContextMemoryMigrationAdapter | None, ) -> None: - global _context_memory_migration_adapter - _context_memory_migration_adapter = adapter + _backends.migration_adapter = adapter def get_context_memory_migration_adapter() -> ContextMemoryMigrationAdapter | None: - return _context_memory_migration_adapter - + return _backends.migration_adapter + + +__all__ = [ + "VectorLongTermMemoryRetriever", + "ContextMemoryEvolutionBackend", + "ContextMemoryMigrationAdapter", + "ExperimentalContextMemoryBackends", + "configure_context_memory_backends", + "get_experimental_context_memory_backends", + "set_context_memory_evolution_backend", + "get_context_memory_evolution_backend", + "set_context_memory_migration_adapter", + "get_context_memory_migration_adapter", +] diff --git a/tests/unit/test_context_memory.py b/tests/unit/test_context_memory.py index 377b03778e..12cb659015 100644 --- a/tests/unit/test_context_memory.py +++ b/tests/unit/test_context_memory.py @@ -2,10 +2,14 @@ from astrbot.core.config.default import CONTEXT_MEMORY_DEFAULTS from astrbot.core.context_memory import ( - DEFAULT_CONTEXT_MEMORY_SETTINGS, + ContextMemoryConfig, + normalize_context_memory_settings, +) +from astrbot.core.context_memory_experimental_backends import ( + configure_context_memory_backends, get_context_memory_evolution_backend, get_context_memory_migration_adapter, - normalize_context_memory_settings, + get_experimental_context_memory_backends, set_context_memory_evolution_backend, set_context_memory_migration_adapter, ) @@ -25,11 +29,16 @@ def test_context_memory_reserved_backend_registration() -> None: backend = object() adapter = object() - set_context_memory_evolution_backend(backend) # type: ignore[arg-type] - set_context_memory_migration_adapter(adapter) # type: ignore[arg-type] + configure_context_memory_backends( + evolution_backend=backend, # type: ignore[arg-type] + migration_adapter=adapter, # type: ignore[arg-type] + ) assert get_context_memory_evolution_backend() is backend assert get_context_memory_migration_adapter() is adapter + backends = get_experimental_context_memory_backends() + assert backends.evolution_backend is backend + assert backends.migration_adapter is adapter set_context_memory_evolution_backend(None) set_context_memory_migration_adapter(None) @@ -39,29 +48,17 @@ def test_context_memory_reserved_backend_registration() -> None: def test_context_memory_defaults_follow_single_source() -> None: - assert DEFAULT_CONTEXT_MEMORY_SETTINGS["enabled"] == CONTEXT_MEMORY_DEFAULTS["enabled"] - assert DEFAULT_CONTEXT_MEMORY_SETTINGS["inject_pinned_memory"] == CONTEXT_MEMORY_DEFAULTS[ - "inject_pinned_memory" - ] - assert DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_max_items"] == CONTEXT_MEMORY_DEFAULTS[ - "pinned_max_items" - ] - assert DEFAULT_CONTEXT_MEMORY_SETTINGS[ - "pinned_max_chars_per_item" - ] == CONTEXT_MEMORY_DEFAULTS["pinned_max_chars_per_item"] - assert DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_enabled"] == CONTEXT_MEMORY_DEFAULTS[ - "retrieval_enabled" - ] - assert DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_backend"] == CONTEXT_MEMORY_DEFAULTS[ - "retrieval_backend" - ] - assert DEFAULT_CONTEXT_MEMORY_SETTINGS[ - "retrieval_provider_id" - ] == CONTEXT_MEMORY_DEFAULTS["retrieval_provider_id"] - assert DEFAULT_CONTEXT_MEMORY_SETTINGS["retrieval_top_k"] == CONTEXT_MEMORY_DEFAULTS[ - "retrieval_top_k" - ] - assert DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_memories"] == [] - assert DEFAULT_CONTEXT_MEMORY_SETTINGS["pinned_memories"] is not CONTEXT_MEMORY_DEFAULTS[ - "pinned_memories" - ] + cfg = ContextMemoryConfig.from_raw(None) + + assert cfg.enabled == CONTEXT_MEMORY_DEFAULTS["enabled"] + assert cfg.inject_pinned_memory == CONTEXT_MEMORY_DEFAULTS["inject_pinned_memory"] + assert cfg.pinned_max_items == CONTEXT_MEMORY_DEFAULTS["pinned_max_items"] + assert ( + cfg.pinned_max_chars_per_item + == CONTEXT_MEMORY_DEFAULTS["pinned_max_chars_per_item"] + ) + assert cfg.retrieval_enabled == CONTEXT_MEMORY_DEFAULTS["retrieval_enabled"] + assert cfg.retrieval_backend == CONTEXT_MEMORY_DEFAULTS["retrieval_backend"] + assert cfg.retrieval_provider_id == CONTEXT_MEMORY_DEFAULTS["retrieval_provider_id"] + assert cfg.retrieval_top_k == CONTEXT_MEMORY_DEFAULTS["retrieval_top_k"] + assert cfg.pinned_memories == [] From 994f5d4bdd235e169d3c9556a7c0262382f7a3d1 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Fri, 20 Mar 2026 01:39:50 -0500 Subject: [PATCH 24/29] shrink experimental context memory backend api surface --- astrbot/core/context_memory_backends.py | 27 +++---------------- .../context_memory_experimental_backends.py | 24 ----------------- tests/unit/test_context_memory.py | 14 +++------- 3 files changed, 6 insertions(+), 59 deletions(-) diff --git a/astrbot/core/context_memory_backends.py b/astrbot/core/context_memory_backends.py index ca02240997..9046cf096a 100644 --- a/astrbot/core/context_memory_backends.py +++ b/astrbot/core/context_memory_backends.py @@ -5,28 +5,7 @@ explicitly isolated from stable context-memory config logic. """ -from astrbot.core.context_memory_experimental_backends import ( - ContextMemoryEvolutionBackend, - ContextMemoryMigrationAdapter, - ExperimentalContextMemoryBackends, - VectorLongTermMemoryRetriever, - configure_context_memory_backends, - get_context_memory_evolution_backend, - get_context_memory_migration_adapter, - get_experimental_context_memory_backends, - set_context_memory_evolution_backend, - set_context_memory_migration_adapter, -) +from astrbot.core import context_memory_experimental_backends as _exp -__all__ = [ - "VectorLongTermMemoryRetriever", - "ContextMemoryEvolutionBackend", - "ContextMemoryMigrationAdapter", - "ExperimentalContextMemoryBackends", - "configure_context_memory_backends", - "get_experimental_context_memory_backends", - "set_context_memory_evolution_backend", - "get_context_memory_evolution_backend", - "set_context_memory_migration_adapter", - "get_context_memory_migration_adapter", -] +__all__ = list(_exp.__all__) +globals().update({name: getattr(_exp, name) for name in __all__}) diff --git a/astrbot/core/context_memory_experimental_backends.py b/astrbot/core/context_memory_experimental_backends.py index 981badcad5..3c89d786f5 100644 --- a/astrbot/core/context_memory_experimental_backends.py +++ b/astrbot/core/context_memory_experimental_backends.py @@ -90,26 +90,6 @@ def configure_context_memory_backends( def get_experimental_context_memory_backends() -> ExperimentalContextMemoryBackends: return _backends -def set_context_memory_evolution_backend( - backend: ContextMemoryEvolutionBackend | None, -) -> None: - _backends.evolution_backend = backend - - -def get_context_memory_evolution_backend() -> ContextMemoryEvolutionBackend | None: - return _backends.evolution_backend - - -def set_context_memory_migration_adapter( - adapter: ContextMemoryMigrationAdapter | None, -) -> None: - _backends.migration_adapter = adapter - - -def get_context_memory_migration_adapter() -> ContextMemoryMigrationAdapter | None: - return _backends.migration_adapter - - __all__ = [ "VectorLongTermMemoryRetriever", "ContextMemoryEvolutionBackend", @@ -117,8 +97,4 @@ def get_context_memory_migration_adapter() -> ContextMemoryMigrationAdapter | No "ExperimentalContextMemoryBackends", "configure_context_memory_backends", "get_experimental_context_memory_backends", - "set_context_memory_evolution_backend", - "get_context_memory_evolution_backend", - "set_context_memory_migration_adapter", - "get_context_memory_migration_adapter", ] diff --git a/tests/unit/test_context_memory.py b/tests/unit/test_context_memory.py index 12cb659015..02b3cca0d0 100644 --- a/tests/unit/test_context_memory.py +++ b/tests/unit/test_context_memory.py @@ -7,11 +7,7 @@ ) from astrbot.core.context_memory_experimental_backends import ( configure_context_memory_backends, - get_context_memory_evolution_backend, - get_context_memory_migration_adapter, get_experimental_context_memory_backends, - set_context_memory_evolution_backend, - set_context_memory_migration_adapter, ) @@ -34,17 +30,13 @@ def test_context_memory_reserved_backend_registration() -> None: migration_adapter=adapter, # type: ignore[arg-type] ) - assert get_context_memory_evolution_backend() is backend - assert get_context_memory_migration_adapter() is adapter backends = get_experimental_context_memory_backends() assert backends.evolution_backend is backend assert backends.migration_adapter is adapter - set_context_memory_evolution_backend(None) - set_context_memory_migration_adapter(None) - - assert get_context_memory_evolution_backend() is None - assert get_context_memory_migration_adapter() is None + configure_context_memory_backends(evolution_backend=None, migration_adapter=None) + assert get_experimental_context_memory_backends().evolution_backend is None + assert get_experimental_context_memory_backends().migration_adapter is None def test_context_memory_defaults_follow_single_source() -> None: From 3210ed6e53ea00cf9474c3c5946c0d61f6765d6c Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Fri, 20 Mar 2026 01:53:33 -0500 Subject: [PATCH 25/29] remove global state from experimental context memory backends --- astrbot/core/context_memory_backends.py | 6 ++--- .../context_memory_experimental_backends.py | 22 +++++++------------ tests/unit/test_context_memory.py | 13 +++++------ 3 files changed, 16 insertions(+), 25 deletions(-) diff --git a/astrbot/core/context_memory_backends.py b/astrbot/core/context_memory_backends.py index 9046cf096a..6f1fe8d00a 100644 --- a/astrbot/core/context_memory_backends.py +++ b/astrbot/core/context_memory_backends.py @@ -1,8 +1,8 @@ """Compatibility re-exports for experimental context-memory backend hooks. -The protocol definitions and global hook state live in -`context_memory_experimental_backends.py` to keep experimental extension points -explicitly isolated from stable context-memory config logic. +Experimental protocol and bundle factory definitions live in +`context_memory_experimental_backends.py` to keep extension points isolated from +stable context-memory config logic. """ from astrbot.core import context_memory_experimental_backends as _exp diff --git a/astrbot/core/context_memory_experimental_backends.py b/astrbot/core/context_memory_experimental_backends.py index 3c89d786f5..690196c44d 100644 --- a/astrbot/core/context_memory_experimental_backends.py +++ b/astrbot/core/context_memory_experimental_backends.py @@ -74,27 +74,21 @@ class ExperimentalContextMemoryBackends: migration_adapter: ContextMemoryMigrationAdapter | None = None -_backends = ExperimentalContextMemoryBackends() - - -def configure_context_memory_backends( +def make_experimental_context_memory_backends( *, evolution_backend: ContextMemoryEvolutionBackend | None = None, migration_adapter: ContextMemoryMigrationAdapter | None = None, -) -> None: - """Configure optional experimental backends in one cohesive entry point.""" - _backends.evolution_backend = evolution_backend - _backends.migration_adapter = migration_adapter - - -def get_experimental_context_memory_backends() -> ExperimentalContextMemoryBackends: - return _backends +) -> ExperimentalContextMemoryBackends: + """Create an experimental backend bundle without module-level mutable state.""" + return ExperimentalContextMemoryBackends( + evolution_backend=evolution_backend, + migration_adapter=migration_adapter, + ) __all__ = [ "VectorLongTermMemoryRetriever", "ContextMemoryEvolutionBackend", "ContextMemoryMigrationAdapter", "ExperimentalContextMemoryBackends", - "configure_context_memory_backends", - "get_experimental_context_memory_backends", + "make_experimental_context_memory_backends", ] diff --git a/tests/unit/test_context_memory.py b/tests/unit/test_context_memory.py index 02b3cca0d0..fa5f3611b1 100644 --- a/tests/unit/test_context_memory.py +++ b/tests/unit/test_context_memory.py @@ -6,8 +6,7 @@ normalize_context_memory_settings, ) from astrbot.core.context_memory_experimental_backends import ( - configure_context_memory_backends, - get_experimental_context_memory_backends, + make_experimental_context_memory_backends, ) @@ -25,18 +24,16 @@ def test_context_memory_reserved_backend_registration() -> None: backend = object() adapter = object() - configure_context_memory_backends( + backends = make_experimental_context_memory_backends( evolution_backend=backend, # type: ignore[arg-type] migration_adapter=adapter, # type: ignore[arg-type] ) - - backends = get_experimental_context_memory_backends() assert backends.evolution_backend is backend assert backends.migration_adapter is adapter - configure_context_memory_backends(evolution_backend=None, migration_adapter=None) - assert get_experimental_context_memory_backends().evolution_backend is None - assert get_experimental_context_memory_backends().migration_adapter is None + empty_backends = make_experimental_context_memory_backends() + assert empty_backends.evolution_backend is None + assert empty_backends.migration_adapter is None def test_context_memory_defaults_follow_single_source() -> None: From c4ea8ca737526f638f02f6c5925de48aa2e75515 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Fri, 20 Mar 2026 02:03:28 -0500 Subject: [PATCH 26/29] align compaction scheduler token mode and idle filtering logic --- astrbot/core/context_compaction_scheduler.py | 31 +++++---- .../unit/test_context_compaction_scheduler.py | 63 +++++++++++++++++++ 2 files changed, 81 insertions(+), 13 deletions(-) diff --git a/astrbot/core/context_compaction_scheduler.py b/astrbot/core/context_compaction_scheduler.py index 5b8666dab8..e2045e8f0b 100644 --- a/astrbot/core/context_compaction_scheduler.py +++ b/astrbot/core/context_compaction_scheduler.py @@ -4,7 +4,7 @@ import time from collections.abc import AsyncIterator from dataclasses import asdict, dataclass -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone from typing import TYPE_CHECKING, Any from astrbot import logger @@ -387,18 +387,12 @@ async def _iter_candidate_conversations( scan_page_size: int, cfg: CompactionConfig, ) -> AsyncIterator[ConversationV2]: - updated_before: datetime | None = None - if cfg.min_idle_minutes > 0: - updated_before = datetime.now(timezone.utc) - timedelta( - minutes=int(cfg.min_idle_minutes), - ) - page = 1 while not self._stop_event.is_set(): conversations, total = await self.conversation_manager.db.get_filtered_conversations( page=page, page_size=scan_page_size, - updated_before=updated_before, + updated_before=None, min_messages=cfg.min_messages, ) if not conversations: @@ -636,11 +630,7 @@ def _resolve_instruction(self, cfg: CompactionConfig) -> str: return "" def _resolve_token_counter(self, provider: Provider | None) -> TokenCounter: - provider_settings = self.config_manager.default_conf.get("provider_settings", {}) - mode = "estimate" - if isinstance(provider_settings, dict): - mode = str(provider_settings.get("context_token_counter_mode", "estimate")) - mode = mode.strip().lower() or "estimate" + mode = self._resolve_token_counter_mode(provider) model = "" if provider is not None: @@ -667,6 +657,21 @@ def _resolve_token_counter(self, provider: Provider | None) -> TokenCounter: self._token_counter_cache[cache_key] = resolved return resolved + def _resolve_token_counter_mode(self, provider: Provider | None) -> str: + if provider is not None: + provider_settings = getattr(provider, "provider_settings", None) + if isinstance(provider_settings, dict): + mode = str(provider_settings.get("context_token_counter_mode", "") or "") + normalized = mode.strip().lower() + if normalized: + return normalized + + provider_settings = self.config_manager.default_conf.get("provider_settings", {}) + mode = "estimate" + if isinstance(provider_settings, dict): + mode = str(provider_settings.get("context_token_counter_mode", "estimate")) + return mode.strip().lower() or "estimate" + @staticmethod def _messages_equal(a: list[Message], b: list[Message]) -> bool: if len(a) != len(b): diff --git a/tests/unit/test_context_compaction_scheduler.py b/tests/unit/test_context_compaction_scheduler.py index ecd8e2f521..8d3f8ef6d2 100644 --- a/tests/unit/test_context_compaction_scheduler.py +++ b/tests/unit/test_context_compaction_scheduler.py @@ -302,6 +302,69 @@ def _fake_create(mode: str | None = None, *, model: str | None = None): assert called["model"] == "gpt-4o" +def test_resolve_token_counter_prefers_provider_level_mode(monkeypatch) -> None: + scheduler = _build_scheduler({"context_token_counter_mode": "estimate"}) + provider = SimpleNamespace( + get_model=lambda: "gpt-4o", + provider_settings={"context_token_counter_mode": "tokenizer"}, + ) + called: dict[str, str | None] = {} + + fake_counter = SimpleNamespace(count_tokens=lambda *_args, **_kwargs: 0) + + def _fake_create(mode: str | None = None, *, model: str | None = None): + called["mode"] = mode + called["model"] = model + return fake_counter + + monkeypatch.setattr( + "astrbot.core.context_compaction_scheduler.create_token_counter", + _fake_create, + ) + + resolved = scheduler._resolve_token_counter(provider) + assert resolved is fake_counter + assert called["mode"] == "tokenizer" + assert called["model"] == "gpt-4o" + + +@pytest.mark.asyncio +async def test_iter_candidate_conversations_does_not_apply_idle_filter_in_db_query() -> None: + scheduler = _build_scheduler( + {"periodic_context_compaction": {"enabled": True, "min_idle_minutes": 30}} + ) + + class _FakeDB: + def __init__(self) -> None: + self.updated_before_calls: list[datetime | None] = [] + + async def get_filtered_conversations( + self, + *, + page: int, + page_size: int, + updated_before: datetime | None, + min_messages: int, + ): + self.updated_before_calls.append(updated_before) + return [], 0 + + fake_db = _FakeDB() + scheduler.conversation_manager = SimpleNamespace(db=fake_db) # type: ignore[assignment] + + cfg = scheduler._load_config() + result = [ + conv + async for conv in scheduler._iter_candidate_conversations( + scan_page_size=40, + cfg=cfg, + ) + ] + + assert result == [] + assert fake_db.updated_before_calls == [None] + + @pytest.mark.asyncio async def test_compact_one_conversation_dry_run_reports_skipped(monkeypatch) -> None: scheduler = _build_scheduler({"periodic_context_compaction": {"enabled": True}}) From 58e4258dd3660f7fc983ca9698b689e55a09ea40 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Fri, 20 Mar 2026 02:20:53 -0500 Subject: [PATCH 27/29] simplify experimental context memory backend to unified protocol --- astrbot/core/context_memory_backends.py | 2 +- .../context_memory_experimental_backends.py | 50 ++----------------- tests/unit/test_context_memory.py | 44 +++++++++++----- 3 files changed, 35 insertions(+), 61 deletions(-) diff --git a/astrbot/core/context_memory_backends.py b/astrbot/core/context_memory_backends.py index 6f1fe8d00a..a508a3dca3 100644 --- a/astrbot/core/context_memory_backends.py +++ b/astrbot/core/context_memory_backends.py @@ -1,6 +1,6 @@ """Compatibility re-exports for experimental context-memory backend hooks. -Experimental protocol and bundle factory definitions live in +Experimental protocol definitions live in `context_memory_experimental_backends.py` to keep extension points isolated from stable context-memory config logic. """ diff --git a/astrbot/core/context_memory_experimental_backends.py b/astrbot/core/context_memory_experimental_backends.py index 690196c44d..cf8ec7689c 100644 --- a/astrbot/core/context_memory_experimental_backends.py +++ b/astrbot/core/context_memory_experimental_backends.py @@ -1,27 +1,11 @@ from __future__ import annotations -from dataclasses import dataclass from typing import Any, Protocol, runtime_checkable @runtime_checkable -class VectorLongTermMemoryRetriever(Protocol): - """Experimental protocol for future vector-DB long-term memory retrieval.""" - - async def retrieve( - self, - *, - unified_msg_origin: str, - query: str, - top_k: int, - ) -> list[str]: - """Return ranked memory snippets for prompt assembly.""" - ... - - -@runtime_checkable -class ContextMemoryEvolutionBackend(Protocol): - """Experimental protocol for MemEvolve-style memory evolution integration.""" +class ContextMemoryBackend(Protocol): + """Experimental unified protocol for context-memory evolution + migration.""" async def evolve( self, @@ -43,11 +27,6 @@ async def retrieve( """Retrieve evolved memory snippets for prompt assembly.""" ... - -@runtime_checkable -class ContextMemoryMigrationAdapter(Protocol): - """Experimental protocol for future context-memory schema/store migration.""" - async def export_session( self, *, @@ -66,29 +45,6 @@ async def import_session( ... -@dataclass -class ExperimentalContextMemoryBackends: - """Container for optional experimental backends.""" - - evolution_backend: ContextMemoryEvolutionBackend | None = None - migration_adapter: ContextMemoryMigrationAdapter | None = None - - -def make_experimental_context_memory_backends( - *, - evolution_backend: ContextMemoryEvolutionBackend | None = None, - migration_adapter: ContextMemoryMigrationAdapter | None = None, -) -> ExperimentalContextMemoryBackends: - """Create an experimental backend bundle without module-level mutable state.""" - return ExperimentalContextMemoryBackends( - evolution_backend=evolution_backend, - migration_adapter=migration_adapter, - ) - __all__ = [ - "VectorLongTermMemoryRetriever", - "ContextMemoryEvolutionBackend", - "ContextMemoryMigrationAdapter", - "ExperimentalContextMemoryBackends", - "make_experimental_context_memory_backends", + "ContextMemoryBackend", ] diff --git a/tests/unit/test_context_memory.py b/tests/unit/test_context_memory.py index fa5f3611b1..830ac7a1b7 100644 --- a/tests/unit/test_context_memory.py +++ b/tests/unit/test_context_memory.py @@ -6,7 +6,7 @@ normalize_context_memory_settings, ) from astrbot.core.context_memory_experimental_backends import ( - make_experimental_context_memory_backends, + ContextMemoryBackend, ) @@ -20,20 +20,38 @@ def test_normalize_context_memory_settings_initializes_fresh_pinned_memories() - assert second["pinned_memories"] == [] -def test_context_memory_reserved_backend_registration() -> None: - backend = object() - adapter = object() +def test_context_memory_unified_backend_protocol_shape() -> None: + class _Backend: + async def evolve( + self, + *, + unified_msg_origin: str, + turns: list[str], + metadata=None, + ) -> dict: + return {"ok": True} - backends = make_experimental_context_memory_backends( - evolution_backend=backend, # type: ignore[arg-type] - migration_adapter=adapter, # type: ignore[arg-type] - ) - assert backends.evolution_backend is backend - assert backends.migration_adapter is adapter + async def retrieve( + self, + *, + unified_msg_origin: str, + query: str, + top_k: int, + ) -> list[str]: + return [] + + async def export_session(self, *, unified_msg_origin: str) -> dict: + return {} + + async def import_session( + self, + *, + unified_msg_origin: str, + payload: dict, + ) -> None: + return None - empty_backends = make_experimental_context_memory_backends() - assert empty_backends.evolution_backend is None - assert empty_backends.migration_adapter is None + assert isinstance(_Backend(), ContextMemoryBackend) def test_context_memory_defaults_follow_single_source() -> None: From 85f2f23ec0d2fcfe3dc555a1bf14b4c7335eab7d Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Fri, 20 Mar 2026 02:52:34 -0500 Subject: [PATCH 28/29] fix(builtin_commands): apply command_group before permission_type for ctxcompact/ctxmem --- astrbot/builtin_stars/builtin_commands/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/main.py b/astrbot/builtin_stars/builtin_commands/main.py index f2939f7540..14455a2bb7 100644 --- a/astrbot/builtin_stars/builtin_commands/main.py +++ b/astrbot/builtin_stars/builtin_commands/main.py @@ -132,8 +132,8 @@ async def provider( """查看或者切换 LLM Provider""" await self.provider_c.provider(event, idx, idx2) - @filter.permission_type(filter.PermissionType.ADMIN) @filter.command_group("ctxcompact") + @filter.permission_type(filter.PermissionType.ADMIN) def ctxcompact(self) -> None: """上下文定时压缩管理""" @@ -153,8 +153,8 @@ async def ctxcompact_run( """手动触发一次上下文压缩(可选 limit 覆盖本次压缩会话数)""" await self.ctxcompact_c.run(event, limit) - @filter.permission_type(filter.PermissionType.ADMIN) @filter.command_group("ctxmem") + @filter.permission_type(filter.PermissionType.ADMIN) def ctxmem(self) -> None: """上下文记忆管理(手动顶层记忆)""" From ef3e019c4be1de1c122ccc1e21c88194e91ac202 Mon Sep 17 00:00:00 2001 From: Jacobinwwey Date: Fri, 20 Mar 2026 20:07:52 -0500 Subject: [PATCH 29/29] fix(context): honor post-tool compaction soft threshold --- astrbot/core/agent/context/manager.py | 10 ++++++-- .../agent/runners/tool_loop_agent_runner.py | 1 + tests/agent/test_context_manager.py | 23 +++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 927b78a1fe..72d6ddfbbe 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -45,12 +45,18 @@ def __init__( ) async def process( - self, messages: list[Message], trusted_token_usage: int = 0 + self, + messages: list[Message], + trusted_token_usage: int = 0, + force_compaction: bool = False, ) -> list[Message]: """Process the messages. Args: messages: The original message list. + trusted_token_usage: Optional trusted token usage hint. + force_compaction: Force one compaction pass when token-based compaction + is enabled, regardless of compressor threshold. Returns: The processed message list. @@ -72,7 +78,7 @@ async def process( result, trusted_token_usage ) - if self.compressor.should_compress( + if force_compaction or self.compressor.should_compress( result, total_tokens, self.config.max_context_tokens ): result = await self._run_compression(result, total_tokens) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 6a50974a78..2b4700f31c 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -748,6 +748,7 @@ async def step(self): if self._should_run_post_tool_compaction(): self.run_context.messages = await self.context_manager.process( self.run_context.messages, + force_compaction=True, ) self._refresh_tool_compaction_baseline() diff --git a/tests/agent/test_context_manager.py b/tests/agent/test_context_manager.py index d5c6c8c1f1..5f78028111 100644 --- a/tests/agent/test_context_manager.py +++ b/tests/agent/test_context_manager.py @@ -222,6 +222,29 @@ async def test_token_compression_not_triggered_below_threshold(self): mock_compress.assert_not_called() assert result == messages + @pytest.mark.asyncio + async def test_force_compaction_bypasses_threshold_gate(self): + """Test that force_compaction bypasses compressor threshold gate.""" + config = ContextConfig(max_context_tokens=1000) + manager = ContextManager(config) + + messages = [self.create_message("user", "Hello")] + + with patch.object( + manager.compressor, "should_compress", return_value=False + ) as mock_should_compress: + with patch.object( + manager, + "_run_compression", + new_callable=AsyncMock, + return_value=messages, + ) as mock_run_compression: + result = await manager.process(messages, force_compaction=True) + + mock_should_compress.assert_not_called() + mock_run_compression.assert_called_once() + assert result == messages + @pytest.mark.asyncio async def test_token_compression_triggered_above_threshold(self): """Test that compression is triggered above threshold."""