Conversation
- Token 估算算法改进(中英数特字符分别计算) - 添加 Token 计数缓存和摘要缓存 - ContextManager 添加指纹机制减少重复计算 - 修复缓存键碰撞和 overhead 重复计算 bug
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! 此拉取请求旨在显著优化上下文压缩模块的性能和准确性。通过引入先进的 token 估算算法、全面的缓存机制以及 ContextManager 的内部改进,系统能够更高效地管理对话上下文,减少不必要的计算开销,并修复了关键的缓存和计算错误,从而提升了整体的用户体验和资源利用率。 Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Hey - I've found 2 issues, and left some high level feedback:
- ContextManager currently relies on the token counter’s private method _get_cache_key for fingerprints; consider exposing a public fingerprint/cache-key API on EstimateTokenCounter instead so the manager doesn’t depend on another class’s internals.
- _generate_summary_cache_key only uses the first 50 characters of each message’s content and ignores provider/instruction configuration, which can lead to summary cache collisions or stale summaries if those change; consider hashing full content plus relevant config to make the cache safer.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- ContextManager currently relies on the token counter’s private method _get_cache_key for fingerprints; consider exposing a public fingerprint/cache-key API on EstimateTokenCounter instead so the manager doesn’t depend on another class’s internals.
- _generate_summary_cache_key only uses the first 50 characters of each message’s content and ignores provider/instruction configuration, which can lead to summary cache collisions or stale summaries if those change; consider hashing full content plus relevant config to make the cache safer.
## Individual Comments
### Comment 1
<location path="tests/test_context_compression.py" line_range="241-250" />
<code_context>
+class TestLLMSummaryCompressor:
</code_context>
<issue_to_address>
**suggestion (testing):** Add tests for error handling, cache clearing, and `should_compress` behavior in `LLMSummaryCompressor`.
Right now we only cover the happy path and basic caching. It would be valuable to: (1) add a test where `provider.text_chat` raises to confirm `__call__` returns the original messages; (2) verify `clear_cache()` actually invalidates entries by asserting the provider is called again (e.g., `call_count` increases after clearing); and (3) add focused tests for `should_compress` using edge values around the threshold for `current_tokens` vs `max_tokens` to validate the decision boundary.
Suggested implementation:
```python
class TestLLMSummaryCompressor:
"""Test cases for LLM summary compressor."""
def setup_method(self):
"""Setup test fixtures."""
self.mock_provider = Mock()
self.mock_provider.text_chat = AsyncMock()
self.mock_provider.text_chat.return_value = Mock(completion_text="这是一段摘要。")
# NOTE: adjust arguments here if LLMSummaryCompressor requires more params.
self.compressor = LLMSummaryCompressor(
provider=self.mock_provider,
)
@pytest.mark.asyncio
async def test_error_handling_returns_original_messages_on_provider_failure(self):
"""When provider.text_chat raises, __call__ should return original messages unchanged."""
messages = [
Message(role="user", content="需要被压缩的对话内容"),
Message(role="assistant", content="助手的长回复"),
]
# Force the provider to fail
self.mock_provider.text_chat.side_effect = RuntimeError("LLM error")
# Call compressor (adjust call signature if needed)
result_messages = await self.compressor(
messages=messages,
max_tokens=2048,
language="zh",
)
# Should fall back to returning the original messages
assert result_messages == messages
@pytest.mark.asyncio
async def test_clear_cache_causes_provider_to_be_called_again(self):
"""clear_cache should invalidate cached summaries so provider is called again."""
messages = [
Message(role="user", content="第一条消息"),
Message(role="assistant", content="第一条回复"),
]
# First call: should hit provider
result1 = await self.compressor(
messages=messages,
max_tokens=2048,
language="zh",
)
assert self.mock_provider.text_chat.call_count == 1
assert "这是一段摘要" in result1[-1].content
# Change provider output to detect cache behavior
self.mock_provider.text_chat.reset_mock()
self.mock_provider.text_chat.return_value = Mock(completion_text="新的摘要内容")
# Second call without clearing cache: should use cached result, not call provider
result2 = await self.compressor(
messages=messages,
max_tokens=2048,
language="zh",
)
assert self.mock_provider.text_chat.call_count == 0
# Still old summary because of cache
assert "这是一段摘要" in result2[-1].content
# Clear cache and call again: provider should be hit and new summary used
self.compressor.clear_cache()
result3 = await self.compressor(
messages=messages,
max_tokens=2048,
language="zh",
)
assert self.mock_provider.text_chat.call_count == 1
assert "新的摘要内容" in result3[-1].content
def test_should_compress_below_threshold(self):
"""should_compress should return False when current_tokens is just below threshold."""
max_tokens = 1000
# Derive threshold ratio from compressor if available; otherwise fall back to a default.
ratio = getattr(self.compressor, "compression_ratio_threshold", 0.8)
threshold_tokens = int(max_tokens * ratio)
current_tokens = threshold_tokens - 1 # just below
assert self.compressor.should_compress(current_tokens=current_tokens, max_tokens=max_tokens) is False
def test_should_compress_at_threshold(self):
"""Validate decision boundary when current_tokens is at the threshold."""
max_tokens = 1000
ratio = getattr(self.compressor, "compression_ratio_threshold", 0.8)
threshold_tokens = int(max_tokens * ratio)
current_tokens = threshold_tokens # exactly at threshold
# Most implementations consider reaching the threshold as needing compression.
assert self.compressor.should_compress(current_tokens=current_tokens, max_tokens=max_tokens) is True
def test_should_compress_above_threshold(self):
"""should_compress should return True when current_tokens is just above threshold."""
max_tokens = 1000
ratio = getattr(self.compressor, "compression_ratio_threshold", 0.8)
threshold_tokens = int(max_tokens * ratio)
current_tokens = threshold_tokens + 1 # just above
assert self.compressor.should_compress(current_tokens=current_tokens, max_tokens=max_tokens) is True
```
1. Ensure `pytest` is imported at the top of `tests/test_context_compression.py` if it is not already:
- `import pytest`
2. Confirm the actual `LLMSummaryCompressor` constructor signature:
- If it requires additional parameters (e.g., `tokenizer`, `summary_language`, `summary_max_tokens`, etc.), add them in `setup_method` where `self.compressor = LLMSummaryCompressor(...)` is created.
3. Adjust the `__call__` signature used in the tests:
- Update the calls to `await self.compressor(messages=..., max_tokens=..., language=...)` to match the real `__call__` parameters (names and count).
4. If `LLMSummaryCompressor` uses a different attribute name than `compression_ratio_threshold` for the threshold ratio, update the tests:
- Replace `"compression_ratio_threshold"` with the actual attribute name that represents the compression threshold.
5. If `__call__` returns something other than a list of `Message` objects (e.g., a tuple with metadata), adapt the assertions to check the correct part of the return value that contains the compressed messages.
</issue_to_address>
### Comment 2
<location path="astrbot/core/agent/context/token_counter.py" line_range="65" />
<code_context>
+ self._hit_count = 0
+ self._miss_count = 0
+
+ def _get_cache_key(self, messages: list[Message]) -> int:
+ """Generate a cache key for messages based on full history structure.
+
</code_context>
<issue_to_address>
**issue (complexity):** Consider simplifying the cache key computation and cache eviction policy by using an explicit structural fingerprint plus a small LRU cache for more deterministic and maintainable behavior.
You can keep all new features (better heuristic, caching, stats) but reduce complexity and improve determinism by:
1. Replacing the rolling hash with an explicit, structural fingerprint.
2. Using a simple LRU-style cache with `OrderedDict` instead of half-clearing arbitrary keys.
### 1) Simplify `_get_cache_key` with an explicit fingerprint
Instead of nested re‑hashing and `str(...)`, build a simple, immutable structure and hash that. This is easier to reason about and avoids depending on `str()` representations:
```python
from typing import Hashable
def _message_fingerprint(msg: Message) -> Hashable:
# Normalize content
if isinstance(msg.content, str):
content_repr: Hashable = msg.content
elif isinstance(msg.content, list):
# Keep explicit structure, don’t rely on str(list)
parts = []
for part in msg.content:
if isinstance(part, TextPart):
parts.append(("text", part.text))
elif isinstance(part, ThinkPart):
parts.append(("think", part.think))
elif isinstance(part, ImageURLPart):
parts.append(("image", part.image_url))
elif isinstance(part, AudioURLPart):
parts.append(("audio", part.audio_url))
else:
parts.append(("other", repr(part)))
content_repr = tuple(parts)
else:
content_repr = repr(msg.content)
# Normalize tool_calls (dicts sorted by key, others by repr)
if msg.tool_calls:
tools = []
for tc in msg.tool_calls:
if isinstance(tc, dict):
tools.append(("dict", tuple(sorted(tc.items()))))
else:
tools.append(("other", repr(tc)))
tool_repr: Hashable = tuple(tools)
else:
tool_repr = ()
return (msg.role, content_repr, tool_repr)
def _get_cache_key(self, messages: list[Message]) -> int:
if not messages:
return 0
fingerprint = tuple(_message_fingerprint(m) for m in messages)
return hash(fingerprint)
```
This keeps the “full-history structure” idea while making the identity explicit and deterministic.
### 2) Replace custom eviction with a tiny LRU cache
Use `collections.OrderedDict` to get deterministic, straightforward eviction instead of clearing half of the cache based on dict key order:
```python
from collections import OrderedDict
class EstimateTokenCounter:
def __init__(self, cache_size: int = 100) -> None:
self._cache: "OrderedDict[int, int]" = OrderedDict()
self._cache_size = cache_size
self._hit_count = 0
self._miss_count = 0
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
if trusted_token_usage > 0:
return trusted_token_usage
cache_key = self._get_cache_key(messages)
if cache_key in self._cache:
self._hit_count += 1
# mark as recently used
self._cache.move_to_end(cache_key)
return self._cache[cache_key]
self._miss_count += 1
total = self._count_tokens_internal(messages)
if self._cache_size > 0:
if len(self._cache) >= self._cache_size:
# evict least-recently-used
self._cache.popitem(last=False)
self._cache[cache_key] = total
return total
```
This keeps caching and stats behavior but removes the non-obvious “clear half of the cache” policy and reliance on incidental dict ordering.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| class TestLLMSummaryCompressor: | ||
| """Test cases for LLM summary compressor.""" | ||
|
|
||
| def setup_method(self): | ||
| """Setup test fixtures.""" | ||
| self.mock_provider = Mock() | ||
| self.mock_provider.text_chat = AsyncMock() | ||
| self.mock_provider.text_chat.return_value = Mock(completion_text="这是一段摘要。") | ||
|
|
||
| self.compressor = LLMSummaryCompressor( |
There was a problem hiding this comment.
suggestion (testing): Add tests for error handling, cache clearing, and should_compress behavior in LLMSummaryCompressor.
Right now we only cover the happy path and basic caching. It would be valuable to: (1) add a test where provider.text_chat raises to confirm __call__ returns the original messages; (2) verify clear_cache() actually invalidates entries by asserting the provider is called again (e.g., call_count increases after clearing); and (3) add focused tests for should_compress using edge values around the threshold for current_tokens vs max_tokens to validate the decision boundary.
Suggested implementation:
class TestLLMSummaryCompressor:
"""Test cases for LLM summary compressor."""
def setup_method(self):
"""Setup test fixtures."""
self.mock_provider = Mock()
self.mock_provider.text_chat = AsyncMock()
self.mock_provider.text_chat.return_value = Mock(completion_text="这是一段摘要。")
# NOTE: adjust arguments here if LLMSummaryCompressor requires more params.
self.compressor = LLMSummaryCompressor(
provider=self.mock_provider,
)
@pytest.mark.asyncio
async def test_error_handling_returns_original_messages_on_provider_failure(self):
"""When provider.text_chat raises, __call__ should return original messages unchanged."""
messages = [
Message(role="user", content="需要被压缩的对话内容"),
Message(role="assistant", content="助手的长回复"),
]
# Force the provider to fail
self.mock_provider.text_chat.side_effect = RuntimeError("LLM error")
# Call compressor (adjust call signature if needed)
result_messages = await self.compressor(
messages=messages,
max_tokens=2048,
language="zh",
)
# Should fall back to returning the original messages
assert result_messages == messages
@pytest.mark.asyncio
async def test_clear_cache_causes_provider_to_be_called_again(self):
"""clear_cache should invalidate cached summaries so provider is called again."""
messages = [
Message(role="user", content="第一条消息"),
Message(role="assistant", content="第一条回复"),
]
# First call: should hit provider
result1 = await self.compressor(
messages=messages,
max_tokens=2048,
language="zh",
)
assert self.mock_provider.text_chat.call_count == 1
assert "这是一段摘要" in result1[-1].content
# Change provider output to detect cache behavior
self.mock_provider.text_chat.reset_mock()
self.mock_provider.text_chat.return_value = Mock(completion_text="新的摘要内容")
# Second call without clearing cache: should use cached result, not call provider
result2 = await self.compressor(
messages=messages,
max_tokens=2048,
language="zh",
)
assert self.mock_provider.text_chat.call_count == 0
# Still old summary because of cache
assert "这是一段摘要" in result2[-1].content
# Clear cache and call again: provider should be hit and new summary used
self.compressor.clear_cache()
result3 = await self.compressor(
messages=messages,
max_tokens=2048,
language="zh",
)
assert self.mock_provider.text_chat.call_count == 1
assert "新的摘要内容" in result3[-1].content
def test_should_compress_below_threshold(self):
"""should_compress should return False when current_tokens is just below threshold."""
max_tokens = 1000
# Derive threshold ratio from compressor if available; otherwise fall back to a default.
ratio = getattr(self.compressor, "compression_ratio_threshold", 0.8)
threshold_tokens = int(max_tokens * ratio)
current_tokens = threshold_tokens - 1 # just below
assert self.compressor.should_compress(current_tokens=current_tokens, max_tokens=max_tokens) is False
def test_should_compress_at_threshold(self):
"""Validate decision boundary when current_tokens is at the threshold."""
max_tokens = 1000
ratio = getattr(self.compressor, "compression_ratio_threshold", 0.8)
threshold_tokens = int(max_tokens * ratio)
current_tokens = threshold_tokens # exactly at threshold
# Most implementations consider reaching the threshold as needing compression.
assert self.compressor.should_compress(current_tokens=current_tokens, max_tokens=max_tokens) is True
def test_should_compress_above_threshold(self):
"""should_compress should return True when current_tokens is just above threshold."""
max_tokens = 1000
ratio = getattr(self.compressor, "compression_ratio_threshold", 0.8)
threshold_tokens = int(max_tokens * ratio)
current_tokens = threshold_tokens + 1 # just above
assert self.compressor.should_compress(current_tokens=current_tokens, max_tokens=max_tokens) is True- Ensure
pytestis imported at the top oftests/test_context_compression.pyif it is not already:import pytest
- Confirm the actual
LLMSummaryCompressorconstructor signature:- If it requires additional parameters (e.g.,
tokenizer,summary_language,summary_max_tokens, etc.), add them insetup_methodwhereself.compressor = LLMSummaryCompressor(...)is created.
- If it requires additional parameters (e.g.,
- Adjust the
__call__signature used in the tests:- Update the calls to
await self.compressor(messages=..., max_tokens=..., language=...)to match the real__call__parameters (names and count).
- Update the calls to
- If
LLMSummaryCompressoruses a different attribute name thancompression_ratio_thresholdfor the threshold ratio, update the tests:- Replace
"compression_ratio_threshold"with the actual attribute name that represents the compression threshold.
- Replace
- If
__call__returns something other than a list ofMessageobjects (e.g., a tuple with metadata), adapt the assertions to check the correct part of the return value that contains the compressed messages.
| self._hit_count = 0 | ||
| self._miss_count = 0 | ||
|
|
||
| def _get_cache_key(self, messages: list[Message]) -> int: |
There was a problem hiding this comment.
issue (complexity): Consider simplifying the cache key computation and cache eviction policy by using an explicit structural fingerprint plus a small LRU cache for more deterministic and maintainable behavior.
You can keep all new features (better heuristic, caching, stats) but reduce complexity and improve determinism by:
- Replacing the rolling hash with an explicit, structural fingerprint.
- Using a simple LRU-style cache with
OrderedDictinstead of half-clearing arbitrary keys.
1) Simplify _get_cache_key with an explicit fingerprint
Instead of nested re‑hashing and str(...), build a simple, immutable structure and hash that. This is easier to reason about and avoids depending on str() representations:
from typing import Hashable
def _message_fingerprint(msg: Message) -> Hashable:
# Normalize content
if isinstance(msg.content, str):
content_repr: Hashable = msg.content
elif isinstance(msg.content, list):
# Keep explicit structure, don’t rely on str(list)
parts = []
for part in msg.content:
if isinstance(part, TextPart):
parts.append(("text", part.text))
elif isinstance(part, ThinkPart):
parts.append(("think", part.think))
elif isinstance(part, ImageURLPart):
parts.append(("image", part.image_url))
elif isinstance(part, AudioURLPart):
parts.append(("audio", part.audio_url))
else:
parts.append(("other", repr(part)))
content_repr = tuple(parts)
else:
content_repr = repr(msg.content)
# Normalize tool_calls (dicts sorted by key, others by repr)
if msg.tool_calls:
tools = []
for tc in msg.tool_calls:
if isinstance(tc, dict):
tools.append(("dict", tuple(sorted(tc.items()))))
else:
tools.append(("other", repr(tc)))
tool_repr: Hashable = tuple(tools)
else:
tool_repr = ()
return (msg.role, content_repr, tool_repr)
def _get_cache_key(self, messages: list[Message]) -> int:
if not messages:
return 0
fingerprint = tuple(_message_fingerprint(m) for m in messages)
return hash(fingerprint)This keeps the “full-history structure” idea while making the identity explicit and deterministic.
2) Replace custom eviction with a tiny LRU cache
Use collections.OrderedDict to get deterministic, straightforward eviction instead of clearing half of the cache based on dict key order:
from collections import OrderedDict
class EstimateTokenCounter:
def __init__(self, cache_size: int = 100) -> None:
self._cache: "OrderedDict[int, int]" = OrderedDict()
self._cache_size = cache_size
self._hit_count = 0
self._miss_count = 0
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
if trusted_token_usage > 0:
return trusted_token_usage
cache_key = self._get_cache_key(messages)
if cache_key in self._cache:
self._hit_count += 1
# mark as recently used
self._cache.move_to_end(cache_key)
return self._cache[cache_key]
self._miss_count += 1
total = self._count_tokens_internal(messages)
if self._cache_size > 0:
if len(self._cache) >= self._cache_size:
# evict least-recently-used
self._cache.popitem(last=False)
self._cache[cache_key] = total
return totalThis keeps caching and stats behavior but removes the non-obvious “clear half of the cache” policy and reliance on incidental dict ordering.
CodeQL False-Positive FixThe CodeQL Fix PR: #6736 — renamed Once #6736 is merged, the CodeQL check should pass. |
优化内容
1. Token 估算算法改进
2. 添加缓存机制
3. ContextManager 优化
4. Bug 修复
5. 单元测试
测试
关闭旧 PR #6655(因冲突)
Summary by Sourcery
Optimize the context compression pipeline with improved token estimation, caching, and runtime statistics.
New Features:
Bug Fixes:
Enhancements:
Tests: