From 1c57601a8b13834f11d0b0edb075f65f2bd786e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nikola=20Forr=C3=B3?= Date: Mon, 11 May 2026 16:43:52 +0200 Subject: [PATCH] Allow to use model native thinking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Nikola Forró --- Containerfile.c10s | 8 +- Containerfile.c9s | 8 +- beeai-reasoning.patch | 201 +++++++ openinference-reasoning.patch | 44 ++ templates/beeai-agent.env | 3 + ymir/agents/backport_agent.py | 8 +- ymir/agents/build_agent.py | 9 +- ymir/agents/cve_applicability_agent.py | 9 +- ymir/agents/log_agent.py | 9 +- ymir/agents/merge_request_agent.py | 8 +- ymir/agents/preliminary_testing_agent.py | 8 +- ymir/agents/reasoning_agent/__init__.py | 4 + ymir/agents/reasoning_agent/_runner.py | 641 +++++++++++++++++++++++ ymir/agents/reasoning_agent/agent.py | 238 +++++++++ ymir/agents/reasoning_agent/events.py | 32 ++ ymir/agents/reasoning_agent/prompts.py | 150 ++++++ ymir/agents/reasoning_agent/types.py | 83 +++ ymir/agents/rebase_agent.py | 8 +- ymir/agents/rebuild_consolidation.py | 6 +- ymir/agents/triage_agent.py | 8 +- ymir/agents/utils.py | 16 +- 21 files changed, 1467 insertions(+), 34 deletions(-) create mode 100644 beeai-reasoning.patch create mode 100644 openinference-reasoning.patch create mode 100644 ymir/agents/reasoning_agent/__init__.py create mode 100644 ymir/agents/reasoning_agent/_runner.py create mode 100644 ymir/agents/reasoning_agent/agent.py create mode 100644 ymir/agents/reasoning_agent/events.py create mode 100644 ymir/agents/reasoning_agent/prompts.py create mode 100644 ymir/agents/reasoning_agent/types.py diff --git a/Containerfile.c10s b/Containerfile.c10s index dd760632..9c34c769 100644 --- a/Containerfile.c10s +++ b/Containerfile.c10s @@ -56,6 +56,9 @@ RUN dnf -y install --allowerasing \ ${EXTRA_PACKAGES} \ && dnf clean all +COPY beeai-reasoning.patch /tmp +COPY openinference-reasoning.patch /tmp + RUN pip3 install --no-cache-dir \ "litellm!=1.82.7,!=1.82.8" \ beeai-framework[vertexai,mcp,duckduckgo]==0.1.80 \ @@ -66,7 +69,10 @@ RUN pip3 install --no-cache-dir \ specfile \ pytest \ pytest-asyncio \ - GitPython>=3.1.0 + GitPython>=3.1.0 \ + && cd /usr/local/lib/python3.12/site-packages \ + && patch -p2 -i /tmp/beeai-reasoning.patch \ + && patch -p5 -i /tmp/openinference-reasoning.patch # Verify no malicious litellm_init.pth was introduced by compromised litellm packages (e.g. 1.82.7, 1.82.8) RUN MALICIOUS=$(find /usr /opt -name "litellm_init.pth" 2>/dev/null); \ diff --git a/Containerfile.c9s b/Containerfile.c9s index 55bde80a..f9e32915 100644 --- a/Containerfile.c9s +++ b/Containerfile.c9s @@ -55,6 +55,9 @@ RUN dnf -y install --allowerasing \ ${EXTRA_PACKAGES} \ && dnf clean all +COPY beeai-reasoning.patch /tmp +COPY openinference-reasoning.patch /tmp + # Create Python 3.11 virtual environment and install Python packages RUN python3.11 -m venv --system-site-packages /opt/beeai-venv \ && /opt/beeai-venv/bin/pip install --upgrade pip \ @@ -67,7 +70,10 @@ RUN python3.11 -m venv --system-site-packages /opt/beeai-venv \ redis \ specfile \ koji \ - GitPython>=3.1.0 + GitPython>=3.1.0 \ + && cd /opt/beeai-venv/lib/python3.11/site-packages \ + && patch -p2 -i /tmp/beeai-reasoning.patch \ + && patch -p5 -i /tmp/openinference-reasoning.patch # Verify no malicious litellm_init.pth was introduced by compromised litellm packages (e.g. 1.82.7, 1.82.8) RUN MALICIOUS=$(find /usr /opt -name "litellm_init.pth" 2>/dev/null); \ diff --git a/beeai-reasoning.patch b/beeai-reasoning.patch new file mode 100644 index 00000000..0816291f --- /dev/null +++ b/beeai-reasoning.patch @@ -0,0 +1,201 @@ +diff --git a/python/beeai_framework/adapters/litellm/chat.py b/python/beeai_framework/adapters/litellm/chat.py +index b5c5a9b4..9ad9f544 100644 +--- a/python/beeai_framework/adapters/litellm/chat.py ++++ b/python/beeai_framework/adapters/litellm/chat.py +@@ -37,7 +37,10 @@ from beeai_framework.backend.chat import ( + ) + from beeai_framework.backend.errors import ChatModelError + from beeai_framework.backend.message import ( ++ AnyMessage, + AssistantMessage, ++ AssistantMessageContent, ++ MessageReasoningContent, + MessageTextContent, + MessageToolCallContent, + ToolMessage, +@@ -195,6 +198,7 @@ class LiteLLMChatModel(ChatModel, ABC): + "role": "assistant", + "content": msg_text_content or None, + "tool_calls": msg_tool_calls or None, ++ "thinking_blocks": message.meta.get("thinking_blocks"), + } + if self.model_supports_tool_calling + else { +@@ -307,27 +311,41 @@ class LiteLLMChatModel(ChatModel, ABC): + total_cost_usd=prompt_tokens_cost_usd + completion_tokens_cost_usd, + ) + +- return ChatModelOutput( +- output=( +- [ +- AssistantMessage( +- [ +- MessageToolCallContent( +- id=call.id or "", +- tool_name=call.function.name or "", +- args=call.function.arguments, +- ) +- for call in update.tool_calls +- ], +- id=chunk.id, ++ reasoning_content = getattr(update, "reasoning_content", None) if update else None ++ # Anthropic requires `thinking_blocks` (with cryptographic signatures) to be sent back ++ # in conversation history; without them LiteLLM silently disables thinking on follow-up turns ++ meta = None ++ if (thinking_blocks := (getattr(update, "thinking_blocks", None) if update else None)) and ( ++ # Streaming deltas carry partial blocks without signatures - filter those out ++ signed_thinking_blocks := [ ++ b ++ for b in thinking_blocks ++ if (b.get("signature") if isinstance(b, dict) else getattr(b, "signature", None)) ++ ] ++ ): ++ meta = {"thinking_blocks": signed_thinking_blocks} ++ ++ if update: ++ parts: list[AssistantMessageContent] = [] ++ if reasoning_content: ++ parts.append(MessageReasoningContent(text=reasoning_content)) ++ if update.tool_calls: ++ parts.extend( ++ MessageToolCallContent( ++ id=call.id or "", ++ tool_name=call.function.name or "", ++ args=call.function.arguments, + ) +- if update.tool_calls +- # pyrefly: ignore [bad-argument-type] +- else AssistantMessage(update.content or update.reasoning_content or "", id=chunk.id) +- ] +- if (update and update.model_dump(exclude_none=True)) +- else [] +- ), ++ for call in update.tool_calls ++ ) ++ if update.content: ++ parts.append(MessageTextContent(text=update.content)) ++ output: list[AnyMessage] = [AssistantMessage(parts, id=chunk.id, meta=meta)] if parts or meta else [] ++ else: ++ output: list[AnyMessage] = [] ++ ++ return ChatModelOutput( ++ output=output, + # Will be set later + output_structured=None, + finish_reason=finish_reason, +diff --git a/python/beeai_framework/backend/__init__.py b/python/beeai_framework/backend/__init__.py +index fe8e5002..f3d7fe86 100644 +--- a/python/beeai_framework/backend/__init__.py ++++ b/python/beeai_framework/backend/__init__.py +@@ -23,6 +23,7 @@ from beeai_framework.backend.message import ( + Message, + MessageFileContent, + MessageImageContent, ++ MessageReasoningContent, + MessageTextContent, + MessageToolCallContent, + MessageToolResultContent, +@@ -65,6 +66,7 @@ __all__ = [ + "MessageError", + "MessageFileContent", + "MessageImageContent", ++ "MessageReasoningContent", + "MessageTextContent", + "MessageToolCallContent", + "MessageToolResultContent", +diff --git a/python/beeai_framework/backend/chat.py b/python/beeai_framework/backend/chat.py +index 972b5246..a1ed0ca3 100644 +--- a/python/beeai_framework/backend/chat.py ++++ b/python/beeai_framework/backend/chat.py +@@ -229,6 +229,11 @@ class ChatModelOptions(RunnableOptions, total=False): + Generated chunks will be streamed without validation of the produced tool calls. + """ + ++ reasoning_effort: str | None ++ """ ++ Controls the amount of reasoning effort for models that support it (e.g., "low", "medium", "high"). ++ """ ++ + fallback_tool: AnyTool | None + """ + Tool to invoke when the model makes a malformed tool call (for example, when it forgets the name of a tool). +diff --git a/python/beeai_framework/backend/message.py b/python/beeai_framework/backend/message.py +index 3877d35b..befd7e9e 100644 +--- a/python/beeai_framework/backend/message.py ++++ b/python/beeai_framework/backend/message.py +@@ -86,6 +86,11 @@ class MessageToolResultContent(BaseModel): + tool_call_id: str + + ++class MessageReasoningContent(BaseModel): ++ type: Literal["reasoning"] = "reasoning" ++ text: str ++ ++ + class MessageToolCallContent(BaseModel): + type: Literal["tool-call"] = "tool-call" + id: str +@@ -157,7 +162,7 @@ class Message(ABC, Generic[T]): + return type(self)([c.model_copy() for c in self.content], self.meta.copy()) + + +-AssistantMessageContent = MessageTextContent | MessageToolCallContent ++AssistantMessageContent = MessageTextContent | MessageToolCallContent | MessageReasoningContent + + + class AssistantMessage(Message[AssistantMessageContent]): +@@ -175,8 +180,10 @@ class AssistantMessage(Message[AssistantMessageContent]): + ( + MessageTextContent(text=c) + if isinstance(c, str) +- # pyrefly: ignore [bad-argument-type] +- else to_any_model([MessageToolCallContent, MessageTextContent], cast(AssistantMessageContent, c)) ++ else to_any_model( ++ [MessageToolCallContent, MessageReasoningContent, MessageTextContent], ++ cast(AssistantMessageContent, c), # pyrefly: ignore [bad-argument-type] ++ ) + ) + for c in cast_list(content) + ] +@@ -189,12 +196,19 @@ class AssistantMessage(Message[AssistantMessageContent]): + id=id, + ) + ++ @property ++ def reasoning(self) -> str: ++ return "".join([x.text for x in self.get_reasoning_messages()]) ++ + def get_tool_calls(self) -> list[MessageToolCallContent]: + return [cont for cont in self.content if isinstance(cont, MessageToolCallContent)] + + def get_text_messages(self) -> list[MessageTextContent]: + return [cont for cont in self.content if isinstance(cont, MessageTextContent)] + ++ def get_reasoning_messages(self) -> list[MessageReasoningContent]: ++ return [cont for cont in self.content if isinstance(cont, MessageReasoningContent)] ++ + + class ToolMessage(Message[MessageToolResultContent]): + role = Role.TOOL +diff --git a/python/beeai_framework/backend/types.py b/python/beeai_framework/backend/types.py +index b44a0a19..222d60a7 100644 +--- a/python/beeai_framework/backend/types.py ++++ b/python/beeai_framework/backend/types.py +@@ -33,6 +33,7 @@ class ChatModelParameters(BaseModel): + seed: int | None = None + stop_sequences: list[str] | None = None + stream: bool | None = None ++ reasoning_effort: str | None = None + + + class ChatModelStructureInput(ChatModelParameters, Generic[T]): +@@ -218,6 +219,9 @@ class ChatModelOutput(RunnableOutput): + def get_text_content(self) -> str: + return "".join([x.text for x in list(filter(lambda x: isinstance(x, AssistantMessage), self.output))]) + ++ def get_reasoning_content(self) -> str: ++ return "".join([x.reasoning for x in self.output if isinstance(x, AssistantMessage)]) ++ + + ChatModelCache = BaseCache[list[ChatModelOutput]] + diff --git a/openinference-reasoning.patch b/openinference-reasoning.patch new file mode 100644 index 00000000..bdbf4243 --- /dev/null +++ b/openinference-reasoning.patch @@ -0,0 +1,44 @@ +diff --git a/python/instrumentation/openinference-instrumentation-beeai/src/openinference/instrumentation/beeai/processors/chat.py b/python/instrumentation/openinference-instrumentation-beeai/src/openinference/instrumentation/beeai/processors/chat.py +index a135b4b7..a8a7a0f4 100644 +--- a/python/instrumentation/openinference-instrumentation-beeai/src/openinference/instrumentation/beeai/processors/chat.py ++++ b/python/instrumentation/openinference-instrumentation-beeai/src/openinference/instrumentation/beeai/processors/chat.py +@@ -5,6 +5,7 @@ from beeai_framework.backend import ( + AnyMessage, + ChatModel, + MessageImageContent, ++ MessageReasoningContent, + MessageTextContent, + MessageToolCallContent, + MessageToolResultContent, +@@ -218,6 +219,11 @@ def _process_messages( + ), + } + if isinstance(content, MessageToolResultContent) ++ else { ++ MessageContentAttributes.MESSAGE_CONTENT_TYPE: "reasoning", ++ MessageContentAttributes.MESSAGE_CONTENT_TEXT: content.text, ++ } ++ if isinstance(content, MessageReasoningContent) + else None + ) + for content in msg.content +@@ -244,7 +250,7 @@ def _process_messages( + + + def _aggregate_msg_content(message: "AnyMessage") -> None: +- from beeai_framework.backend import MessageTextContent, MessageToolCallContent ++ from beeai_framework.backend import MessageReasoningContent, MessageTextContent, MessageToolCallContent + + contents = message.content.copy() + aggregated_content: list[Any] = [] +@@ -257,6 +263,10 @@ def _aggregate_msg_content(message: "AnyMessage") -> None: + content, MessageToolCallContent + ): + last_content.args += content.args ++ elif isinstance(last_content, MessageReasoningContent) and isinstance( ++ content, MessageReasoningContent ++ ): ++ last_content.text += content.text + else: + aggregated_content.append(content) + diff --git a/templates/beeai-agent.env b/templates/beeai-agent.env index 1956359a..d06c626b 100644 --- a/templates/beeai-agent.env +++ b/templates/beeai-agent.env @@ -28,6 +28,9 @@ CHAT_MODEL= # CHAT_MODEL_REBASE= # CHAT_MODEL_REBUILD= +# One of: none, minimal, low, medium, high; defaults to none +#REASONING_EFFORT=high + # ============================================================================= # CREDENTIALS - Use based on model prefix above # ============================================================================= diff --git a/ymir/agents/backport_agent.py b/ymir/agents/backport_agent.py index dc3e76eb..0ac44c50 100644 --- a/ymir/agents/backport_agent.py +++ b/ymir/agents/backport_agent.py @@ -7,7 +7,6 @@ from pathlib import Path from typing import Any -from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ( ConditionalRequirement, ) @@ -29,12 +28,14 @@ from ymir.agents.log_agent import get_prompt as get_log_prompt from ymir.agents.observability import setup_observability from ymir.agents.package_update_steps import PackageUpdateState, PackageUpdateStep +from ymir.agents.reasoning_agent import ReasoningAgent from ymir.agents.utils import ( check_subprocess, format_mr_justification, get_agent_execution_config, get_chat_model, get_tool_call_checker_config, + is_reasoning_enabled, mcp_tools, render_prompt, resolve_chat_model_override, @@ -885,7 +886,7 @@ async def create_backport_agent( local_tool_options: dict[str, Any], include_build_tools: bool = False, fix_version: str | None = None, -) -> RequirementAgent: +) -> ReasoningAgent: """ Create a backport agent. @@ -934,9 +935,10 @@ async def create_backport_agent( if include_build_tools: base_tools.extend([t for t in mcp_tools if t.name in ["build_package", "download_artifacts"]]) - return RequirementAgent( + return ReasoningAgent( name="BackportAgent", llm=get_chat_model(), + unconstrained=is_reasoning_enabled(), tool_call_checker=get_tool_call_checker_config(), tools=base_tools, memory=UnconstrainedMemory(), diff --git a/ymir/agents/build_agent.py b/ymir/agents/build_agent.py index f164a39e..42fbf1eb 100644 --- a/ymir/agents/build_agent.py +++ b/ymir/agents/build_agent.py @@ -1,6 +1,5 @@ from typing import Any -from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ( ConditionalRequirement, ) @@ -10,7 +9,8 @@ from beeai_framework.tools.search.duckduckgo import DuckDuckGoSearchTool from beeai_framework.tools.think import ThinkTool -from ymir.agents.utils import get_chat_model, get_tool_call_checker_config +from ymir.agents.reasoning_agent import ReasoningAgent +from ymir.agents.utils import get_chat_model, get_tool_call_checker_config, is_reasoning_enabled from ymir.tools.unprivileged.commands import RunShellCommandTool from ymir.tools.unprivileged.filesystem import GetCWDTool from ymir.tools.unprivileged.text import ( @@ -49,10 +49,11 @@ def get_prompt() -> str: """ -def create_build_agent(mcp_tools: list[Tool], local_tool_options: dict[str, Any]) -> RequirementAgent: - return RequirementAgent( +def create_build_agent(mcp_tools: list[Tool], local_tool_options: dict[str, Any]) -> ReasoningAgent: + return ReasoningAgent( name="BuildAgent", llm=get_chat_model(), + unconstrained=is_reasoning_enabled(), tool_call_checker=get_tool_call_checker_config(), tools=[ ThinkTool(), diff --git a/ymir/agents/cve_applicability_agent.py b/ymir/agents/cve_applicability_agent.py index 2b0462d0..46e75bdb 100644 --- a/ymir/agents/cve_applicability_agent.py +++ b/ymir/agents/cve_applicability_agent.py @@ -1,7 +1,6 @@ from pathlib import Path from textwrap import dedent -from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ( ConditionalRequirement, ) @@ -11,7 +10,8 @@ from beeai_framework.tools.search.duckduckgo import DuckDuckGoSearchTool from beeai_framework.tools.think import ThinkTool -from ymir.agents.utils import get_chat_model, get_tool_call_checker_config +from ymir.agents.reasoning_agent import ReasoningAgent +from ymir.agents.utils import get_chat_model, get_tool_call_checker_config, is_reasoning_enabled from ymir.common.models import Resolution from ymir.tools.unprivileged.commands import RunShellCommandTool from ymir.tools.unprivileged.text import SearchTextTool, ViewTool @@ -20,11 +20,12 @@ def create_applicability_agent( gateway_tools: list[Tool], local_tool_options: dict, -) -> RequirementAgent: +) -> ReasoningAgent: extra_gateway_tools = [t for t in gateway_tools if t.name in ["get_jira_details", "get_maintainer_rules"]] - return RequirementAgent( + return ReasoningAgent( name="ApplicabilityAgent", llm=get_chat_model(), + unconstrained=is_reasoning_enabled(), tool_call_checker=get_tool_call_checker_config(), tools=[ ThinkTool(), diff --git a/ymir/agents/log_agent.py b/ymir/agents/log_agent.py index 423fab7b..fbaeaf24 100644 --- a/ymir/agents/log_agent.py +++ b/ymir/agents/log_agent.py @@ -1,6 +1,5 @@ from typing import Any -from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ( ConditionalRequirement, ) @@ -10,7 +9,8 @@ from beeai_framework.tools.search.duckduckgo import DuckDuckGoSearchTool from beeai_framework.tools.think import ThinkTool -from ymir.agents.utils import get_chat_model, get_tool_call_checker_config +from ymir.agents.reasoning_agent import ReasoningAgent +from ymir.agents.utils import get_chat_model, get_tool_call_checker_config, is_reasoning_enabled from ymir.tools.unprivileged.commands import RunShellCommandTool from ymir.tools.unprivileged.filesystem import GetCWDTool from ymir.tools.unprivileged.specfile import AddChangelogEntryTool @@ -90,10 +90,11 @@ def get_prompt() -> str: """ -def create_log_agent(_: list[Tool], local_tool_options: dict[str, Any]) -> RequirementAgent: - return RequirementAgent( +def create_log_agent(_: list[Tool], local_tool_options: dict[str, Any]) -> ReasoningAgent: + return ReasoningAgent( name="LogAgent", llm=get_chat_model(), + unconstrained=is_reasoning_enabled(), tool_call_checker=get_tool_call_checker_config(), tools=[ ThinkTool(), diff --git a/ymir/agents/merge_request_agent.py b/ymir/agents/merge_request_agent.py index 00e1d1b7..611cbaa5 100644 --- a/ymir/agents/merge_request_agent.py +++ b/ymir/agents/merge_request_agent.py @@ -8,7 +8,6 @@ from textwrap import dedent from typing import Any -from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ( ConditionalRequirement, ) @@ -26,10 +25,12 @@ from ymir.agents.build_agent import get_prompt as get_build_prompt from ymir.agents.constants import I_AM_YMIR from ymir.agents.observability import setup_observability +from ymir.agents.reasoning_agent import ReasoningAgent from ymir.agents.utils import ( get_agent_execution_config, get_chat_model, get_tool_call_checker_config, + is_reasoning_enabled, mcp_tools, render_prompt, ) @@ -133,10 +134,11 @@ def get_prompt() -> str: """ -def create_merge_request_agent(mcp_tools: list[Tool], local_tool_options: dict[str, Any]) -> RequirementAgent: - return RequirementAgent( +def create_merge_request_agent(mcp_tools: list[Tool], local_tool_options: dict[str, Any]) -> ReasoningAgent: + return ReasoningAgent( name="MergeRequestAgent", llm=get_chat_model(), + unconstrained=is_reasoning_enabled(), tool_call_checker=get_tool_call_checker_config(), tools=[ ThinkTool(), diff --git a/ymir/agents/preliminary_testing_agent.py b/ymir/agents/preliminary_testing_agent.py index a58daf70..6e344ca2 100644 --- a/ymir/agents/preliminary_testing_agent.py +++ b/ymir/agents/preliminary_testing_agent.py @@ -8,7 +8,6 @@ from enum import StrEnum from typing import Any -from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ( ConditionalRequirement, ) @@ -22,10 +21,12 @@ from pydantic import BaseModel, Field from ymir.agents.observability import setup_observability +from ymir.agents.reasoning_agent import ReasoningAgent from ymir.agents.utils import ( get_agent_execution_config, get_chat_model, get_tool_call_checker_config, + is_reasoning_enabled, mcp_tools, run_tool, ) @@ -137,14 +138,15 @@ def render_prompt(input: InputSchema) -> str: return PromptTemplate(PromptTemplateInput(schema=InputSchema, template=TEMPLATE)).render(input) -def create_preliminary_testing_agent(gateway_tools: list) -> RequirementAgent: - return RequirementAgent( +def create_preliminary_testing_agent(gateway_tools: list) -> ReasoningAgent: + return ReasoningAgent( name="PreliminaryTestingAnalyst", description=( "Agent that analyzes GreenWave gating and MR comment results" " to determine preliminary testing status" ), llm=get_chat_model(), + unconstrained=is_reasoning_enabled(), tool_call_checker=get_tool_call_checker_config(), tools=[ ThinkTool(), diff --git a/ymir/agents/reasoning_agent/__init__.py b/ymir/agents/reasoning_agent/__init__.py new file mode 100644 index 00000000..2d6e87cd --- /dev/null +++ b/ymir/agents/reasoning_agent/__init__.py @@ -0,0 +1,4 @@ +from ymir.agents.reasoning_agent.agent import ReasoningAgent +from ymir.agents.reasoning_agent.types import ReasoningAgentOutput, ReasoningAgentRunState + +__all__ = ["ReasoningAgent", "ReasoningAgentOutput", "ReasoningAgentRunState"] diff --git a/ymir/agents/reasoning_agent/_runner.py b/ymir/agents/reasoning_agent/_runner.py new file mode 100644 index 00000000..b87e3490 --- /dev/null +++ b/ymir/agents/reasoning_agent/_runner.py @@ -0,0 +1,641 @@ +import json +import uuid +from collections.abc import Sequence +from typing import Any, Self + +from beeai_framework.agents import AgentError, AgentExecutionConfig +from beeai_framework.agents._utils import run_tools +from beeai_framework.agents.requirement.agent import RequirementAgentRequirement +from beeai_framework.agents.requirement.requirements.events import ( + RequirementInitEvent, + requirement_event_types, +) +from beeai_framework.agents.requirement.requirements.requirement import Rule +from beeai_framework.agents.tool_calling.utils import ToolCallChecker +from beeai_framework.backend import ( + AnyMessage, + AssistantMessage, + ChatModel, + ChatModelOutput, + MessageToolCallContent, + MessageToolResultContent, + SystemMessage, + ToolMessage, + UserMessage, +) +from beeai_framework.backend.chat import ChatModelOptions +from beeai_framework.backend.errors import ChatModelToolCallError +from beeai_framework.backend.utils import parse_broken_json +from beeai_framework.context import RunContext +from beeai_framework.emitter import Emitter +from beeai_framework.memory import UnconstrainedMemory +from beeai_framework.memory.utils import TEMP_MESSAGE_META_KEY, delete_messages_by_meta_key +from beeai_framework.middleware.stream_tool_call import StreamToolCallMiddleware +from beeai_framework.tools import AnyTool, StringToolOutput, Tool, ToolRunOptions +from beeai_framework.utils.counter import RetryCounter +from beeai_framework.utils.lists import ensure_strictly_increasing, find_last_index +from beeai_framework.utils.strings import find_first_pair, generate_random_string, to_json, to_safe_word +from pydantic import BaseModel, Field + +from ymir.agents.reasoning_agent.events import ( + ReasoningAgentFinalAnswerEvent, + ReasoningAgentStartEvent, + ReasoningAgentSuccessEvent, +) +from ymir.agents.reasoning_agent.prompts import ( + ReasoningAgentToolErrorPromptInput, + ReasoningAgentToolTemplateDefinition, +) +from ymir.agents.reasoning_agent.types import ( + ReasoningAgentRunState, + ReasoningAgentRunStateStep, + ReasoningAgentTemplates, + RequirementEvaluation, +) + + +class FinalAnswerToolSchema(BaseModel): + response: str = Field(description="The final answer to the user") + + +class FinalAnswerTool(Tool[BaseModel, ToolRunOptions, StringToolOutput]): + name = "final_answer" + description = "Sends the final answer to the user" + + def __init__(self, expected_output: str | type[BaseModel] | None, state: ReasoningAgentRunState) -> None: + super().__init__() + self._expected_output = expected_output + self._state = state + self.instructions = expected_output if isinstance(expected_output, str) else None + self.custom_schema = isinstance(expected_output, type) + + def _create_emitter(self) -> Emitter: + return Emitter.root().child(namespace=["tool", "final_answer"], creator=self) + + @property + def input_schema(self) -> type[BaseModel]: + expected_output = self._expected_output + + if expected_output is None: + return FinalAnswerToolSchema + if isinstance(expected_output, type) and issubclass(expected_output, BaseModel): + return expected_output + if isinstance(expected_output, str): + + class CustomFinalAnswerToolSchema(FinalAnswerToolSchema): + response: str = Field(description=expected_output) # type: ignore + + return CustomFinalAnswerToolSchema + return FinalAnswerToolSchema + + async def _run( + self, input: BaseModel, options: ToolRunOptions | None, context: RunContext + ) -> StringToolOutput: + self._state.result = input + if self.input_schema is self._expected_output: + self._state.answer = AssistantMessage(input.model_dump_json()) + else: + self._state.answer = AssistantMessage(input.response) # type: ignore + + return StringToolOutput("Message has been sent") + + async def clone(self) -> Self: + tool = self.__class__(expected_output=self._expected_output, state=self._state.model_copy()) + tool.name = self.name + tool.description = self.description + tool._cache = await self.cache.clone() + tool.middlewares.extend(self.middlewares) + return tool + + +class ReasoningAgentRunner: + def __init__( + self, + *, + config: AgentExecutionConfig, + tool_call_cycle_checker: ToolCallChecker, + force_final_answer_as_tool: bool, + expected_output: Any, + run_context: RunContext, + tools: list[AnyTool], + templates: ReasoningAgentTemplates, + llm: ChatModel, + requirements: Sequence[RequirementAgentRequirement] | None = None, + unconstrained: bool = False, + ) -> None: + self._ctx = run_context + self._llm = llm + self._templates = templates + self._force_final_answer_as_tool = force_final_answer_as_tool + self._state = ReasoningAgentRunState( + answer=None, result=None, memory=UnconstrainedMemory(), steps=[], iteration=0 + ) + self._final_answer = FinalAnswerTool(expected_output, state=self._state) + self._tools = tools + self._all_tools: list[AnyTool] = [*tools, self._final_answer] + self._run_config = config + self._tool_call_cycle_checker = tool_call_cycle_checker + self._requirements: list[RequirementAgentRequirement] = list(requirements or []) + self._unconstrained = unconstrained + + max_retries_per_iteration = 0 if config.max_retries_per_step is None else config.max_retries_per_step + self._iteration_error_counter = RetryCounter( + error_type=AgentError, max_retries=max_retries_per_iteration + ) + + max_retries = 0 if config.total_max_retries is None else config.total_max_retries + max_retries = max(max_retries_per_iteration, max_retries) + self._global_error_counter = RetryCounter(error_type=AgentError, max_retries=max_retries) + + async def _init_requirements(self) -> None: + for requirement in self._requirements: + emitter = self._ctx.emitter.child( + group_id=to_safe_word(requirement.name), + creator=requirement, + events=requirement_event_types, + ) + emitter.namespace.append("requirement") + tools = list(self._all_tools) + await emitter.emit("init", RequirementInitEvent(tools=tools)) + await requirement.init(tools=tools, ctx=self._ctx) + + async def _evaluate_requirements(self, extra_rules: list[Rule] | None = None) -> RequirementEvaluation: + rules_by_tool: dict[str, list[tuple[int, Rule]]] = {t.name: [] for t in self._all_tools} + + for requirement in self._requirements: + if not requirement.enabled: + continue + generated_rules = await requirement.run(self._state) # type: ignore[arg-type] + for rule in generated_rules: + if rule.target not in rules_by_tool: + raise ValueError( + f"Tool '{rule.target}' not found in ({','.join(t.name for t in self._all_tools)})." + ) + rules_by_tool[rule.target].append((requirement.priority, rule)) + + for rule in extra_rules or []: + if rule.target not in rules_by_tool: + raise ValueError(f"Tool '{rule.target}' not found.") + entries = rules_by_tool[rule.target] + priority = max(e[0] for e in entries) + 1 if entries else 1 + entries.append((priority, rule)) + + allowed: list[AnyTool] = [] + hidden: list[AnyTool] = [] + forced: AnyTool | None = None + forced_priority = 0 + prevent_stop = False + prevent_step_refs: list[dict[str, Any]] = [] + reason_by_tool: dict[str, str | None] = {} + reasons: list[str] = [] + + for tool in self._all_tools: + entries = rules_by_tool.get(tool.name, []) + if not entries: + allowed.append(tool) + continue + + entries.sort(key=lambda x: x[0], reverse=True) + + is_allowed = True + is_hidden = False + is_forced = False + is_prevent_stop = False + reason: str | None = None + + for priority, rule in entries: + if not rule.allowed: + is_allowed = False + if rule.hidden: + is_hidden = True + if rule.forced: + is_forced = True + if rule.prevent_stop: + is_prevent_stop = True + prevent_step_refs.append( + { + "rule": { + "target": rule.target, + "allowed": rule.allowed, + "reason": rule.reason, + }, + "priority": priority, + } + ) + if rule.reason: + reason = rule.reason + + if is_hidden: + is_allowed = False + + if reason: + reason_by_tool[tool.name] = reason + + if is_allowed: + allowed.append(tool) + max_priority = entries[0][0] + if is_forced and (not forced or forced_priority < max_priority): + forced = tool + forced_priority = max_priority + + if not is_allowed and reason: + reasons.append(f"- {tool.name}: {reason}") + + if is_hidden: + hidden.append(tool) + if is_prevent_stop: + prevent_stop = True + + # Constrained: restrict allowed to forced + final_answer when forced + if not self._unconstrained and forced is not None: + allowed = [forced] + if self._final_answer is not forced: + allowed.append(self._final_answer) + + if prevent_stop and not isinstance(forced, FinalAnswerTool): + if self._final_answer in allowed: + allowed.remove(self._final_answer) + if self._unconstrained: + reasons.append("Do NOT call 'final_answer' yet — there are required steps remaining.") + + if not allowed: + raise AgentError( + "One of the generated rules is preventing the agent from continuing. " + "This indicates that the provided requirements may conflict with each other. " + "See the following rules that are preventing the agent from continuing.\n" + + json.dumps(prevent_step_refs, indent=2, default=str) + ) + + # Unconstrained: build prompt-based constraint text + constraint_text = None + if self._unconstrained: + if forced is not None and forced is not self._final_answer: + reasons.insert(0, f"You MUST call '{forced.name}' in your next response.") + unavailable = [r for r in reasons if r.startswith("- ")] + directives = [r for r in reasons if not r.startswith("- ")] + constraint_parts: list[str] = [] + if directives: + constraint_parts.extend(directives) + if unavailable: + constraint_parts.append("The following tools are currently unavailable:") + constraint_parts.extend(unavailable) + constraint_text = "\n".join(constraint_parts) if constraint_parts else None + + # Constrained: compute tool_choice for forcing + tool_choice: AnyTool | str = "auto" + if not self._unconstrained: + if forced is not None: + tool_choice = forced + elif len(allowed) == 1: + tool_choice = allowed[0] + else: + tool_choice = "required" + if ( + not isinstance(tool_choice, Tool) + and not self._force_final_answer_as_tool + and not prevent_stop + ): + tool_choice = "auto" + + return RequirementEvaluation( + allowed_tools=allowed, + hidden_tools=hidden, + forced_tool=forced, + can_stop=not prevent_stop, + constraint_text=constraint_text, + tool_choice=tool_choice, + reason_by_tool=reason_by_tool, + all_tools=list(self._all_tools), + ) + + def _increment_iteration(self) -> None: + self._state.iteration += 1 + + if self._run_config.max_iterations and self._state.iteration > self._run_config.max_iterations: + raise AgentError(f"Agent was not able to resolve the task in {self._state.iteration} iterations.") + + def _create_final_answer_stream(self) -> StreamToolCallMiddleware: + stream_middleware = StreamToolCallMiddleware( + self._final_answer, + "response", + match_nested=False, + force_streaming=False, + ) + stream_middleware.emitter.on( + "update", + lambda data, meta: self._ctx.emitter.emit( + "final_answer", + ReasoningAgentFinalAnswerEvent( + state=self._state, output=data.output, delta=data.delta, output_structured=None + ), + ), + ) + return stream_middleware + + async def _run_llm(self, evaluation: RequirementEvaluation) -> ChatModelOutput: + stream_middleware = self._create_final_answer_stream() + + try: + messages, options = self._prepare_llm_request(evaluation) + response = await self._llm.run(messages, **options).middleware(stream_middleware) + + self._state.usage.merge(response.usage) + self._state.cost.merge(response.cost) + + return response + except ChatModelToolCallError as e: + generated_content = e.generated_content or (e.response.get_text_content() if e.response else "") + if not generated_content: + raise e + + response = ChatModelOutput.from_chunks([e.response] if e.response else []) + response.output.clear() + response.output.append(AssistantMessage(generated_content)) + return response + finally: + stream_middleware.unbind() + + def _create_system_message( + self, + tool_constraints: str | None = None, + tools: list[ReasoningAgentToolTemplateDefinition] | None = None, + ) -> SystemMessage: + return SystemMessage( + self._templates.system.render( + final_answer_name=self._final_answer.name, + final_answer_schema=( + to_json( + self._final_answer.input_schema.model_json_schema(mode="validation"), + indent=2, + sort_keys=False, + ) + if self._final_answer.custom_schema + else None + ), + final_answer_instructions=self._final_answer.instructions, + tool_constraints=tool_constraints, + tools=tools or [], + ) + ) + + def _prepare_llm_request( + self, evaluation: RequirementEvaluation + ) -> tuple[list[AnyMessage], ChatModelOptions]: + if self._unconstrained: + messages = [ + self._create_system_message(tool_constraints=evaluation.constraint_text), + *self._state.memory.messages, + ] + tools_for_llm = [t for t in evaluation.allowed_tools if t not in evaluation.hidden_tools] + options = ChatModelOptions( + max_retries=self._run_config.max_retries_per_step, + tools=tools_for_llm, + tool_choice="auto", + stream_partial_tool_calls=True, + fallback_tool=self._final_answer if evaluation.can_stop else None, + ) + cache_index = 0 + else: + tool_defs = [ + ReasoningAgentToolTemplateDefinition.from_tool( + tool, + allowed=tool in evaluation.allowed_tools, + reason=evaluation.reason_by_tool.get(tool.name), + ) + for tool in evaluation.all_tools + if tool not in evaluation.hidden_tools + ] + messages = [ + self._create_system_message(tools=tool_defs), + *self._state.memory.messages, + ] + tools_for_llm = [t for t in evaluation.allowed_tools if t not in evaluation.hidden_tools] + options = ChatModelOptions( + max_retries=self._run_config.max_retries_per_step, + tools=tools_for_llm, + tool_choice=evaluation.tool_choice, + stream_partial_tool_calls=True, + fallback_tool=self._final_answer if evaluation.can_stop else None, + ) + cache_index = 1 if self._requirements else 0 + + cache_control_injection_points = [ + {"location": "message", "index": cache_index}, + { + "location": "message", + "index": find_last_index( + messages, + lambda msg: ( + not msg.meta.get(TEMP_MESSAGE_META_KEY) + and (self._llm.provider_id != "amazon_bedrock" or not isinstance(msg, ToolMessage)) + ), + ), + }, + ] + options["cache_control_injection_points"] = ensure_strictly_increasing( # type: ignore + cache_control_injection_points, + key=lambda v: v["index"], + ) + return messages, options + + async def _create_final_answer_tool_call(self, full_text: str) -> AssistantMessage | None: + json_object_pair = find_first_pair(full_text, ("{", "}")) + final_answer_input = parse_broken_json(json_object_pair.outer) if json_object_pair else None + if not final_answer_input and not self._final_answer.custom_schema: + final_answer_input = FinalAnswerToolSchema(response=full_text).model_dump() + + if not final_answer_input: + return None + + manual_assistant_tool_call_message = MessageToolCallContent( + type="tool-call", + id=f"call_{generate_random_string(8).lower()}", + tool_name=self._final_answer.name, + args=to_json(final_answer_input, sort_keys=False), + ) + return AssistantMessage(manual_assistant_tool_call_message) + + async def _invoke_tool_calls( + self, tool_calls: list[MessageToolCallContent], evaluation: RequirementEvaluation + ) -> list[ToolMessage]: + tool_results: list[ToolMessage] = [] + + for tool_call in await run_tools( + tools=evaluation.allowed_tools, + messages=tool_calls, + context={"state": self._state.model_dump()}, + ): + self._state.steps.append( + ReasoningAgentRunStateStep( + id=str(uuid.uuid4()), + iteration=self._state.iteration, + input=tool_call.input, + output=tool_call.output, + tool=tool_call.tool, + error=tool_call.error, + ) + ) + + if tool_call.error is not None: + result = self._templates.tool_error.render( + ReasoningAgentToolErrorPromptInput(reason=tool_call.error.explain()) + ) + else: + result = ( + tool_call.output.get_text_content() + if not tool_call.output.is_empty() + else self._templates.tool_no_result.render(tool_call=tool_call) + ) + + tool_results.append( + ToolMessage( + MessageToolResultContent( + tool_name=tool_call.tool.name if tool_call.tool else tool_call.msg.tool_name, + tool_call_id=tool_call.msg.id, + result=result, + ) + ) + ) + if tool_call.error is not None: + self._iteration_error_counter.use(tool_call.error) + self._global_error_counter.use(tool_call.error) + + return tool_results + + async def add_messages(self, messages: list[AnyMessage]) -> None: + await self._state.memory.add_many(messages) + + async def run(self) -> ReasoningAgentRunState: + if self._state.answer is not None: + return self._state + + await self._init_requirements() + + while self._state.answer is None: + self._increment_iteration() + + evaluation = await self._evaluate_requirements() + await self._ctx.emitter.emit( + "start", + ReasoningAgentStartEvent(state=self._state, evaluation=evaluation), + ) + self._iteration_error_counter.reset() + + if self._unconstrained: + response = await self._run_unconstrained(evaluation) + else: + response = await self._run_constrained(evaluation) + + await self._ctx.emitter.emit( + "success", + ReasoningAgentSuccessEvent(state=self._state, response=response), + ) + return self._state + + async def _run_constrained(self, evaluation: RequirementEvaluation) -> ChatModelOutput: + response = await self._run_llm(evaluation) + + if not response.get_tool_calls(): + text = response.get_text_content() + final_answer_tool_call = ( + await self._create_final_answer_tool_call(text) if evaluation.can_stop and text else None + ) + if final_answer_tool_call: + stream = self._create_final_answer_stream() + await stream.add(ChatModelOutput(output=[final_answer_tool_call])) + response.output_structured = None + response.output = [final_answer_tool_call] + else: + err = AgentError("Model produced an invalid final answer tool call.") + self._iteration_error_counter.use(err) + self._global_error_counter.use(err) + + if not evaluation.can_stop: + return await self._run_constrained(evaluation) + + self._requirements = [] + updated = await self._evaluate_requirements( + extra_rules=[ + Rule(target=self._final_answer.name, allowed=True, hidden=False), + ], + ) + self._force_final_answer_as_tool = True + return await self._run_constrained(updated) + + tool_calls = response.get_tool_calls() + for tool_call_msg in tool_calls: + self._tool_call_cycle_checker.register(tool_call_msg) + if self._tool_call_cycle_checker.cycle_found: + self._tool_call_cycle_checker.reset() + updated = await self._evaluate_requirements( + extra_rules=[ + Rule( + target=tool_call_msg.tool_name, + allowed=False, + hidden=False, + forced=True, + ), + ], + ) + return await self._run_constrained(updated) + + tool_results = await self._invoke_tool_calls(tool_calls, evaluation) + + await self._state.memory.add_many([*response.output, *tool_results]) + await delete_messages_by_meta_key(self._state.memory, key=TEMP_MESSAGE_META_KEY, value=True) + + return response + + async def _run_unconstrained(self, evaluation: RequirementEvaluation) -> ChatModelOutput: + response = await self._run_llm(evaluation) + + if not response.get_tool_calls(): + text = response.get_text_content() + final_answer_tool_call = await self._create_final_answer_tool_call(text) if text else None + if final_answer_tool_call: + stream = self._create_final_answer_stream() + await stream.add(ChatModelOutput(output=[final_answer_tool_call])) + response.output_structured = None + response.output = [final_answer_tool_call] + elif not self._force_final_answer_as_tool: + self._state.answer = AssistantMessage(text or "") + self._state.result = text + await self._state.memory.add_many(response.output) + return response + else: + err = AgentError("Model produced text instead of calling final_answer tool.") + self._iteration_error_counter.use(err) + self._global_error_counter.use(err) + await self._state.memory.add_many(response.output) + + await self._state.memory.add( + UserMessage( + "Please provide your final answer using the 'final_answer' tool.", + meta={TEMP_MESSAGE_META_KEY: True}, + ) + ) + return response + + tool_calls = response.get_tool_calls() + for tool_call_msg in tool_calls: + self._tool_call_cycle_checker.register(tool_call_msg) + if self._tool_call_cycle_checker.cycle_found: + self._tool_call_cycle_checker.reset() + await self._state.memory.add_many(response.output) + + await self._state.memory.add( + UserMessage( + f"You appear to be calling '{tool_call_msg.tool_name}' repeatedly " + "with the same input. Break the cycle by using a different tool " + "or different input, or call 'final_answer' to provide your final answer.", + meta={TEMP_MESSAGE_META_KEY: True}, + ) + ) + return response + + tool_results = await self._invoke_tool_calls(tool_calls, evaluation) + + await self._state.memory.add_many([*response.output, *tool_results]) + await delete_messages_by_meta_key(self._state.memory, key=TEMP_MESSAGE_META_KEY, value=True) + + return response diff --git a/ymir/agents/reasoning_agent/agent.py b/ymir/agents/reasoning_agent/agent.py new file mode 100644 index 00000000..8005af6b --- /dev/null +++ b/ymir/agents/reasoning_agent/agent.py @@ -0,0 +1,238 @@ +from collections.abc import Sequence +from typing import Any + +from beeai_framework.agents import AgentExecutionConfig, AgentMeta, AgentOptions, BaseAgent +from beeai_framework.agents.requirement.agent import RequirementAgentRequirement +from beeai_framework.agents.tool_calling.utils import ToolCallChecker, ToolCallCheckerConfig +from beeai_framework.backend import AnyMessage +from beeai_framework.backend.chat import ChatModel +from beeai_framework.backend.message import MessageTextContent, UserMessage +from beeai_framework.context import RunContext, RunMiddlewareType +from beeai_framework.emitter import Emitter +from beeai_framework.memory.base_memory import BaseMemory +from beeai_framework.memory.unconstrained_memory import UnconstrainedMemory +from beeai_framework.memory.utils import extract_last_tool_call_pair +from beeai_framework.runnable import runnable_entry +from beeai_framework.template import PromptTemplate +from beeai_framework.tools import AnyTool +from beeai_framework.tools.think import ThinkTool +from beeai_framework.utils.dicts import exclude_none +from beeai_framework.utils.lists import cast_list +from beeai_framework.utils.models import update_model +from typing_extensions import Unpack + +from ymir.agents.reasoning_agent._runner import ReasoningAgentRunner +from ymir.agents.reasoning_agent.events import reasoning_agent_event_types +from ymir.agents.reasoning_agent.prompts import ReasoningAgentTaskPromptInput +from ymir.agents.reasoning_agent.types import ( + ReasoningAgentOutput, + ReasoningAgentTemplateFactory, + ReasoningAgentTemplates, + ReasoningAgentTemplatesKeys, +) + + +class ReasoningAgent(BaseAgent[ReasoningAgentOutput]): + """ + Drop-in replacement for RequirementAgent that is also compatible with + reasoning models (e.g., Anthropic extended thinking, OpenAI o-series). + + When ``unconstrained=False`` (default), behaves like RequirementAgent: + requirements are evaluated each iteration and ThinkTool is available. + + When ``unconstrained=True``, requirements and ThinkTool are ignored and + tool_choice is always "auto", which is required by reasoning models. + """ + + def __init__( + self, + *, + llm: ChatModel | str, + memory: BaseMemory | None = None, + tools: Sequence[AnyTool] | None = None, + requirements: Sequence[RequirementAgentRequirement] | None = None, + unconstrained: bool = False, + name: str | None = None, + description: str | None = None, + role: str | None = None, + instructions: str | list[str] | None = None, + notes: str | list[str] | None = None, + tool_call_checker: ToolCallCheckerConfig | bool = True, + final_answer_as_tool: bool = True, + save_intermediate_steps: bool = True, + templates: dict[ReasoningAgentTemplatesKeys, PromptTemplate[Any] | ReasoningAgentTemplateFactory] + | ReasoningAgentTemplates + | None = None, + middlewares: list[RunMiddlewareType] | None = None, + ) -> None: + super().__init__(middlewares=middlewares) + self._llm = ChatModel.from_name(llm) if isinstance(llm, str) else llm + self._memory = memory or UnconstrainedMemory() + self._templates = self._generate_templates(templates) + self._save_intermediate_steps = save_intermediate_steps + self._tool_call_checker = tool_call_checker + self._final_answer_as_tool = final_answer_as_tool + self._unconstrained = unconstrained + self._requirements = [] if unconstrained else list(requirements or []) + if role or instructions or notes: + self._templates.system.update( + defaults=exclude_none( + { + "role": role, + "instructions": "\n -".join(cast_list(instructions)) if instructions else None, + "notes": "\n -".join(cast_list(notes)) if notes else None, + } + ) + ) + tools_list = list(tools or []) + if unconstrained: + tools_list = [t for t in tools_list if not isinstance(t, ThinkTool)] + self._tools = tools_list + self._meta = AgentMeta(name=name or "", description=description or "", tools=self._tools) + self.runner_cls: type[ReasoningAgentRunner] = ReasoningAgentRunner + + @runnable_entry + async def run( + self, input: str | list[AnyMessage], /, **kwargs: Unpack[AgentOptions] + ) -> ReasoningAgentOutput: + runner = self.runner_cls( + llm=self._llm, + config=AgentExecutionConfig( + max_retries_per_step=kwargs.get("max_retries_per_step", 3), + total_max_retries=kwargs.get("total_max_retries", 20), + max_iterations=kwargs.get("max_iterations", 20), + ), + tools=self._tools, + expected_output=kwargs.get("expected_output"), + tool_call_cycle_checker=self._create_tool_call_checker(), + run_context=RunContext.get(), + force_final_answer_as_tool=self._final_answer_as_tool, + templates=self._templates, + requirements=self._requirements, + unconstrained=self._unconstrained, + ) + new_messages = self._process_input( + input, + backstory=kwargs.get("backstory"), + expected_output=kwargs.get("expected_output"), + ) + await runner.add_messages(self.memory.messages) + await runner.add_messages(new_messages) + + final_state = await runner.run() + + if self._save_intermediate_steps: + self.memory.reset() + await self.memory.add_many(final_state.memory.messages) + else: + await self.memory.add_many(new_messages) + await self.memory.add_many(extract_last_tool_call_pair(final_state.memory) or []) + + assert final_state.answer is not None + return ReasoningAgentOutput( + output=[final_state.answer], + output_structured=final_state.result, + state=final_state, + ) + + def _process_input( + self, input: str | list[AnyMessage], backstory: str | None, expected_output: Any + ) -> list[AnyMessage]: + if not input: + return [] + + *msgs, last_message = [UserMessage(input)] if isinstance(input, str) else input + if last_message is not None and isinstance(last_message, UserMessage) and last_message.text: + user_message = UserMessage( + self._templates.task.render( + ReasoningAgentTaskPromptInput( + prompt=last_message.text, + context=backstory, + expected_output=expected_output if isinstance(expected_output, str) else None, + ) + ), + meta=last_message.meta.copy(), + ) + user_message.content.extend( + [content for content in last_message.content if not isinstance(content, MessageTextContent)] + ) + return [*msgs, user_message] + return msgs if last_message is None else [*msgs, last_message] + + def _create_emitter(self) -> Emitter: + return Emitter.root().child( + namespace=["agent", "reasoning"], creator=self, events=reasoning_agent_event_types + ) + + @property + def memory(self) -> BaseMemory: + return self._memory + + @memory.setter + def memory(self, memory: BaseMemory) -> None: + self._memory = memory + + @staticmethod + def _generate_templates( + overrides: dict[ReasoningAgentTemplatesKeys, PromptTemplate[Any] | ReasoningAgentTemplateFactory] + | ReasoningAgentTemplates + | None = None, + ) -> ReasoningAgentTemplates: + if isinstance(overrides, ReasoningAgentTemplates): + return overrides + + templates = ReasoningAgentTemplates() + if overrides is None: + return templates + + for name in ReasoningAgentTemplates.model_fields: + override: PromptTemplate[Any] | ReasoningAgentTemplateFactory | None = overrides.get(name) + if override is None: + continue + if isinstance(override, PromptTemplate): + setattr(templates, name, override) + else: + setattr(templates, name, override(getattr(templates, name))) + return templates + + async def clone(self) -> "ReasoningAgent": + cloned = ReasoningAgent( + llm=await self._llm.clone(), + memory=await self._memory.clone(), + tools=self._tools.copy(), + requirements=self._requirements.copy(), + unconstrained=self._unconstrained, + templates=self._templates.model_dump(), + tool_call_checker=( + self._tool_call_checker.config.model_copy() + if isinstance(self._tool_call_checker, ToolCallChecker) + else self._tool_call_checker + ), + save_intermediate_steps=self._save_intermediate_steps, + final_answer_as_tool=self._final_answer_as_tool, + name=self._meta.name, + description=self._meta.description, + middlewares=self.middlewares.copy(), + ) + cloned.emitter = await self.emitter.clone() + cloned.runner_cls = self.runner_cls + return cloned + + @property + def meta(self) -> AgentMeta: + parent = super().meta + + return AgentMeta( + name=self._meta.name or parent.name, + description=self._meta.description or parent.description, + extra_description=self._meta.extra_description or parent.extra_description, + tools=list(self._tools), + ) + + def _create_tool_call_checker(self) -> ToolCallChecker: + config = ToolCallCheckerConfig() + update_model(config, sources=[self._tool_call_checker]) + + instance = ToolCallChecker(config) + instance.enabled = self._tool_call_checker is not False + return instance diff --git a/ymir/agents/reasoning_agent/events.py b/ymir/agents/reasoning_agent/events.py new file mode 100644 index 00000000..fab4e8a4 --- /dev/null +++ b/ymir/agents/reasoning_agent/events.py @@ -0,0 +1,32 @@ +from typing import Any + +from beeai_framework.backend import ChatModelOutput +from pydantic import BaseModel, ConfigDict + +from ymir.agents.reasoning_agent.types import ReasoningAgentRunState, RequirementEvaluation + + +class ReasoningAgentStartEvent(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + state: ReasoningAgentRunState + evaluation: RequirementEvaluation + + +class ReasoningAgentSuccessEvent(BaseModel): + state: ReasoningAgentRunState + response: ChatModelOutput + + +class ReasoningAgentFinalAnswerEvent(BaseModel): + state: ReasoningAgentRunState + output_structured: BaseModel | Any + output: str + delta: str + + +reasoning_agent_event_types: dict[str, type] = { + "start": ReasoningAgentStartEvent, + "success": ReasoningAgentSuccessEvent, + "final_answer": ReasoningAgentFinalAnswerEvent, +} diff --git a/ymir/agents/reasoning_agent/prompts.py b/ymir/agents/reasoning_agent/prompts.py new file mode 100644 index 00000000..5eb1c765 --- /dev/null +++ b/ymir/agents/reasoning_agent/prompts.py @@ -0,0 +1,150 @@ +from datetime import UTC, datetime +from typing import Self + +from beeai_framework.agents._utils import ToolInvocationResult +from beeai_framework.template import PromptTemplate +from beeai_framework.tools import AnyTool +from beeai_framework.utils.strings import to_json +from pydantic import BaseModel, Field + + +class ReasoningAgentToolTemplateDefinition(BaseModel): + name: str + description: str + input_schema: str + allowed: str + reason: str | None + + @classmethod + def from_tool(cls, tool: AnyTool, *, allowed: bool = True, reason: str | None = None) -> Self: + return cls( + name=tool.name, + description=tool.description, + input_schema=to_json( + tool.input_schema.model_json_schema(mode="validation"), + indent=2, + sort_keys=False, + ), + allowed=str(allowed), + reason=reason, + ) + + +class ReasoningAgentSystemPromptInput(BaseModel): + role: str + instructions: str | None = None + final_answer_name: str + final_answer_schema: str | None + final_answer_instructions: str | None + notes: str | None = None + tool_constraints: str | None = None + tools: list[ReasoningAgentToolTemplateDefinition] = Field(default_factory=list) + + +ReasoningAgentSystemPrompt = PromptTemplate( + schema=ReasoningAgentSystemPromptInput, + functions={"formatDate": lambda data: datetime.now(tz=UTC).strftime("%Y-%m-%d")}, + defaults={"role": "a helpful AI assistant", "instructions": ""}, + template="""# Role +Assume the role of {{role}}. + +# Instructions +{{#instructions}} +{{&.}} +{{/instructions}} +When the user sends a message, figure out a solution and provide a final answer to the user by calling the '{{final_answer_name}}' tool. +{{#final_answer_schema}} +The final answer must fulfill the following. + +``` +{{&final_answer_schema}} +``` +{{/final_answer_schema}} +{{#final_answer_instructions}} +{{&final_answer_instructions}} +{{/final_answer_instructions}} + +IMPORTANT: The facts mentioned in the final answer must be backed by evidence provided by relevant tool outputs. + +{{#tool_constraints}} +# Tool Constraints +{{&tool_constraints}} + +{{/tool_constraints}} +# Tools +{{#tools.0}} +You must use a tool to retrieve factual or historical information. +Never use the tool twice with the same input if not stated otherwise. + +{{#tools}} +Name: {{name}} +Description: {{description}} +Allowed: {{allowed}}{{#reason}} +Reason: {{&.}}{{/reason}} + +{{/tools}} +{{/tools.0}} +{{^tools.0}} +Tools are available via the standard tool calling mechanism. +You must use a tool to retrieve factual or historical information. +Never use the tool twice with the same input if not stated otherwise. +{{/tools.0}} + +# Notes +- Use markdown syntax to format code snippets, links, JSON, tables, images, and files. +- If the provided task is unclear, ask the user for clarification. +- Do not refer to tools or tool outputs by name when responding. +- Always take it one step at a time. Don't try to do multiple things at once. +- When the tool doesn't give you what you were asking for, you must either use another tool or a different tool input. +- You should always try a few different approaches before declaring the problem unsolvable. +- If you can't fully answer the user's question, answer partially and describe what you couldn't achieve. +- You cannot do complex calculations, computations, or data manipulations without using tools. +- The current date and time is: {{formatDate}} +{{#notes}} +{{&.}} +{{/notes}} +""", # noqa: E501 +) + + +class ReasoningAgentTaskPromptInput(BaseModel): + prompt: str + context: str | None = None + expected_output: str | None = None + + +ReasoningAgentTaskPrompt = PromptTemplate( + schema=ReasoningAgentTaskPromptInput, + template="""{{#context}}This is the context relevant to the task: +{{&.}} + +{{/context}} +{{#expected_output}} +This is the expected criteria for your output: +{{.}} + +{{/expected_output}} +Your task: {{prompt}}""", +) + + +class ReasoningAgentToolErrorPromptInput(BaseModel): + reason: str + + +ReasoningAgentToolErrorPrompt = PromptTemplate( + schema=ReasoningAgentToolErrorPromptInput, + template="""The tool has failed; the error log is shown below. If the tool cannot accomplish what you want, use a different tool or explain why you can't use it. + +{{&reason}}""", # noqa: E501 +) + + +class ReasoningAgentToolNoResultTemplateInput(BaseModel): + tool_call: ToolInvocationResult + + +ReasoningAgentToolNoResultPrompt = PromptTemplate( + schema=ReasoningAgentToolNoResultTemplateInput, + template="""No results were found! Try to reformulate your query or use a different tool.""", +) diff --git a/ymir/agents/reasoning_agent/types.py b/ymir/agents/reasoning_agent/types.py new file mode 100644 index 00000000..98a1a89f --- /dev/null +++ b/ymir/agents/reasoning_agent/types.py @@ -0,0 +1,83 @@ +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Annotated, Any + +from beeai_framework.agents import AgentOutput +from beeai_framework.backend import AssistantMessage, UserMessage +from beeai_framework.backend.types import ChatModelCost, ChatModelUsage +from beeai_framework.errors import FrameworkError +from beeai_framework.memory import BaseMemory +from beeai_framework.template import PromptTemplate +from beeai_framework.tools import AnyTool, Tool, ToolOutput +from pydantic import BaseModel, ConfigDict, Field, InstanceOf + +from ymir.agents.reasoning_agent.prompts import ( + ReasoningAgentSystemPrompt, + ReasoningAgentSystemPromptInput, + ReasoningAgentTaskPrompt, + ReasoningAgentTaskPromptInput, + ReasoningAgentToolErrorPrompt, + ReasoningAgentToolErrorPromptInput, + ReasoningAgentToolNoResultPrompt, + ReasoningAgentToolNoResultTemplateInput, +) + + +class ReasoningAgentTemplates(BaseModel): + system: InstanceOf[PromptTemplate[ReasoningAgentSystemPromptInput]] = Field( + default_factory=lambda: ReasoningAgentSystemPrompt.fork(None), + ) + task: InstanceOf[PromptTemplate[ReasoningAgentTaskPromptInput]] = Field( + default_factory=lambda: ReasoningAgentTaskPrompt.fork(None), + ) + tool_error: InstanceOf[PromptTemplate[ReasoningAgentToolErrorPromptInput]] = Field( + default_factory=lambda: ReasoningAgentToolErrorPrompt.fork(None), + ) + tool_no_result: InstanceOf[PromptTemplate[ReasoningAgentToolNoResultTemplateInput]] = Field( + default_factory=lambda: ReasoningAgentToolNoResultPrompt.fork(None), + ) + + +ReasoningAgentTemplateFactory = Callable[[InstanceOf[PromptTemplate[Any]]], InstanceOf[PromptTemplate[Any]]] +ReasoningAgentTemplatesKeys = Annotated[str, lambda v: v in ReasoningAgentTemplates.model_fields] + + +class ReasoningAgentRunStateStep(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + iteration: int + tool: InstanceOf[Tool[Any, Any, Any]] | None + input: Any + output: InstanceOf[ToolOutput] + error: InstanceOf[FrameworkError] | None + + +class ReasoningAgentRunState(BaseModel): + answer: InstanceOf[AssistantMessage] | None = None + result: Any + memory: InstanceOf[BaseMemory] + iteration: int + steps: list[ReasoningAgentRunStateStep] = [] + usage: ChatModelUsage = ChatModelUsage() + cost: ChatModelCost = ChatModelCost() + + @property + def input(self) -> UserMessage: + return next(msg for msg in reversed(self.memory.messages) if isinstance(msg, UserMessage)) + + +class ReasoningAgentOutput(AgentOutput): + state: ReasoningAgentRunState + + +@dataclass +class RequirementEvaluation: + allowed_tools: list[AnyTool] = field(default_factory=list) + hidden_tools: list[AnyTool] = field(default_factory=list) + forced_tool: AnyTool | None = None + can_stop: bool = True + constraint_text: str | None = None + tool_choice: AnyTool | str = "auto" + reason_by_tool: dict[str, str | None] = field(default_factory=dict) + all_tools: list[AnyTool] = field(default_factory=list) diff --git a/ymir/agents/rebase_agent.py b/ymir/agents/rebase_agent.py index 16064f6a..5c19eeb4 100644 --- a/ymir/agents/rebase_agent.py +++ b/ymir/agents/rebase_agent.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Any -from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ( ConditionalRequirement, ) @@ -27,11 +26,13 @@ from ymir.agents.log_agent import get_prompt as get_log_prompt from ymir.agents.observability import setup_observability from ymir.agents.package_update_steps import PackageUpdateState, PackageUpdateStep +from ymir.agents.reasoning_agent import ReasoningAgent from ymir.agents.utils import ( format_mr_justification, get_agent_execution_config, get_chat_model, get_tool_call_checker_config, + is_reasoning_enabled, mcp_tools, render_prompt, resolve_chat_model_override, @@ -169,10 +170,11 @@ def get_prompt() -> str: """ -def create_rebase_agent(mcp_tools: list[Tool], local_tool_options: dict[str, Any]) -> RequirementAgent: - return RequirementAgent( +def create_rebase_agent(mcp_tools: list[Tool], local_tool_options: dict[str, Any]) -> ReasoningAgent: + return ReasoningAgent( name="RebaseAgent", llm=get_chat_model(), + unconstrained=is_reasoning_enabled(), tool_call_checker=get_tool_call_checker_config(), tools=[ ThinkTool(), diff --git a/ymir/agents/rebuild_consolidation.py b/ymir/agents/rebuild_consolidation.py index d5e4004f..a2328b9d 100644 --- a/ymir/agents/rebuild_consolidation.py +++ b/ymir/agents/rebuild_consolidation.py @@ -2,16 +2,17 @@ from pathlib import Path from textwrap import dedent -from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.memory import UnconstrainedMemory from beeai_framework.tools import Tool from pydantic import BaseModel, Field from ymir.agents.cve_applicability_agent import build_applicability_prompt, create_applicability_agent +from ymir.agents.reasoning_agent import ReasoningAgent from ymir.agents.utils import ( get_agent_execution_config, get_chat_model, get_tool_call_checker_config, + is_reasoning_enabled, run_tool, ) from ymir.common.models import ( @@ -141,9 +142,10 @@ async def find_rebuild_siblings( continue try: - analysis_agent = RequirementAgent( + analysis_agent = ReasoningAgent( name="SiblingRebuildAnalyzer", llm=get_chat_model(), + unconstrained=is_reasoning_enabled(), tool_call_checker=get_tool_call_checker_config(), tools=analysis_tools, memory=UnconstrainedMemory(), diff --git a/ymir/agents/triage_agent.py b/ymir/agents/triage_agent.py index 9b6b484f..e3bf1802 100644 --- a/ymir/agents/triage_agent.py +++ b/ymir/agents/triage_agent.py @@ -7,7 +7,6 @@ from pathlib import Path from textwrap import dedent -from beeai_framework.agents.requirement import RequirementAgent from beeai_framework.agents.requirement.requirements.conditional import ( ConditionalRequirement, ) @@ -24,12 +23,14 @@ import ymir.agents.tasks as tasks from ymir.agents.cve_applicability_agent import build_applicability_prompt, create_applicability_agent from ymir.agents.observability import setup_observability +from ymir.agents.reasoning_agent import ReasoningAgent from ymir.agents.rebuild_consolidation import find_rebuild_siblings from ymir.agents.utils import ( build_agent_factory_with_mock_repos, get_agent_execution_config, get_chat_model, get_tool_call_checker_config, + is_reasoning_enabled, mcp_tools, resolve_chat_model_override, run_tool, @@ -531,10 +532,11 @@ class TriageState(BaseModel): applicability_check_skipped: bool = Field(default=False) -def create_triage_agent(gateway_tools, local_tool_options=None): - return RequirementAgent( +def create_triage_agent(gateway_tools, local_tool_options=None) -> ReasoningAgent: + return ReasoningAgent( name="TriageAgent", llm=get_chat_model(), + unconstrained=is_reasoning_enabled(), tool_call_checker=get_tool_call_checker_config(), tools=[ ThinkTool(), diff --git a/ymir/agents/utils.py b/ymir/agents/utils.py index 1fc31e8c..8cabe386 100644 --- a/ymir/agents/utils.py +++ b/ymir/agents/utils.py @@ -26,22 +26,32 @@ def resolve_chat_model_override(agent_type: str) -> None: os.environ["CHAT_MODEL"] = override +def is_reasoning_enabled() -> bool: + chat_model = os.environ.get("CHAT_MODEL", "") + return "claude" in chat_model and bool(os.getenv("REASONING_EFFORT")) + + def get_chat_model() -> ChatModel: chat_model = os.environ["CHAT_MODEL"] + # lowering the temperature makes the model stop backporting too soon + # but should yield more predictable results, similar for top_p (tried 0.5) + temperature = float(os.getenv("TEMPERATURE", "0.6")) + reasoning_effort = os.getenv("REASONING_EFFORT") model = ChatModel.from_name( chat_model, # this the preferred way to set parameters, don't do options=... # it was changed in beeai 0.1.48 ChatModelParameters( - # lowering the temperature makes the model stop backporting too soon - # but should yield more predictable results, similar for top_p (tried 0.5) - temperature=0.6 + # Anthropic requires temperature=1 when extended thinking is enabled + temperature=1 if "claude" in chat_model and reasoning_effort else temperature, + reasoning_effort=reasoning_effort, ), timeout=1200, ) if "gemini" in chat_model: # disable `required` for Gemini models model.tool_choice_support = {"single", "none", "auto"} + model.allow_prompt_caching = False return model