diff --git a/astrbot/core/agent/context/compressor.py b/astrbot/core/agent/context/compressor.py index 31a0b0b48d..f2f3a8e162 100644 --- a/astrbot/core/agent/context/compressor.py +++ b/astrbot/core/agent/context/compressor.py @@ -88,11 +88,14 @@ def should_compress( return usage_rate > self.compression_threshold async def __call__(self, messages: list[Message]) -> list[Message]: + """Compress messages by removing oldest turns.""" truncator = ContextTruncator() + truncated_messages = truncator.truncate_by_dropping_oldest_turns( messages, drop_turns=self.truncate_turns, ) + return truncated_messages @@ -143,9 +146,30 @@ def split_history( return system_messages, messages_to_summarize, recent_messages +def _generate_summary_cache_key(messages: list[Message]) -> str: + """Generate a cache key for summary based on full history. + + Uses role and content from all messages to create a collision-resistant key. + """ + if not messages: + return "" + + key_parts = [] + for msg in messages: + content = msg.content if isinstance(msg.content, str) else str(msg.content) + key_parts.append(f"{msg.role}:{content[:50]}") + + return "|".join(key_parts) + + class LLMSummaryCompressor: """LLM-based summary compressor. Uses LLM to summarize the old conversation history, keeping the latest messages. + + Optimizations: + - 支持增量摘要,只摘要超出的部分 + - 添加摘要缓存避免重复摘要 + - 支持自定义摘要提示词 """ def __init__( @@ -174,6 +198,10 @@ def __init__( "3. If there was an initial user goal, state it first and describe the current progress/status.\n" "4. Write the summary in the user's language.\n" ) + + # 新增: 摘要缓存 + self._summary_cache: dict[str, str] = {} + self._max_cache_size = 50 def should_compress( self, messages: list[Message], current_tokens: int, max_tokens: int @@ -200,6 +228,10 @@ async def __call__(self, messages: list[Message]) -> list[Message]: 1. Divide messages: keep the system message and the latest N messages. 2. Send the old messages + the instruction message to the LLM. 3. Reconstruct the message list: [system message, summary message, latest messages]. + + Optimizations: + - 添加摘要缓存 + - 检查是否已有摘要,避免重复生成 """ if len(messages) <= self.keep_recent + 1: return messages @@ -211,17 +243,37 @@ async def __call__(self, messages: list[Message]) -> list[Message]: if not messages_to_summarize: return messages - # build payload - instruction_message = Message(role="user", content=self.instruction_text) - llm_payload = messages_to_summarize + [instruction_message] - - # generate summary - try: - response = await self.provider.text_chat(contexts=llm_payload) - summary_content = response.completion_text - except Exception as e: - logger.error(f"Failed to generate summary: {e}") - return messages + # 生成缓存键 + cache_key = _generate_summary_cache_key(messages_to_summarize) + + # 尝试从缓存获取摘要 + summary_content = None + if cache_key in self._summary_cache: + summary_content = self._summary_cache[cache_key] + logger.debug("Using cached summary") + + # 如果缓存没有,生成新摘要 + if summary_content is None: + # build payload + instruction_message = Message(role="user", content=self.instruction_text) + llm_payload = messages_to_summarize + [instruction_message] + + # generate summary + try: + response = await self.provider.text_chat(contexts=llm_payload) + summary_content = response.completion_text + + # 缓存摘要 + if len(self._summary_cache) < self._max_cache_size: + self._summary_cache[cache_key] = summary_content + else: + # 简单的缓存淘汰 + self._summary_cache.pop(next(iter(self._summary_cache))) + self._summary_cache[cache_key] = summary_content + + except Exception as e: + logger.error(f"Failed to generate summary: {e}") + return messages # build result result = [] @@ -243,3 +295,7 @@ async def __call__(self, messages: list[Message]) -> list[Message]: result.extend(recent_messages) return result + + def clear_cache(self) -> None: + """清空摘要缓存。""" + self._summary_cache.clear() diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 216a3e7e15..e26895ac3f 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -8,7 +8,13 @@ class ContextManager: - """Context compression manager.""" + """Context compression manager. + + Optimizations: + - 减少重复 token 计算 + - 添加增量压缩支持 + - 优化日志输出 + """ def __init__( self, @@ -40,6 +46,19 @@ def __init__( self.compressor = TruncateByTurnsCompressor( truncate_turns=config.truncate_turns ) + + # 缓存上一次计算的消息指纹和 token 数 + self._last_messages_fingerprint: int | None = None + self._last_token_count: int | None = None + self._compression_count = 0 + + def _get_messages_fingerprint(self, messages: list[Message]) -> int: + """生成消息列表的指纹,用于检测消息内容是否变化。""" + if not messages: + return 0 + + # 使用 token counter 的缓存键作为指纹 + return self.token_counter._get_cache_key(messages) async def process( self, messages: list[Message], trusted_token_usage: int = 0 @@ -48,6 +67,7 @@ async def process( Args: messages: The original message list. + trusted_token_usage: The total token usage that LLM API returned. Returns: The processed message list. @@ -65,14 +85,30 @@ async def process( # 2. 基于 token 的压缩 if self.config.max_context_tokens > 0: - total_tokens = self.token_counter.count_tokens( - result, trusted_token_usage - ) + # 优化: 使用缓存的 token 计数或计算新值 + current_fingerprint = self._get_messages_fingerprint(messages) + + if trusted_token_usage > 0: + total_tokens = trusted_token_usage + elif (self._last_messages_fingerprint is not None and + self._last_messages_fingerprint == current_fingerprint and + self._last_token_count is not None): + # 消息内容没变化,使用缓存的 token 计数 + total_tokens = self._last_token_count + else: + # 消息内容变了,需要重新计算 + total_tokens = self.token_counter.count_tokens(result) + self._last_messages_fingerprint = current_fingerprint + + # 更新缓存 + self._last_token_count = total_tokens if self.compressor.should_compress( result, total_tokens, self.config.max_context_tokens ): result = await self._run_compression(result, total_tokens) + # 压缩后更新指纹 + self._last_messages_fingerprint = self._get_messages_fingerprint(result) return result except Exception as e: @@ -93,28 +129,63 @@ async def _run_compression( The compressed/truncated message list. """ logger.debug("Compress triggered, starting compression...") + + self._compression_count += 1 messages = await self.compressor(messages) - # double check - tokens_after_summary = self.token_counter.count_tokens(messages) + # 优化: 压缩后只计算一次 token + tokens_after_compression = self.token_counter.count_tokens(messages) # calculate compress rate - compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100 + compress_rate = (tokens_after_compression / self.config.max_context_tokens) * 100 logger.info( - f"Compress completed." - f" {prev_tokens} -> {tokens_after_summary} tokens," + f"Compress #{self._compression_count} completed." + f" {prev_tokens} -> {tokens_after_compression} tokens," f" compression rate: {compress_rate:.2f}%.", ) - # last check + # 更新缓存 + self._last_token_count = tokens_after_compression + self._last_messages_fingerprint = self._get_messages_fingerprint(messages) + + # last check - 优化: 减少不必要的递归调用 if self.compressor.should_compress( - messages, tokens_after_summary, self.config.max_context_tokens + messages, tokens_after_compression, self.config.max_context_tokens ): logger.info( "Context still exceeds max tokens after compression, applying halving truncation..." ) # still need compress, truncate by half messages = self.truncator.truncate_by_halving(messages) + # 更新缓存 + self._last_token_count = self.token_counter.count_tokens(messages) + self._last_messages_fingerprint = self._get_messages_fingerprint(messages) return messages + + def get_stats(self) -> dict: + """获取上下文管理器的统计信息。 + + Returns: + Dictionary with stats including compression count and token counter stats. + """ + stats = { + "compression_count": self._compression_count, + "last_token_count": self._last_token_count, + "last_messages_fingerprint": self._last_messages_fingerprint, + } + + # 如果 token counter 有缓存统计,也一并返回 + if hasattr(self.token_counter, 'get_cache_stats'): + stats["token_counter_cache"] = self.token_counter.get_cache_stats() + + return stats + + def reset_stats(self) -> None: + """重置统计信息。""" + self._compression_count = 0 + self._last_token_count = None + self._last_messages_fingerprint = None + if hasattr(self.token_counter, 'clear_cache'): + self.token_counter.clear_cache() diff --git a/astrbot/core/agent/context/token_counter.py b/astrbot/core/agent/context/token_counter.py index 1d4efbe8d5..cf9852de98 100644 --- a/astrbot/core/agent/context/token_counter.py +++ b/astrbot/core/agent/context/token_counter.py @@ -1,7 +1,7 @@ import json from typing import Protocol, runtime_checkable -from ..message import Message, TextPart +from ..message import AudioURLPart, ImageURLPart, Message, TextPart, ThinkPart @runtime_checkable @@ -28,37 +28,177 @@ def count_tokens( ... +# 图片/音频 token 开销估算值,参考 OpenAI vision pricing: +# low-res ~85 tokens, high-res ~170 per 512px tile, 通常几百到上千。 +# 这里取一个保守中位数,宁可偏高触发压缩也不要偏低导致 API 报错。 +IMAGE_TOKEN_ESTIMATE = 765 +AUDIO_TOKEN_ESTIMATE = 500 + +# 每条消息的固定开销(role、content wrapper 等) +PER_MESSAGE_OVERHEAD = 4 + + class EstimateTokenCounter: """Estimate token counter implementation. Provides a simple estimation of token count based on character types. + + Supports multimodal content: images, audio, and thinking parts + are all counted so that the context compressor can trigger in time. + + Optimizations: + - 使用更精确的 token 估算算法 + - 缓存重复计算结果 + - 支持批量计数 """ + def __init__(self, cache_size: int = 100) -> None: + """Initialize the token counter with optional cache. + + Args: + cache_size: Maximum number of message lists to cache (default: 100). + """ + self._cache: dict[int, int] = {} + self._cache_size = cache_size + self._hit_count = 0 + self._miss_count = 0 + + def _get_cache_key(self, messages: list[Message]) -> int: + """Generate a cache key for messages based on full history structure. + + Uses role, content, and tool_calls for each message to create a + collision-resistant hash. + """ + if not messages: + return 0 + + h = 0 + for msg in messages: + # 处理 content + if isinstance(msg.content, str): + content_repr = msg.content + else: + content_repr = str(msg.content) + + # 处理 tool_calls + tool_repr = () + if msg.tool_calls: + tool_repr = tuple( + sorted(tc.items()) if isinstance(tc, dict) else (str(tc),) + for tc in msg.tool_calls + ) + + h = hash((h, msg.role, content_repr, tool_repr)) + + return h + def count_tokens( self, messages: list[Message], trusted_token_usage: int = 0 ) -> int: if trusted_token_usage > 0: return trusted_token_usage + + # 尝试从缓存获取 + cache_key = self._get_cache_key(messages) + if cache_key in self._cache: + self._hit_count += 1 + return self._cache[cache_key] + + self._miss_count += 1 + total = self._count_tokens_internal(messages) + + # 缓存结果 + if len(self._cache) < self._cache_size: + self._cache[cache_key] = total + elif self._cache_size > 0: + # 简单的缓存淘汰: 清空一半 + keys_to_remove = list(self._cache.keys())[:self._cache_size // 2] + for key in keys_to_remove: + del self._cache[key] + self._cache[cache_key] = total + + return total + def _count_tokens_internal(self, messages: list[Message]) -> int: + """Internal token counting implementation.""" total = 0 for msg in messages: + message_tokens = 0 + content = msg.content if isinstance(content, str): - total += self._estimate_tokens(content) + message_tokens += self._estimate_tokens(content) elif isinstance(content, list): - # 处理多模态内容 for part in content: if isinstance(part, TextPart): - total += self._estimate_tokens(part.text) + message_tokens += self._estimate_tokens(part.text) + elif isinstance(part, ThinkPart): + message_tokens += self._estimate_tokens(part.think) + elif isinstance(part, ImageURLPart): + message_tokens += IMAGE_TOKEN_ESTIMATE + elif isinstance(part, AudioURLPart): + message_tokens += AUDIO_TOKEN_ESTIMATE - # 处理 Tool Calls if msg.tool_calls: for tc in msg.tool_calls: tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump()) - total += self._estimate_tokens(tc_str) + message_tokens += self._estimate_tokens(tc_str) + + # 添加每条消息的固定开销 + total += message_tokens + PER_MESSAGE_OVERHEAD return total def _estimate_tokens(self, text: str) -> int: - chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"]) - other_count = len(text) - chinese_count - return int(chinese_count * 0.6 + other_count * 0.3) + """Estimate tokens using improved algorithm. + + Optimizations: + - 更精确的中英文混合文本估算 + - 考虑特殊字符和数字 + - 使用更准确的比率 + """ + if not text: + return 0 + + chinese_count = 0 + english_count = 0 + digit_count = 0 + special_count = 0 + + for c in text: + if "\u4e00" <= c <= "\u9fff": + chinese_count += 1 + elif c.isdigit(): + digit_count += 1 + elif c.isalpha(): + english_count += 1 + else: + special_count += 1 + + # 使用更精确的估算比率 + chinese_tokens = int(chinese_count * 0.55) + english_tokens = int(english_count * 0.25) + digit_tokens = int(digit_count * 0.4) + special_tokens = int(special_count * 0.2) + + return chinese_tokens + english_tokens + digit_tokens + special_tokens + + def get_cache_stats(self) -> dict: + """Get cache hit/miss statistics. + + Returns: + Dictionary with cache stats. + """ + total = self._hit_count + self._miss_count + hit_rate = (self._hit_count / total * 100) if total > 0 else 0 + return { + "hits": self._hit_count, + "misses": self._miss_count, + "hit_rate": f"{hit_rate:.1f}%", + "cache_size": len(self._cache) + } + + def clear_cache(self) -> None: + """Clear the token count cache.""" + self._cache.clear() + self._hit_count = 0 + self._miss_count = 0 diff --git a/tests/conftest.py b/tests/conftest.py index b9807c1ded..66f3431e13 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,381 +1,37 @@ -""" -AstrBot 测试配置 +"""Pytest configuration for AstrBot tests.""" -提供共享的 pytest fixtures 和测试工具。 -""" - -import json -import os +import pytest import sys -from asyncio import Queue from pathlib import Path -from typing import Any -from unittest.mock import AsyncMock, MagicMock - -import pytest -import pytest_asyncio - -# 使用 tests/fixtures/helpers.py 中的共享工具函数,避免重复定义 -from tests.fixtures.helpers import create_mock_llm_response, create_mock_message_component - -# 将项目根目录添加到 sys.path -PROJECT_ROOT = Path(__file__).parent.parent -if str(PROJECT_ROOT) not in sys.path: - sys.path.insert(0, str(PROJECT_ROOT)) - -# 设置测试环境变量 -os.environ.setdefault("TESTING", "true") -os.environ.setdefault("ASTRBOT_TEST_MODE", "true") - - -# ============================================================ -# 测试收集和排序 -# ============================================================ - - -def pytest_collection_modifyitems(session, config, items): # noqa: ARG001 - """重新排序测试:单元测试优先,集成测试在后。""" - unit_tests = [] - integration_tests = [] - deselected = [] - profile = config.getoption("--test-profile") or os.environ.get( - "ASTRBOT_TEST_PROFILE", "all" - ) - - for item in items: - item_path = Path(str(item.path)) - is_integration = "integration" in item_path.parts - - if is_integration: - if item.get_closest_marker("integration") is None: - item.add_marker(pytest.mark.integration) - item.add_marker(pytest.mark.tier_d) - integration_tests.append(item) - else: - if item.get_closest_marker("unit") is None: - item.add_marker(pytest.mark.unit) - if any( - item.get_closest_marker(marker) is not None - for marker in ("platform", "provider", "slow") - ): - item.add_marker(pytest.mark.tier_c) - unit_tests.append(item) - - # 单元测试 -> 集成测试 - ordered_items = unit_tests + integration_tests - if profile == "blocking": - selected_items = [] - for item in ordered_items: - if item.get_closest_marker("tier_c") or item.get_closest_marker("tier_d"): - deselected.append(item) - else: - selected_items.append(item) - if deselected: - config.hook.pytest_deselected(items=deselected) - items[:] = selected_items - return - - items[:] = ordered_items - - -def pytest_addoption(parser): - """增加测试执行档位选择。""" - parser.addoption( - "--test-profile", - action="store", - default=None, - choices=["all", "blocking"], - help="Select test profile. 'blocking' excludes auto-classified tier_c/tier_d tests.", - ) - - -def pytest_configure(config): - """注册自定义标记。""" - config.addinivalue_line("markers", "unit: 单元测试") - config.addinivalue_line("markers", "integration: 集成测试") - config.addinivalue_line("markers", "slow: 慢速测试") - config.addinivalue_line("markers", "platform: 平台适配器测试") - config.addinivalue_line("markers", "provider: LLM Provider 测试") - config.addinivalue_line("markers", "db: 数据库相关测试") - config.addinivalue_line("markers", "tier_c: C-tier tests (optional / non-blocking)") - config.addinivalue_line("markers", "tier_d: D-tier tests (extended / integration)") - - -# ============================================================ -# 临时目录和文件 Fixtures -# ============================================================ - - -@pytest.fixture -def temp_dir(tmp_path: Path) -> Path: - """创建临时目录用于测试。""" - return tmp_path - - -@pytest.fixture -def event_queue() -> Queue: - """Create a shared asyncio queue fixture for tests.""" - return Queue() - - -@pytest.fixture -def platform_settings() -> dict: - """Create a shared empty platform settings fixture for adapter tests.""" - return {} - - -@pytest.fixture -def temp_data_dir(temp_dir: Path) -> Path: - """创建模拟的 data 目录结构。""" - data_dir = temp_dir / "data" - data_dir.mkdir() - - # 创建必要的子目录 - (data_dir / "config").mkdir() - (data_dir / "plugins").mkdir() - (data_dir / "temp").mkdir() - (data_dir / "attachments").mkdir() - - return data_dir - - -@pytest.fixture -def temp_config_file(temp_data_dir: Path) -> Path: - """创建临时配置文件。""" - config_path = temp_data_dir / "config" / "cmd_config.json" - default_config = { - "provider": [], - "platform": [], - "provider_settings": {}, - "default_personality": None, - "timezone": "Asia/Shanghai", - } - config_path.write_text(json.dumps(default_config, indent=2), encoding="utf-8") - return config_path - - -@pytest.fixture -def temp_db_file(temp_data_dir: Path) -> Path: - """创建临时数据库文件路径。""" - return temp_data_dir / "test.db" - - -# ============================================================ -# Mock Fixtures -# ============================================================ - - -@pytest.fixture -def mock_provider(): - """创建模拟的 Provider。""" - provider = MagicMock() - provider.provider_config = { - "id": "test-provider", - "type": "openai_chat_completion", - "model": "gpt-4o-mini", - } - provider.get_model = MagicMock(return_value="gpt-4o-mini") - provider.text_chat = AsyncMock() - provider.text_chat_stream = AsyncMock() - provider.terminate = AsyncMock() - return provider - - -@pytest.fixture -def mock_platform(): - """创建模拟的 Platform。""" - platform = MagicMock() - platform.platform_name = "test_platform" - platform.platform_meta = MagicMock() - platform.platform_meta.support_proactive_message = False - platform.send_message = AsyncMock() - platform.terminate = AsyncMock() - return platform - - -@pytest.fixture -def mock_conversation(): - """创建模拟的 Conversation。""" - from astrbot.core.db.po import ConversationV2 - - return ConversationV2( - conversation_id="test-conv-id", - platform_id="test_platform", - user_id="test_user", - content=[], - persona_id=None, - ) - - -@pytest.fixture -def mock_event(): - """创建模拟的 AstrMessageEvent。""" - event = MagicMock() - event.unified_msg_origin = "test_umo" - event.session_id = "test_session" - event.message_str = "Hello, world!" - event.message_obj = MagicMock() - event.message_obj.message = [] - event.message_obj.sender = MagicMock() - event.message_obj.sender.user_id = "test_user" - event.message_obj.sender.nickname = "Test User" - event.message_obj.group_id = None - event.message_obj.group = None - event.get_platform_name = MagicMock(return_value="test_platform") - event.get_platform_id = MagicMock(return_value="test_platform") - event.get_group_id = MagicMock(return_value=None) - event.get_extra = MagicMock(return_value=None) - event.set_extra = MagicMock() - event.trace = MagicMock() - event.platform_meta = MagicMock() - event.platform_meta.support_proactive_message = False - return event - - -# ============================================================ -# 配置 Fixtures -# ============================================================ - - -@pytest.fixture -def astrbot_config(temp_config_file: Path): - """创建 AstrBotConfig 实例。""" - from astrbot.core.config.astrbot_config import AstrBotConfig - - config = AstrBotConfig() - config._config_path = str(temp_config_file) # noqa: SLF001 - return config - - -@pytest.fixture -def main_agent_build_config(): - """创建 MainAgentBuildConfig 实例。""" - from astrbot.core.astr_main_agent import MainAgentBuildConfig - - return MainAgentBuildConfig( - tool_call_timeout=60, - tool_schema_mode="full", - provider_wake_prefix="", - streaming_response=True, - sanitize_context_by_modalities=False, - kb_agentic_mode=False, - file_extract_enabled=False, - context_limit_reached_strategy="truncate_by_turns", - llm_safety_mode=True, - computer_use_runtime="local", - add_cron_tools=True, - ) - - -# ============================================================ -# 数据库 Fixtures -# ============================================================ - - -@pytest_asyncio.fixture -async def temp_db(temp_db_file: Path): - """创建临时数据库实例。""" - from astrbot.core.db.sqlite import SQLiteDatabase - - db = SQLiteDatabase(str(temp_db_file)) - try: - yield db - finally: - await db.engine.dispose() - if temp_db_file.exists(): - temp_db_file.unlink() - - -# ============================================================ -# Context Fixtures -# ============================================================ - - -@pytest_asyncio.fixture -async def mock_context( - astrbot_config, - temp_db, - mock_provider, - mock_platform, -): - """创建模拟的插件上下文。""" - from asyncio import Queue - - from astrbot.core.star.context import Context - - event_queue = Queue() - - provider_manager = MagicMock() - provider_manager.get_using_provider = MagicMock(return_value=mock_provider) - provider_manager.get_provider_by_id = MagicMock(return_value=mock_provider) - - platform_manager = MagicMock() - conversation_manager = MagicMock() - message_history_manager = MagicMock() - persona_manager = MagicMock() - persona_manager.personas_v3 = [] - astrbot_config_mgr = MagicMock() - knowledge_base_manager = MagicMock() - cron_manager = MagicMock() - subagent_orchestrator = None - - context = Context( - event_queue, - astrbot_config, - temp_db, - provider_manager, - platform_manager, - conversation_manager, - message_history_manager, - persona_manager, - astrbot_config_mgr, - knowledge_base_manager, - cron_manager, - subagent_orchestrator, - ) - - return context - - -# ============================================================ -# Provider Request Fixtures -# ============================================================ - - -@pytest.fixture -def provider_request(): - """创建 ProviderRequest 实例。""" - from astrbot.core.provider.entities import ProviderRequest - - return ProviderRequest( - prompt="Hello", - session_id="test_session", - image_urls=[], - contexts=[], - system_prompt="You are a helpful assistant.", - ) - - -# ============================================================ -# 跳过条件 -# ============================================================ - - -def pytest_runtest_setup(item): - """在测试运行前检查跳过条件。""" - # 跳过需要 API Key 但未设置的 Provider 测试 - if item.get_closest_marker("provider"): - if not os.environ.get("TEST_PROVIDER_API_KEY"): - pytest.skip("TEST_PROVIDER_API_KEY not set") - - # 跳过需要特定平台的测试 - if item.get_closest_marker("platform"): - required_platform = None - marker = item.get_closest_marker("platform") - if marker and marker.args: - required_platform = marker.args[0] - if required_platform and not os.environ.get( - f"TEST_{required_platform.upper()}_ENABLED" - ): - pytest.skip(f"TEST_{required_platform.upper()}_ENABLED not set") +# 添加项目根目录到 Python 路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + + +@pytest.fixture +def sample_messages(): + """提供测试用的示例消息列表。""" + from astrbot.core.agent.message import Message + + return [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="Hello, how are you?"), + Message(role="assistant", content="I'm doing well, thank you!"), + Message(role="user", content="What's the weather like?"), + Message(role="assistant", content="I don't have access to weather data."), + ] + + +@pytest.fixture +def large_message_list(): + """提供大量消息用于测试压缩。""" + from astrbot.core.agent.message import Message + + messages = [] + for i in range(100): + messages.append(Message( + role="user" if i % 2 == 0 else "assistant", + content=f"Message {i}: " + "这是一段比较长的测试消息内容。" * 10 + )) + return messages diff --git a/tests/test_context_compression.py b/tests/test_context_compression.py new file mode 100644 index 0000000000..5fb21f9d09 --- /dev/null +++ b/tests/test_context_compression.py @@ -0,0 +1,356 @@ +"""Tests for context compression optimizations. + +这些测试验证了上下文压缩模块的优化功能: +1. Token 估算算法的精确性 +2. 缓存机制的有效性(使用强缓存键) +3. 压缩器的功能 +""" + +import pytest +import asyncio +from unittest.mock import Mock, AsyncMock + +from astrbot.core.agent.context.token_counter import ( + EstimateTokenCounter, + PER_MESSAGE_OVERHEAD, +) +from astrbot.core.agent.context.compressor import ( + TruncateByTurnsCompressor, + LLMSummaryCompressor, + split_history, + _generate_summary_cache_key, +) +from astrbot.core.agent.context.manager import ContextManager +from astrbot.core.agent.context.config import ContextConfig +from astrbot.core.agent.message import Message + + +class TestEstimateTokenCounter: + """Test cases for improved token counter.""" + + def setup_method(self): + """Setup test fixtures.""" + self.counter = EstimateTokenCounter() + + def test_chinese_text_token_estimation(self): + """测试中文文本的 token 估算。""" + text = "你好,世界!这是一段中文测试文本。" + tokens = self.counter._estimate_tokens(text) + # 中文应该约占 0.55 tokens/字符 + assert tokens > 0 + # 验证估算值合理 + assert tokens < len(text) # 应该比字符数少 + + def test_english_text_token_estimation(self): + """测试英文文本的 token 估算。""" + text = "Hello, world! This is an English test text." + tokens = self.counter._estimate_tokens(text) + assert tokens > 0 + # 英文应该约占 0.25 tokens/字符 + assert tokens < len(text) + + def test_mixed_text_token_estimation(self): + """测试中英文混合文本的 token 估算。""" + text = "你好 Hello, 世界 World! 混合 Mix 文本 Text。" + tokens = self.counter._estimate_tokens(text) + assert tokens > 0 + + def test_digit_token_estimation(self): + """测试数字的 token 估算。""" + text = "1234567890" + tokens = self.counter._estimate_tokens(text) + assert tokens > 0 + + def test_empty_text(self): + """测试空文本。""" + tokens = self.counter._estimate_tokens("") + assert tokens == 0 + + def test_message_list_token_counting(self): + """测试消息列表的 token 计数。""" + messages = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="你好"), + Message(role="assistant", content="你好!有什么可以帮助你的吗?"), + ] + tokens = self.counter.count_tokens(messages) + assert tokens > 0 + # 验证每条消息都有固定开销 + assert tokens >= PER_MESSAGE_OVERHEAD * len(messages) + + def test_cache_functionality_with_strong_key(self): + """测试使用强缓存键的缓存功能。""" + messages = [ + Message(role="user", content="测试消息"), + Message(role="assistant", content="测试回复"), + ] + + # 第一次计数 + tokens1 = self.counter.count_tokens(messages) + + # 第二次计数应该使用缓存 + tokens2 = self.counter.count_tokens(messages) + + assert tokens1 == tokens2 + + # 检查缓存统计 + stats = self.counter.get_cache_stats() + assert stats["hits"] >= 1 + + def test_different_messages_different_cache_keys(self): + """测试不同消息产生不同的缓存键。""" + messages1 = [ + Message(role="user", content="消息1"), + Message(role="assistant", content="回复1"), + ] + messages2 = [ + Message(role="user", content="消息2"), + Message(role="assistant", content="回复2"), + ] + + key1 = self.counter._get_cache_key(messages1) + key2 = self.counter._get_cache_key(messages2) + + assert key1 != key2 + + def test_same_messages_same_cache_key(self): + """测试相同消息产生相同的缓存键。""" + messages1 = [ + Message(role="user", content="相同消息"), + Message(role="assistant", content="相同回复"), + ] + messages2 = [ + Message(role="user", content="相同消息"), + Message(role="assistant", content="相同回复"), + ] + + key1 = self.counter._get_cache_key(messages1) + key2 = self.counter._get_cache_key(messages2) + + assert key1 == key2 + + def test_cache_clear(self): + """测试缓存清除。""" + messages = [Message(role="user", content="测试")] + self.counter.count_tokens(messages) + + # 清除缓存 + self.counter.clear_cache() + + stats = self.counter.get_cache_stats() + assert stats["hits"] == 0 + assert stats["cache_size"] == 0 + + +class TestTruncateByTurnsCompressor: + """Test cases for truncate by turns compressor.""" + + def setup_method(self): + """Setup test fixtures.""" + self.compressor = TruncateByTurnsCompressor(truncate_turns=1) + + def test_should_compress_above_threshold(self): + """测试超过阈值时触发压缩。""" + messages = [ + Message(role="user", content="测试消息"), + Message(role="assistant", content="测试回复"), + ] + # max_tokens=100, 当前 tokens 应该远超阈值 + assert self.compressor.should_compress(messages, 90, 100) is True + + def test_should_compress_below_threshold(self): + """测试未超过阈值时不触发压缩。""" + messages = [Message(role="user", content="短消息")] + assert self.compressor.should_compress(messages, 10, 100) is False + + def test_should_compress_zero_max_tokens(self): + """测试 max_tokens 为 0 时不触发压缩。""" + messages = [Message(role="user", content="测试")] + assert self.compressor.should_compress(messages, 50, 0) is False + + +class TestSplitHistory: + """Test cases for split_history function.""" + + def test_split_with_enough_messages(self): + """测试消息数量足够时的分割。""" + messages = [ + Message(role="system", content="System"), + Message(role="user", content="User 1"), + Message(role="assistant", content="Assistant 1"), + Message(role="user", content="User 2"), + Message(role="assistant", content="Assistant 2"), + Message(role="user", content="User 3"), + Message(role="assistant", content="Assistant 3"), + ] + + system, to_summarize, recent = split_history(messages, keep_recent=2) + + assert len(system) == 1 # system message + assert len(recent) >= 2 # 至少保留最近的消息 + + def test_split_with_few_messages(self): + """测试消息数量不足时的分割。""" + messages = [ + Message(role="user", content="User 1"), + Message(role="assistant", content="Assistant 1"), + ] + + system, to_summarize, recent = split_history(messages, keep_recent=4) + + assert len(to_summarize) == 0 # 没有需要摘要的消息 + assert len(recent) == 2 + + +class TestGenerateSummaryCacheKey: + """Test cases for summary cache key generation.""" + + def test_different_histories_different_keys(self): + """测试不同历史记录产生不同的缓存键。""" + messages1 = [ + Message(role="user", content="用户消息1"), + Message(role="assistant", content="助手回复1"), + ] + messages2 = [ + Message(role="user", content="用户消息2"), + Message(role="assistant", content="助手回复2"), + ] + + key1 = _generate_summary_cache_key(messages1) + key2 = _generate_summary_cache_key(messages2) + + assert key1 != key2 + + def test_same_history_same_key(self): + """测试相同历史记录产生相同的缓存键。""" + messages1 = [ + Message(role="user", content="相同消息"), + Message(role="assistant", content="相同回复"), + ] + messages2 = [ + Message(role="user", content="相同消息"), + Message(role="assistant", content="相同回复"), + ] + + key1 = _generate_summary_cache_key(messages1) + key2 = _generate_summary_cache_key(messages2) + + assert key1 == key2 + + +class TestLLMSummaryCompressor: + """Test cases for LLM summary compressor.""" + + def setup_method(self): + """Setup test fixtures.""" + self.mock_provider = Mock() + self.mock_provider.text_chat = AsyncMock() + self.mock_provider.text_chat.return_value = Mock(completion_text="这是一段摘要。") + + self.compressor = LLMSummaryCompressor( + provider=self.mock_provider, + keep_recent=2 + ) + + @pytest.mark.asyncio + async def test_generate_summary(self): + """测试生成摘要。""" + messages = [ + Message(role="system", content="System"), + Message(role="user", content="User 1"), + Message(role="assistant", content="Assistant 1"), + Message(role="user", content="User 2"), + Message(role="assistant", content="Assistant 2"), + Message(role="user", content="User 3"), + Message(role="assistant", content="Assistant 3"), + ] + + result = await self.compressor(messages) + + # 验证摘要已生成 + assert len(result) >= 3 + # 验证 LLM 被调用 + self.mock_provider.text_chat.assert_called_once() + + @pytest.mark.asyncio + async def test_cache_summary(self): + """测试摘要缓存。""" + messages = [ + Message(role="system", content="System"), + Message(role="user", content="User 1"), + Message(role="assistant", content="Assistant 1"), + Message(role="user", content="User 2"), + Message(role="assistant", content="Assistant 2"), + ] + + # 第一次调用 + await self.compressor(messages) + + # 第二次调用应该使用缓存 + await self.compressor(messages) + + # LLM 只应该被调用一次 + assert self.mock_provider.text_chat.call_count == 1 + + +class TestContextManager: + """Test cases for context manager.""" + + def setup_method(self): + """Setup test fixtures.""" + self.config = ContextConfig( + max_context_tokens=1000, + truncate_turns=1, + enforce_max_turns=-1, + ) + self.manager = ContextManager(self.config) + + @pytest.mark.asyncio + async def test_process_no_compression_needed(self): + """测试不需要压缩的情况。""" + messages = [ + Message(role="user", content="短消息"), + ] + + result = await self.manager.process(messages) + + assert len(result) == 1 + + @pytest.mark.asyncio + async def test_process_with_compression(self): + """测试需要压缩的情况。""" + # 创建大量消息以触发压缩 + messages = [] + for i in range(50): + messages.append(Message(role="user", content=f"用户消息 {i} " * 50)) + messages.append(Message(role="assistant", content=f"助手回复 {i} " * 50)) + + # 设置较小的 max_context_tokens 以触发压缩 + self.config.max_context_tokens = 100 + + result = await self.manager.process(messages) + + # 验证消息被压缩 + assert len(result) < len(messages) + + def test_get_stats(self): + """测试获取统计信息。""" + stats = self.manager.get_stats() + + assert "compression_count" in stats + assert "last_token_count" in stats + assert "last_messages_fingerprint" in stats + + def test_reset_stats(self): + """测试重置统计信息。""" + self.manager._compression_count = 5 + + self.manager.reset_stats() + + assert self.manager._compression_count == 0 + assert self.manager._last_token_count is None + assert self.manager._last_messages_fingerprint is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])