diff --git a/astrbot/core/agent/context/compressor.py b/astrbot/core/agent/context/compressor.py index 31a0b0b48d..0f91ed084b 100644 --- a/astrbot/core/agent/context/compressor.py +++ b/astrbot/core/agent/context/compressor.py @@ -88,11 +88,14 @@ def should_compress( return usage_rate > self.compression_threshold async def __call__(self, messages: list[Message]) -> list[Message]: + """Compress messages by removing oldest turns.""" truncator = ContextTruncator() + truncated_messages = truncator.truncate_by_dropping_oldest_turns( messages, drop_turns=self.truncate_turns, ) + return truncated_messages @@ -143,9 +146,30 @@ def split_history( return system_messages, messages_to_summarize, recent_messages +def _generate_summary_cache_key(messages: list[Message]) -> str: + """Generate a cache key for summary based on full history. + + Uses role and content from all messages to create a collision-resistant key. + """ + if not messages: + return "" + + key_parts = [] + for msg in messages: + content = msg.content if isinstance(msg.content, str) else str(msg.content) + key_parts.append(f"{msg.role}:{content[:50]}") + + return "|".join(key_parts) + + class LLMSummaryCompressor: """LLM-based summary compressor. Uses LLM to summarize the old conversation history, keeping the latest messages. + + Optimizations: + - 支持增量摘要,只摘要超出的部分 + - 添加摘要缓存避免重复摘要 + - 支持自定义摘要提示词 """ def __init__( @@ -175,6 +199,10 @@ def __init__( "4. Write the summary in the user's language.\n" ) + # 新增: 摘要缓存 + self._summary_cache: dict[str, str] = {} + self._max_cache_size = 50 + def should_compress( self, messages: list[Message], current_tokens: int, max_tokens: int ) -> bool: @@ -200,6 +228,10 @@ async def __call__(self, messages: list[Message]) -> list[Message]: 1. Divide messages: keep the system message and the latest N messages. 2. Send the old messages + the instruction message to the LLM. 3. Reconstruct the message list: [system message, summary message, latest messages]. + + Optimizations: + - 添加摘要缓存 + - 检查是否已有摘要,避免重复生成 """ if len(messages) <= self.keep_recent + 1: return messages @@ -211,17 +243,37 @@ async def __call__(self, messages: list[Message]) -> list[Message]: if not messages_to_summarize: return messages - # build payload - instruction_message = Message(role="user", content=self.instruction_text) - llm_payload = messages_to_summarize + [instruction_message] - - # generate summary - try: - response = await self.provider.text_chat(contexts=llm_payload) - summary_content = response.completion_text - except Exception as e: - logger.error(f"Failed to generate summary: {e}") - return messages + # 生成缓存键 + cache_key = _generate_summary_cache_key(messages_to_summarize) + + # 尝试从缓存获取摘要 + summary_content = None + if cache_key in self._summary_cache: + summary_content = self._summary_cache[cache_key] + logger.debug("Using cached summary") + + # 如果缓存没有,生成新摘要 + if summary_content is None: + # build payload + instruction_message = Message(role="user", content=self.instruction_text) + llm_payload = messages_to_summarize + [instruction_message] + + # generate summary + try: + response = await self.provider.text_chat(contexts=llm_payload) + summary_content = response.completion_text + + # 缓存摘要 + if len(self._summary_cache) < self._max_cache_size: + self._summary_cache[cache_key] = summary_content + else: + # 简单的缓存淘汰 + self._summary_cache.pop(next(iter(self._summary_cache))) + self._summary_cache[cache_key] = summary_content + + except Exception as e: + logger.error(f"Failed to generate summary: {e}") + return messages # build result result = [] @@ -243,3 +295,7 @@ async def __call__(self, messages: list[Message]) -> list[Message]: result.extend(recent_messages) return result + + def clear_cache(self) -> None: + """清空摘要缓存。""" + self._summary_cache.clear() diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 216a3e7e15..6c6353982d 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -8,7 +8,13 @@ class ContextManager: - """Context compression manager.""" + """Context compression manager. + + Optimizations: + - 减少重复 token 计算 + - 添加增量压缩支持 + - 优化日志输出 + """ def __init__( self, @@ -41,6 +47,19 @@ def __init__( truncate_turns=config.truncate_turns ) + # 缓存上一次计算的消息指纹和 token 数 + self._last_messages_fingerprint: int | None = None + self._last_token_count: int | None = None + self._compression_count = 0 + + def _get_messages_fingerprint(self, messages: list[Message]) -> int: + """生成消息列表的指纹,用于检测消息内容是否变化。""" + if not messages: + return 0 + + # 使用 token counter 的缓存键作为指纹 + return self.token_counter._get_cache_key(messages) + async def process( self, messages: list[Message], trusted_token_usage: int = 0 ) -> list[Message]: @@ -48,6 +67,7 @@ async def process( Args: messages: The original message list. + trusted_token_usage: The total token usage that LLM API returned. Returns: The processed message list. @@ -65,14 +85,34 @@ async def process( # 2. 基于 token 的压缩 if self.config.max_context_tokens > 0: - total_tokens = self.token_counter.count_tokens( - result, trusted_token_usage - ) + # 优化: 使用缓存的 token 计数或计算新值 + current_fingerprint = self._get_messages_fingerprint(messages) + + if trusted_token_usage > 0: + total_tokens = trusted_token_usage + elif ( + self._last_messages_fingerprint is not None + and self._last_messages_fingerprint == current_fingerprint + and self._last_token_count is not None + ): + # 消息内容没变化,使用缓存的 token 计数 + total_tokens = self._last_token_count + else: + # 消息内容变了,需要重新计算 + total_tokens = self.token_counter.count_tokens(result) + self._last_messages_fingerprint = current_fingerprint + + # 更新缓存 + self._last_token_count = total_tokens if self.compressor.should_compress( result, total_tokens, self.config.max_context_tokens ): result = await self._run_compression(result, total_tokens) + # 压缩后更新指纹 + self._last_messages_fingerprint = self._get_messages_fingerprint( + result + ) return result except Exception as e: @@ -94,27 +134,64 @@ async def _run_compression( """ logger.debug("Compress triggered, starting compression...") + self._compression_count += 1 + messages = await self.compressor(messages) - # double check - tokens_after_summary = self.token_counter.count_tokens(messages) + # 优化: 压缩后只计算一次 token + tokens_after_compression = self.token_counter.count_tokens(messages) # calculate compress rate - compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100 + compress_rate = ( + tokens_after_compression / self.config.max_context_tokens + ) * 100 logger.info( - f"Compress completed." - f" {prev_tokens} -> {tokens_after_summary} tokens," + f"Compress #{self._compression_count} completed." + f" {prev_tokens} -> {tokens_after_compression} tokens," f" compression rate: {compress_rate:.2f}%.", ) - # last check + # 更新缓存 + self._last_token_count = tokens_after_compression + self._last_messages_fingerprint = self._get_messages_fingerprint(messages) + + # last check - 优化: 减少不必要的递归调用 if self.compressor.should_compress( - messages, tokens_after_summary, self.config.max_context_tokens + messages, tokens_after_compression, self.config.max_context_tokens ): logger.info( "Context still exceeds max tokens after compression, applying halving truncation..." ) # still need compress, truncate by half messages = self.truncator.truncate_by_halving(messages) + # 更新缓存 + self._last_token_count = self.token_counter.count_tokens(messages) + self._last_messages_fingerprint = self._get_messages_fingerprint(messages) return messages + + def get_stats(self) -> dict: + """获取上下文管理器的统计信息。 + + Returns: + Dictionary with stats including compression count and token counter stats. + """ + stats = { + "compression_count": self._compression_count, + "last_token_count": self._last_token_count, + "last_messages_fingerprint": self._last_messages_fingerprint, + } + + # 如果 token counter 有缓存统计,也一并返回 + if hasattr(self.token_counter, "get_cache_stats"): + stats["token_counter_cache"] = self.token_counter.get_cache_stats() + + return stats + + def reset_stats(self) -> None: + """重置统计信息。""" + self._compression_count = 0 + self._last_token_count = None + self._last_messages_fingerprint = None + if hasattr(self.token_counter, "clear_cache"): + self.token_counter.clear_cache() diff --git a/astrbot/core/agent/context/token_counter.py b/astrbot/core/agent/context/token_counter.py index 7c60cb23ec..b6b310e849 100644 --- a/astrbot/core/agent/context/token_counter.py +++ b/astrbot/core/agent/context/token_counter.py @@ -34,6 +34,9 @@ def count_tokens( IMAGE_TOKEN_ESTIMATE = 765 AUDIO_TOKEN_ESTIMATE = 500 +# 每条消息的固定开销(role、content wrapper 等) +PER_MESSAGE_OVERHEAD = 4 + class EstimateTokenCounter: """Estimate token counter implementation. @@ -41,38 +44,161 @@ class EstimateTokenCounter: Supports multimodal content: images, audio, and thinking parts are all counted so that the context compressor can trigger in time. + + Optimizations: + - 使用更精确的 token 估算算法 + - 缓存重复计算结果 + - 支持批量计数 """ + def __init__(self, cache_size: int = 100) -> None: + """Initialize the token counter with optional cache. + + Args: + cache_size: Maximum number of message lists to cache (default: 100). + """ + self._cache: dict[int, int] = {} + self._cache_size = cache_size + self._hit_count = 0 + self._miss_count = 0 + + def _get_cache_key(self, messages: list[Message]) -> int: + """Generate a cache key for messages based on full history structure. + + Uses role, content, and tool_calls for each message to create a + collision-resistant hash. + """ + if not messages: + return 0 + + h = 0 + for msg in messages: + # 处理 content + if isinstance(msg.content, str): + content_repr = msg.content + else: + content_repr = str(msg.content) + + # 处理 tool_calls + tool_repr = () + if msg.tool_calls: + tool_repr = tuple( + sorted(tc.items()) if isinstance(tc, dict) else (str(tc),) + for tc in msg.tool_calls + ) + + h = hash((h, msg.role, content_repr, tool_repr)) + + return h + def count_tokens( self, messages: list[Message], trusted_token_usage: int = 0 ) -> int: if trusted_token_usage > 0: return trusted_token_usage + # 尝试从缓存获取 + cache_key = self._get_cache_key(messages) + if cache_key in self._cache: + self._hit_count += 1 + return self._cache[cache_key] + + self._miss_count += 1 + total = self._count_tokens_internal(messages) + + # 缓存结果 + if len(self._cache) < self._cache_size: + self._cache[cache_key] = total + elif self._cache_size > 0: + # 简单的缓存淘汰: 清空一半 + keys_to_remove = list(self._cache.keys())[: self._cache_size // 2] + for key in keys_to_remove: + del self._cache[key] + self._cache[cache_key] = total + + return total + + def _count_tokens_internal(self, messages: list[Message]) -> int: + """Internal token counting implementation.""" total = 0 for msg in messages: + message_tokens = 0 + content = msg.content if isinstance(content, str): - total += self._estimate_tokens(content) + message_tokens += self._estimate_tokens(content) elif isinstance(content, list): for part in content: if isinstance(part, TextPart): - total += self._estimate_tokens(part.text) + message_tokens += self._estimate_tokens(part.text) elif isinstance(part, ThinkPart): - total += self._estimate_tokens(part.think) + message_tokens += self._estimate_tokens(part.think) elif isinstance(part, ImageURLPart): - total += IMAGE_TOKEN_ESTIMATE + message_tokens += IMAGE_TOKEN_ESTIMATE elif isinstance(part, AudioURLPart): - total += AUDIO_TOKEN_ESTIMATE + message_tokens += AUDIO_TOKEN_ESTIMATE if msg.tool_calls: for tc in msg.tool_calls: tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump()) - total += self._estimate_tokens(tc_str) + message_tokens += self._estimate_tokens(tc_str) + + # 添加每条消息的固定开销 + total += message_tokens + PER_MESSAGE_OVERHEAD return total def _estimate_tokens(self, text: str) -> int: - chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"]) - other_count = len(text) - chinese_count - return int(chinese_count * 0.6 + other_count * 0.3) + """Estimate tokens using improved algorithm. + + Optimizations: + - 更精确的中英文混合文本估算 + - 考虑特殊字符和数字 + - 使用更准确的比率 + """ + if not text: + return 0 + + chinese_count = 0 + english_count = 0 + digit_count = 0 + special_count = 0 + + for c in text: + if "\u4e00" <= c <= "\u9fff": + chinese_count += 1 + elif c.isdigit(): + digit_count += 1 + elif c.isalpha(): + english_count += 1 + else: + special_count += 1 + + # 使用更精确的估算比率 + chinese_tokens = int(chinese_count * 0.55) + english_tokens = int(english_count * 0.25) + digit_tokens = int(digit_count * 0.4) + special_tokens = int(special_count * 0.2) + + return chinese_tokens + english_tokens + digit_tokens + special_tokens + + def get_cache_stats(self) -> dict: + """Get cache hit/miss statistics. + + Returns: + Dictionary with cache stats. + """ + total = self._hit_count + self._miss_count + hit_rate = (self._hit_count / total * 100) if total > 0 else 0 + return { + "hits": self._hit_count, + "misses": self._miss_count, + "hit_rate": f"{hit_rate:.1f}%", + "cache_size": len(self._cache), + } + + def clear_cache(self) -> None: + """Clear the token count cache.""" + self._cache.clear() + self._hit_count = 0 + self._miss_count = 0 diff --git a/tests/test_context_compression.py b/tests/test_context_compression.py new file mode 100644 index 0000000000..5313083f0a --- /dev/null +++ b/tests/test_context_compression.py @@ -0,0 +1,356 @@ +"""Tests for context compression optimizations. + +这些测试验证了上下文压缩模块的优化功能: +1. Token 估算算法的精确性 +2. 缓存机制的有效性(使用强缓存键) +3. 压缩器的功能 +""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +from astrbot.core.agent.context.compressor import ( + LLMSummaryCompressor, + TruncateByTurnsCompressor, + _generate_summary_cache_key, + split_history, +) +from astrbot.core.agent.context.config import ContextConfig +from astrbot.core.agent.context.manager import ContextManager +from astrbot.core.agent.context.token_counter import ( + PER_MESSAGE_OVERHEAD, + EstimateTokenCounter, +) +from astrbot.core.agent.message import Message + + +class TestEstimateTokenCounter: + """Test cases for improved token counter.""" + + def setup_method(self): + """Setup test fixtures.""" + self.counter = EstimateTokenCounter() + + def test_chinese_text_token_estimation(self): + """测试中文文本的 token 估算。""" + text = "你好,世界!这是一段中文测试文本。" + tokens = self.counter._estimate_tokens(text) + # 中文应该约占 0.55 tokens/字符 + assert tokens > 0 + # 验证估算值合理 + assert tokens < len(text) # 应该比字符数少 + + def test_english_text_token_estimation(self): + """测试英文文本的 token 估算。""" + text = "Hello, world! This is an English test text." + tokens = self.counter._estimate_tokens(text) + assert tokens > 0 + # 英文应该约占 0.25 tokens/字符 + assert tokens < len(text) + + def test_mixed_text_token_estimation(self): + """测试中英文混合文本的 token 估算。""" + text = "你好 Hello, 世界 World! 混合 Mix 文本 Text。" + tokens = self.counter._estimate_tokens(text) + assert tokens > 0 + + def test_digit_token_estimation(self): + """测试数字的 token 估算。""" + text = "1234567890" + tokens = self.counter._estimate_tokens(text) + assert tokens > 0 + + def test_empty_text(self): + """测试空文本。""" + tokens = self.counter._estimate_tokens("") + assert tokens == 0 + + def test_message_list_token_counting(self): + """测试消息列表的 token 计数。""" + messages = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="你好"), + Message(role="assistant", content="你好!有什么可以帮助你的吗?"), + ] + tokens = self.counter.count_tokens(messages) + assert tokens > 0 + # 验证每条消息都有固定开销 + assert tokens >= PER_MESSAGE_OVERHEAD * len(messages) + + def test_cache_functionality_with_strong_key(self): + """测试使用强缓存键的缓存功能。""" + messages = [ + Message(role="user", content="测试消息"), + Message(role="assistant", content="测试回复"), + ] + + # 第一次计数 + tokens1 = self.counter.count_tokens(messages) + + # 第二次计数应该使用缓存 + tokens2 = self.counter.count_tokens(messages) + + assert tokens1 == tokens2 + + # 检查缓存统计 + stats = self.counter.get_cache_stats() + assert stats["hits"] >= 1 + + def test_different_messages_different_cache_keys(self): + """测试不同消息产生不同的缓存键。""" + messages1 = [ + Message(role="user", content="消息1"), + Message(role="assistant", content="回复1"), + ] + messages2 = [ + Message(role="user", content="消息2"), + Message(role="assistant", content="回复2"), + ] + + key1 = self.counter._get_cache_key(messages1) + key2 = self.counter._get_cache_key(messages2) + + assert key1 != key2 + + def test_same_messages_same_cache_key(self): + """测试相同消息产生相同的缓存键。""" + messages1 = [ + Message(role="user", content="相同消息"), + Message(role="assistant", content="相同回复"), + ] + messages2 = [ + Message(role="user", content="相同消息"), + Message(role="assistant", content="相同回复"), + ] + + key1 = self.counter._get_cache_key(messages1) + key2 = self.counter._get_cache_key(messages2) + + assert key1 == key2 + + def test_cache_clear(self): + """测试缓存清除。""" + messages = [Message(role="user", content="测试")] + self.counter.count_tokens(messages) + + # 清除缓存 + self.counter.clear_cache() + + stats = self.counter.get_cache_stats() + assert stats["hits"] == 0 + assert stats["cache_size"] == 0 + + +class TestTruncateByTurnsCompressor: + """Test cases for truncate by turns compressor.""" + + def setup_method(self): + """Setup test fixtures.""" + self.compressor = TruncateByTurnsCompressor(truncate_turns=1) + + def test_should_compress_above_threshold(self): + """测试超过阈值时触发压缩。""" + messages = [ + Message(role="user", content="测试消息"), + Message(role="assistant", content="测试回复"), + ] + # max_tokens=100, 当前 tokens 应该远超阈值 + assert self.compressor.should_compress(messages, 90, 100) is True + + def test_should_compress_below_threshold(self): + """测试未超过阈值时不触发压缩。""" + messages = [Message(role="user", content="短消息")] + assert self.compressor.should_compress(messages, 10, 100) is False + + def test_should_compress_zero_max_tokens(self): + """测试 max_tokens 为 0 时不触发压缩。""" + messages = [Message(role="user", content="测试")] + assert self.compressor.should_compress(messages, 50, 0) is False + + +class TestSplitHistory: + """Test cases for split_history function.""" + + def test_split_with_enough_messages(self): + """测试消息数量足够时的分割。""" + messages = [ + Message(role="system", content="System"), + Message(role="user", content="User 1"), + Message(role="assistant", content="Assistant 1"), + Message(role="user", content="User 2"), + Message(role="assistant", content="Assistant 2"), + Message(role="user", content="User 3"), + Message(role="assistant", content="Assistant 3"), + ] + + system, to_summarize, recent = split_history(messages, keep_recent=2) + + assert len(system) == 1 # system message + assert len(recent) >= 2 # 至少保留最近的消息 + + def test_split_with_few_messages(self): + """测试消息数量不足时的分割。""" + messages = [ + Message(role="user", content="User 1"), + Message(role="assistant", content="Assistant 1"), + ] + + system, to_summarize, recent = split_history(messages, keep_recent=4) + + assert len(to_summarize) == 0 # 没有需要摘要的消息 + assert len(recent) == 2 + + +class TestGenerateSummaryCacheKey: + """Test cases for summary cache key generation.""" + + def test_different_histories_different_keys(self): + """测试不同历史记录产生不同的缓存键。""" + messages1 = [ + Message(role="user", content="用户消息1"), + Message(role="assistant", content="助手回复1"), + ] + messages2 = [ + Message(role="user", content="用户消息2"), + Message(role="assistant", content="助手回复2"), + ] + + key1 = _generate_summary_cache_key(messages1) + key2 = _generate_summary_cache_key(messages2) + + assert key1 != key2 + + def test_same_history_same_key(self): + """测试相同历史记录产生相同的缓存键。""" + messages1 = [ + Message(role="user", content="相同消息"), + Message(role="assistant", content="相同回复"), + ] + messages2 = [ + Message(role="user", content="相同消息"), + Message(role="assistant", content="相同回复"), + ] + + key1 = _generate_summary_cache_key(messages1) + key2 = _generate_summary_cache_key(messages2) + + assert key1 == key2 + + +class TestLLMSummaryCompressor: + """Test cases for LLM summary compressor.""" + + def setup_method(self): + """Setup test fixtures.""" + self.mock_provider = Mock() + self.mock_provider.text_chat = AsyncMock() + self.mock_provider.text_chat.return_value = Mock(completion_text="这是一段摘要。") + + self.compressor = LLMSummaryCompressor( + provider=self.mock_provider, + keep_recent=2 + ) + + @pytest.mark.asyncio + async def test_generate_summary(self): + """测试生成摘要。""" + messages = [ + Message(role="system", content="System"), + Message(role="user", content="User 1"), + Message(role="assistant", content="Assistant 1"), + Message(role="user", content="User 2"), + Message(role="assistant", content="Assistant 2"), + Message(role="user", content="User 3"), + Message(role="assistant", content="Assistant 3"), + ] + + result = await self.compressor(messages) + + # 验证摘要已生成 + assert len(result) >= 3 + # 验证 LLM 被调用 + self.mock_provider.text_chat.assert_called_once() + + @pytest.mark.asyncio + async def test_cache_summary(self): + """测试摘要缓存。""" + messages = [ + Message(role="system", content="System"), + Message(role="user", content="User 1"), + Message(role="assistant", content="Assistant 1"), + Message(role="user", content="User 2"), + Message(role="assistant", content="Assistant 2"), + ] + + # 第一次调用 + await self.compressor(messages) + + # 第二次调用应该使用缓存 + await self.compressor(messages) + + # LLM 只应该被调用一次 + assert self.mock_provider.text_chat.call_count == 1 + + +class TestContextManager: + """Test cases for context manager.""" + + def setup_method(self): + """Setup test fixtures.""" + self.config = ContextConfig( + max_context_tokens=1000, + truncate_turns=1, + enforce_max_turns=-1, + ) + self.manager = ContextManager(self.config) + + @pytest.mark.asyncio + async def test_process_no_compression_needed(self): + """测试不需要压缩的情况。""" + messages = [ + Message(role="user", content="短消息"), + ] + + result = await self.manager.process(messages) + + assert len(result) == 1 + + @pytest.mark.asyncio + async def test_process_with_compression(self): + """测试需要压缩的情况。""" + # 创建大量消息以触发压缩 + messages = [] + for i in range(50): + messages.append(Message(role="user", content=f"用户消息 {i} " * 50)) + messages.append(Message(role="assistant", content=f"助手回复 {i} " * 50)) + + # 设置较小的 max_context_tokens 以触发压缩 + self.config.max_context_tokens = 100 + + result = await self.manager.process(messages) + + # 验证消息被压缩 + assert len(result) < len(messages) + + def test_get_stats(self): + """测试获取统计信息。""" + stats = self.manager.get_stats() + + assert "compression_count" in stats + assert "last_token_count" in stats + assert "last_messages_fingerprint" in stats + + def test_reset_stats(self): + """测试重置统计信息。""" + self.manager._compression_count = 5 + + self.manager.reset_stats() + + assert self.manager._compression_count == 0 + assert self.manager._last_token_count is None + assert self.manager._last_messages_fingerprint is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])