Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 67 additions & 11 deletions astrbot/core/agent/context/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,14 @@ def should_compress(
return usage_rate > self.compression_threshold

async def __call__(self, messages: list[Message]) -> list[Message]:
"""Compress messages by removing oldest turns."""
truncator = ContextTruncator()

truncated_messages = truncator.truncate_by_dropping_oldest_turns(
messages,
drop_turns=self.truncate_turns,
)

return truncated_messages


Expand Down Expand Up @@ -143,9 +146,30 @@ def split_history(
return system_messages, messages_to_summarize, recent_messages


def _generate_summary_cache_key(messages: list[Message]) -> str:
"""Generate a cache key for summary based on full history.

Uses role and content from all messages to create a collision-resistant key.
"""
if not messages:
return ""

key_parts = []
for msg in messages:
content = msg.content if isinstance(msg.content, str) else str(msg.content)
key_parts.append(f"{msg.role}:{content[:50]}")

return "|".join(key_parts)


class LLMSummaryCompressor:
"""LLM-based summary compressor.
Uses LLM to summarize the old conversation history, keeping the latest messages.

Optimizations:
- 支持增量摘要,只摘要超出的部分
- 添加摘要缓存避免重复摘要
- 支持自定义摘要提示词
"""

def __init__(
Expand Down Expand Up @@ -175,6 +199,10 @@ def __init__(
"4. Write the summary in the user's language.\n"
)

# 新增: 摘要缓存
self._summary_cache: dict[str, str] = {}
self._max_cache_size = 50

def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
Expand All @@ -200,6 +228,10 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
1. Divide messages: keep the system message and the latest N messages.
2. Send the old messages + the instruction message to the LLM.
3. Reconstruct the message list: [system message, summary message, latest messages].

Optimizations:
- 添加摘要缓存
- 检查是否已有摘要,避免重复生成
"""
if len(messages) <= self.keep_recent + 1:
return messages
Expand All @@ -211,17 +243,37 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
if not messages_to_summarize:
return messages

# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]

# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages
# 生成缓存键
cache_key = _generate_summary_cache_key(messages_to_summarize)

# 尝试从缓存获取摘要
summary_content = None
if cache_key in self._summary_cache:
summary_content = self._summary_cache[cache_key]
logger.debug("Using cached summary")

# 如果缓存没有,生成新摘要
if summary_content is None:
# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]

# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text

# 缓存摘要
if len(self._summary_cache) < self._max_cache_size:
self._summary_cache[cache_key] = summary_content
else:
# 简单的缓存淘汰
self._summary_cache.pop(next(iter(self._summary_cache)))
self._summary_cache[cache_key] = summary_content

except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages

# build result
result = []
Expand All @@ -243,3 +295,7 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
result.extend(recent_messages)

return result

def clear_cache(self) -> None:
"""清空摘要缓存。"""
self._summary_cache.clear()
99 changes: 88 additions & 11 deletions astrbot/core/agent/context/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@


class ContextManager:
"""Context compression manager."""
"""Context compression manager.

Optimizations:
- 减少重复 token 计算
- 添加增量压缩支持
- 优化日志输出
"""

def __init__(
self,
Expand Down Expand Up @@ -41,13 +47,27 @@
truncate_turns=config.truncate_turns
)

# 缓存上一次计算的消息指纹和 token 数
self._last_messages_fingerprint: int | None = None
self._last_token_count: int | None = None
self._compression_count = 0

def _get_messages_fingerprint(self, messages: list[Message]) -> int:
"""生成消息列表的指纹,用于检测消息内容是否变化。"""
if not messages:
return 0

# 使用 token counter 的缓存键作为指纹
return self.token_counter._get_cache_key(messages)

async def process(
self, messages: list[Message], trusted_token_usage: int = 0
) -> list[Message]:
"""Process the messages.

Args:
messages: The original message list.
trusted_token_usage: The total token usage that LLM API returned.

Returns:
The processed message list.
Expand All @@ -65,14 +85,34 @@

# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
)
# 优化: 使用缓存的 token 计数或计算新值
current_fingerprint = self._get_messages_fingerprint(messages)

if trusted_token_usage > 0:
total_tokens = trusted_token_usage
elif (
self._last_messages_fingerprint is not None
and self._last_messages_fingerprint == current_fingerprint
and self._last_token_count is not None
):
# 消息内容没变化,使用缓存的 token 计数
total_tokens = self._last_token_count
else:
# 消息内容变了,需要重新计算
total_tokens = self.token_counter.count_tokens(result)
self._last_messages_fingerprint = current_fingerprint

# 更新缓存
self._last_token_count = total_tokens

if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
# 压缩后更新指纹
self._last_messages_fingerprint = self._get_messages_fingerprint(
result
)

return result
except Exception as e:
Expand All @@ -94,27 +134,64 @@
"""
logger.debug("Compress triggered, starting compression...")

self._compression_count += 1

messages = await self.compressor(messages)

# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
# 优化: 压缩后只计算一次 token
tokens_after_compression = self.token_counter.count_tokens(messages)

# calculate compress rate
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
compress_rate = (
Comment on lines 143 to +145

Check failure

Code scanning / CodeQL

Clear-text logging of sensitive information High

This expression logs
sensitive data (secret)
as clear text.
This expression logs
sensitive data (secret)
as clear text.
tokens_after_compression / self.config.max_context_tokens
) * 100
logger.info(
f"Compress completed."
f" {prev_tokens} -> {tokens_after_summary} tokens,"
f"Compress #{self._compression_count} completed."
f" {prev_tokens} -> {tokens_after_compression} tokens,"
f" compression rate: {compress_rate:.2f}%.",
)

# last check
# 更新缓存
self._last_token_count = tokens_after_compression
self._last_messages_fingerprint = self._get_messages_fingerprint(messages)

# last check - 优化: 减少不必要的递归调用
if self.compressor.should_compress(
messages, tokens_after_summary, self.config.max_context_tokens
messages, tokens_after_compression, self.config.max_context_tokens
):
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)
# 更新缓存
self._last_token_count = self.token_counter.count_tokens(messages)
self._last_messages_fingerprint = self._get_messages_fingerprint(messages)

return messages

def get_stats(self) -> dict:
"""获取上下文管理器的统计信息。

Returns:
Dictionary with stats including compression count and token counter stats.
"""
stats = {
"compression_count": self._compression_count,
"last_token_count": self._last_token_count,
"last_messages_fingerprint": self._last_messages_fingerprint,
}

# 如果 token counter 有缓存统计,也一并返回
if hasattr(self.token_counter, "get_cache_stats"):
stats["token_counter_cache"] = self.token_counter.get_cache_stats()

return stats

def reset_stats(self) -> None:
"""重置统计信息。"""
self._compression_count = 0
self._last_token_count = None
self._last_messages_fingerprint = None
if hasattr(self.token_counter, "clear_cache"):
self.token_counter.clear_cache()
Loading
Loading