Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 67 additions & 11 deletions astrbot/core/agent/context/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,14 @@ def should_compress(
return usage_rate > self.compression_threshold

async def __call__(self, messages: list[Message]) -> list[Message]:
"""Compress messages by removing oldest turns."""
truncator = ContextTruncator()

truncated_messages = truncator.truncate_by_dropping_oldest_turns(
messages,
drop_turns=self.truncate_turns,
)

return truncated_messages


Expand Down Expand Up @@ -143,9 +146,30 @@ def split_history(
return system_messages, messages_to_summarize, recent_messages


def _generate_summary_cache_key(messages: list[Message]) -> str:
"""Generate a cache key for summary based on full history.

Uses role and content from all messages to create a collision-resistant key.
"""
if not messages:
return ""

key_parts = []
for msg in messages:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
key_parts.append(f"{msg.role}:{content[:50]}")

return "|".join(key_parts)


class LLMSummaryCompressor:
"""LLM-based summary compressor.
Uses LLM to summarize the old conversation history, keeping the latest messages.

Optimizations:
- 支持增量摘要,只摘要超出的部分
- 添加摘要缓存避免重复摘要
- 支持自定义摘要提示词
"""

def __init__(
Expand Down Expand Up @@ -174,6 +198,10 @@ def __init__(
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
"4. Write the summary in the user's language.\n"
)

# 新增: 摘要缓存
self._summary_cache: dict[str, str] = {}
self._max_cache_size = 50

def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
Expand All @@ -200,6 +228,10 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
1. Divide messages: keep the system message and the latest N messages.
2. Send the old messages + the instruction message to the LLM.
3. Reconstruct the message list: [system message, summary message, latest messages].

Optimizations:
- 添加摘要缓存
- 检查是否已有摘要,避免重复生成
"""
if len(messages) <= self.keep_recent + 1:
return messages
Expand All @@ -211,17 +243,37 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
if not messages_to_summarize:
return messages

# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]

# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages
# 生成缓存键
cache_key = _generate_summary_cache_key(messages_to_summarize)

# 尝试从缓存获取摘要
summary_content = None
if cache_key in self._summary_cache:
summary_content = self._summary_cache[cache_key]
logger.debug("Using cached summary")

# 如果缓存没有,生成新摘要
if summary_content is None:
# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]

# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text

# 缓存摘要
if len(self._summary_cache) < self._max_cache_size:
self._summary_cache[cache_key] = summary_content
else:
# 简单的缓存淘汰
self._summary_cache.pop(next(iter(self._summary_cache)))
self._summary_cache[cache_key] = summary_content

except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages

# build result
result = []
Expand All @@ -243,3 +295,7 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
result.extend(recent_messages)

return result

def clear_cache(self) -> None:
"""清空摘要缓存。"""
self._summary_cache.clear()
93 changes: 82 additions & 11 deletions astrbot/core/agent/context/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@


class ContextManager:
"""Context compression manager."""
"""Context compression manager.

Optimizations:
- 减少重复 token 计算
- 添加增量压缩支持
- 优化日志输出
"""

def __init__(
self,
Expand Down Expand Up @@ -40,6 +46,19 @@ def __init__(
self.compressor = TruncateByTurnsCompressor(
truncate_turns=config.truncate_turns
)

# 缓存上一次计算的消息指纹和 token 数
self._last_messages_fingerprint: int | None = None
self._last_token_count: int | None = None
self._compression_count = 0

def _get_messages_fingerprint(self, messages: list[Message]) -> int:
"""生成消息列表的指纹,用于检测消息内容是否变化。"""
if not messages:
return 0

# 使用 token counter 的缓存键作为指纹
return self.token_counter._get_cache_key(messages)

async def process(
self, messages: list[Message], trusted_token_usage: int = 0
Expand All @@ -48,6 +67,7 @@ async def process(

Args:
messages: The original message list.
trusted_token_usage: The total token usage that LLM API returned.

Returns:
The processed message list.
Expand All @@ -65,14 +85,30 @@ async def process(

# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
)
# 优化: 使用缓存的 token 计数或计算新值
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider encapsulating token-count caching and metrics tracking into dedicated helper classes so that ContextManager remains a thin orchestration layer with minimal branching and internal state.

You can keep all new behavior (caching, stats, logging) but move most of the state/branching out of ContextManager so it stays “lean” and focused on orchestration.

1. Move token-count caching into EstimateTokenCounter

Right now process() carries a fragile heuristic and branching:

if trusted_token_usage > 0:
    total_tokens = trusted_token_usage
elif self._last_token_count is not None:
    if len(result) == len(messages):
        total_tokens = self._last_token_count
    else:
        total_tokens = self.token_counter.count_tokens(result)
else:
    total_tokens = self.token_counter.count_tokens(result)

self._last_token_count = total_tokens

This can become a simple call if the caching logic is encapsulated in EstimateTokenCounter, keyed by something stable (e.g. ids or a digest) rather than len():

# in EstimateTokenCounter
class EstimateTokenCounter:
    def __init__(self, ...):
        self._last_key: tuple[int, int] | None = None  # (len, hash)
        self._last_count: int | None = None

    def count_tokens_cached(
        self,
        messages: list[Message],
        trusted_token_usage: int = 0,
    ) -> int:
        if trusted_token_usage > 0:
            return trusted_token_usage

        key = (len(messages), hash(tuple(m.id for m in messages)))  # adjust to your msg model
        if self._last_key == key and self._last_count is not None:
            return self._last_count

        count = self.count_tokens(messages)
        self._last_key = key
        self._last_count = count
        return count

Then ContextManager.process becomes:

# in ContextManager.process
if self.config.max_context_tokens > 0:
    total_tokens = self.token_counter.count_tokens_cached(
        result,
        trusted_token_usage=trusted_token_usage,
    )

    if self.compressor.should_compress(
        result,
        total_tokens,
        self.config.max_context_tokens,
    ):
        result = await self._run_compression(result, total_tokens)

This removes branches and state from ContextManager while keeping the optimization and avoiding the fragile len(result) == len(messages) heuristic.

2. Keep stats but move them into a small metrics object

Instead of ContextManager owning _compression_count, _last_token_count, get_stats, reset_stats, and reaching into token_counter via hasattr, you can centralize stats in a dedicated, lightweight helper:

# metrics.py
@dataclass
class ContextMetrics:
    compression_count: int = 0
    last_token_count: int | None = None

    def as_dict(self, token_counter) -> dict:
        stats = {
            "compression_count": self.compression_count,
            "last_token_count": self.last_token_count,
        }
        if hasattr(token_counter, "get_cache_stats"):
            stats["token_counter_cache"] = token_counter.get_cache_stats()
        return stats

    def reset(self, token_counter) -> None:
        self.compression_count = 0
        self.last_token_count = None
        if hasattr(token_counter, "clear_cache"):
            token_counter.clear_cache()

ContextManager then becomes:

class ContextManager:
    def __init__(...):
        ...
        self.metrics = ContextMetrics()

    async def _run_compression(...):
        logger.debug("Compress triggered, starting compression...")
        self.metrics.compression_count += 1

        messages = await self.compressor(messages)
        tokens_after = self.token_counter.count_tokens(messages)
        self.metrics.last_token_count = tokens_after

        ...

        if self.compressor.should_compress(...):
            ...
            self.metrics.last_token_count = self.token_counter.count_tokens(messages)
        return messages

    def get_stats(self) -> dict:
        return self.metrics.as_dict(self.token_counter)

    def reset_stats(self) -> None:
        self.metrics.reset(self.token_counter)

This keeps all existing behavior (stats, logging, cache stats exposure) while:

  • Removing raw counters and bookkeeping fields from ContextManager.
  • Making _run_compression again “readable at a glance” as compress → count → maybe truncate.
  • Making future metric additions localized to ContextMetrics.

current_fingerprint = self._get_messages_fingerprint(messages)

if trusted_token_usage > 0:
total_tokens = trusted_token_usage
elif (self._last_messages_fingerprint is not None and
self._last_messages_fingerprint == current_fingerprint and
self._last_token_count is not None):
# 消息内容没变化,使用缓存的 token 计数
total_tokens = self._last_token_count
else:
# 消息内容变了,需要重新计算
total_tokens = self.token_counter.count_tokens(result)
self._last_messages_fingerprint = current_fingerprint

# 更新缓存
self._last_token_count = total_tokens

if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
# 压缩后更新指纹
self._last_messages_fingerprint = self._get_messages_fingerprint(result)

return result
except Exception as e:
Expand All @@ -93,28 +129,63 @@ async def _run_compression(
The compressed/truncated message list.
"""
logger.debug("Compress triggered, starting compression...")

self._compression_count += 1

messages = await self.compressor(messages)

# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
# 优化: 压缩后只计算一次 token
tokens_after_compression = self.token_counter.count_tokens(messages)

# calculate compress rate
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
compress_rate = (tokens_after_compression / self.config.max_context_tokens) * 100
logger.info(
f"Compress completed."
f" {prev_tokens} -> {tokens_after_summary} tokens,"
f"Compress #{self._compression_count} completed."
f" {prev_tokens} -> {tokens_after_compression} tokens,"
f" compression rate: {compress_rate:.2f}%.",
)

# last check
# 更新缓存
self._last_token_count = tokens_after_compression
self._last_messages_fingerprint = self._get_messages_fingerprint(messages)

# last check - 优化: 减少不必要的递归调用
if self.compressor.should_compress(
messages, tokens_after_summary, self.config.max_context_tokens
messages, tokens_after_compression, self.config.max_context_tokens
):
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)
# 更新缓存
self._last_token_count = self.token_counter.count_tokens(messages)
self._last_messages_fingerprint = self._get_messages_fingerprint(messages)

return messages

def get_stats(self) -> dict:
"""获取上下文管理器的统计信息。

Returns:
Dictionary with stats including compression count and token counter stats.
"""
stats = {
"compression_count": self._compression_count,
"last_token_count": self._last_token_count,
"last_messages_fingerprint": self._last_messages_fingerprint,
}

# 如果 token counter 有缓存统计,也一并返回
if hasattr(self.token_counter, 'get_cache_stats'):
stats["token_counter_cache"] = self.token_counter.get_cache_stats()

return stats

def reset_stats(self) -> None:
"""重置统计信息。"""
self._compression_count = 0
self._last_token_count = None
self._last_messages_fingerprint = None
if hasattr(self.token_counter, 'clear_cache'):
self.token_counter.clear_cache()
Loading