diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 4b6b6df64..42a2bd620 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -71,7 +71,7 @@ ToolCallOutputItem, TResponseInputItem, ) -from .lifecycle import RunHooks +from .lifecycle import RunHooks, _call_hook_with_data from .logger import logger from .model_settings import ModelSettings from .models.interface import ModelTracing @@ -909,9 +909,21 @@ async def _execute_tool_with_hooks( The result from the tool execution. """ await asyncio.gather( - hooks.on_tool_start(tool_context, agent, func_tool), + _call_hook_with_data( + hooks.on_tool_start, + tool_context, + agent, + func_tool, + data={"arguments": tool_call.arguments}, + ), ( - agent.hooks.on_tool_start(tool_context, agent, func_tool) + _call_hook_with_data( + agent.hooks.on_tool_start, + tool_context, + agent, + func_tool, + data={"arguments": tool_call.arguments}, + ) if agent.hooks else _coro.noop_coroutine() ), @@ -979,10 +991,22 @@ async def run_single_tool( # 4) Tool end hooks (with final result, which may have been overridden) await asyncio.gather( - hooks.on_tool_end(tool_context, agent, func_tool, final_result), + _call_hook_with_data( + hooks.on_tool_end, + tool_context, + agent, + func_tool, + final_result, + data={"arguments": tool_call.arguments}, + ), ( - agent.hooks.on_tool_end( - tool_context, agent, func_tool, final_result + _call_hook_with_data( + agent.hooks.on_tool_end, + tool_context, + agent, + func_tool, + final_result, + data={"arguments": tool_call.arguments}, ) if agent.hooks else _coro.noop_coroutine() @@ -1553,10 +1577,20 @@ async def execute( else cls._get_screenshot_sync(computer, action.tool_call) ) + tool_args = getattr(action.tool_call, "arguments", None) + tool_data = {"arguments": tool_args} if tool_args else None _, _, output = await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, action.computer_tool), + _call_hook_with_data( + hooks.on_tool_start, context_wrapper, agent, action.computer_tool, data=tool_data + ), ( - agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool) + _call_hook_with_data( + agent.hooks.on_tool_start, + context_wrapper, + agent, + action.computer_tool, + data=tool_data, + ) if agent.hooks else _coro.noop_coroutine() ), @@ -1564,9 +1598,23 @@ async def execute( ) await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output), + _call_hook_with_data( + hooks.on_tool_end, + context_wrapper, + agent, + action.computer_tool, + output, + data=tool_data, + ), ( - agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output) + _call_hook_with_data( + agent.hooks.on_tool_end, + context_wrapper, + agent, + action.computer_tool, + output, + data=tool_data, + ) if agent.hooks else _coro.noop_coroutine() ), @@ -1656,10 +1704,20 @@ async def execute( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> RunItem: + tool_args = getattr(call.tool_call, "arguments", None) + tool_data = {"arguments": tool_args} if tool_args else None await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool), + _call_hook_with_data( + hooks.on_tool_start, context_wrapper, agent, call.local_shell_tool, data=tool_data + ), ( - agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool) + _call_hook_with_data( + agent.hooks.on_tool_start, + context_wrapper, + agent, + call.local_shell_tool, + data=tool_data, + ) if agent.hooks else _coro.noop_coroutine() ), @@ -1676,9 +1734,23 @@ async def execute( result = output await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result), + _call_hook_with_data( + hooks.on_tool_end, + context_wrapper, + agent, + call.local_shell_tool, + result, + data=tool_data, + ), ( - agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result) + _call_hook_with_data( + agent.hooks.on_tool_end, + context_wrapper, + agent, + call.local_shell_tool, + result, + data=tool_data, + ) if agent.hooks else _coro.noop_coroutine() ), @@ -1707,10 +1779,20 @@ async def execute( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> RunItem: + tool_args = getattr(call.tool_call, "arguments", None) + tool_data = {"arguments": tool_args} if tool_args else None await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, call.shell_tool), + _call_hook_with_data( + hooks.on_tool_start, context_wrapper, agent, call.shell_tool, data=tool_data + ), ( - agent.hooks.on_tool_start(context_wrapper, agent, call.shell_tool) + _call_hook_with_data( + agent.hooks.on_tool_start, + context_wrapper, + agent, + call.shell_tool, + data=tool_data, + ) if agent.hooks else _coro.noop_coroutine() ), @@ -1744,9 +1826,23 @@ async def execute( logger.error("Shell executor failed: %s", exc, exc_info=True) await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text), + _call_hook_with_data( + hooks.on_tool_end, + context_wrapper, + agent, + call.shell_tool, + output_text, + data=tool_data, + ), ( - agent.hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text) + _call_hook_with_data( + agent.hooks.on_tool_end, + context_wrapper, + agent, + call.shell_tool, + output_text, + data=tool_data, + ) if agent.hooks else _coro.noop_coroutine() ), @@ -1831,10 +1927,20 @@ async def execute( config: RunConfig, ) -> RunItem: apply_patch_tool = call.apply_patch_tool + tool_args = getattr(call.tool_call, "arguments", None) + tool_data = {"arguments": tool_args} if tool_args else None await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, apply_patch_tool), + _call_hook_with_data( + hooks.on_tool_start, context_wrapper, agent, apply_patch_tool, data=tool_data + ), ( - agent.hooks.on_tool_start(context_wrapper, agent, apply_patch_tool) + _call_hook_with_data( + agent.hooks.on_tool_start, + context_wrapper, + agent, + apply_patch_tool, + data=tool_data, + ) if agent.hooks else _coro.noop_coroutine() ), @@ -1871,9 +1977,23 @@ async def execute( logger.error("Apply patch editor failed: %s", exc, exc_info=True) await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text), + _call_hook_with_data( + hooks.on_tool_end, + context_wrapper, + agent, + apply_patch_tool, + output_text, + data=tool_data, + ), ( - agent.hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text) + _call_hook_with_data( + agent.hooks.on_tool_end, + context_wrapper, + agent, + apply_patch_tool, + output_text, + data=tool_data, + ) if agent.hooks else _coro.noop_coroutine() ), diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 85ea26bc8..2d3ef9135 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -1,4 +1,5 @@ -from typing import Any, Generic, Optional +import inspect +from typing import Any, Callable, Generic, Optional from typing_extensions import TypeVar @@ -34,7 +35,12 @@ async def on_llm_end( """Called immediately after the LLM call returns for this agent.""" pass - async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None: + async def on_agent_start( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + data: Optional[dict[str, Any]] = None, + ) -> None: """Called before the agent is invoked. Called each time the current agent changes.""" pass @@ -61,6 +67,7 @@ async def on_tool_start( context: RunContextWrapper[TContext], agent: TAgent, tool: Tool, + data: Optional[dict[str, Any]] = None, ) -> None: """Called immediately before a local tool is invoked.""" pass @@ -71,6 +78,7 @@ async def on_tool_end( agent: TAgent, tool: Tool, result: str, + data: Optional[dict[str, Any]] = None, ) -> None: """Called immediately after a local tool is invoked.""" pass @@ -83,7 +91,12 @@ class AgentHooksBase(Generic[TContext, TAgent]): Subclass and override the methods you need. """ - async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None: + async def on_start( + self, + context: RunContextWrapper[TContext], + agent: TAgent, + data: Optional[dict[str, Any]] = None, + ) -> None: """Called before the agent is invoked. Called each time the running agent is changed to this agent.""" pass @@ -112,6 +125,7 @@ async def on_tool_start( context: RunContextWrapper[TContext], agent: TAgent, tool: Tool, + data: Optional[dict[str, Any]] = None, ) -> None: """Called immediately before a local tool is invoked.""" pass @@ -122,6 +136,7 @@ async def on_tool_end( agent: TAgent, tool: Tool, result: str, + data: Optional[dict[str, Any]] = None, ) -> None: """Called immediately after a local tool is invoked.""" pass @@ -151,3 +166,15 @@ async def on_llm_end( AgentHooks = AgentHooksBase[TContext, Agent] """Agent hooks for `Agent`s.""" + + +async def _call_hook_with_data( + hook_method: Callable[..., Any], + *args: Any, + data: Optional[dict[str, Any]] = None, + **kwargs: Any, +) -> Any: + sig = inspect.signature(hook_method) + if "data" in sig.parameters: + return await hook_method(*args, data=data, **kwargs) + return await hook_method(*args, **kwargs) diff --git a/src/agents/run.py b/src/agents/run.py index f6707b33b..a08023c0f 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -57,7 +57,7 @@ ToolCallItemTypes, TResponseInputItem, ) -from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase +from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase, _call_hook_with_data from .logger import logger from .memory import Session, SessionInputCallback from .model_settings import ModelSettings @@ -1367,10 +1367,24 @@ async def _run_single_turn_streamed( emitted_reasoning_item_ids: set[str] = set() if should_run_agent_start_hooks: + if server_conversation_tracker is not None: + turn_input = server_conversation_tracker.prepare_input( + streamed_result.input, streamed_result.new_items + ) + else: + turn_input = ItemHelpers.input_to_new_input_list(streamed_result.input) + turn_input.extend([item.to_input_item() for item in streamed_result.new_items]) await asyncio.gather( - hooks.on_agent_start(context_wrapper, agent), + _call_hook_with_data( + hooks.on_agent_start, context_wrapper, agent, data={"turn_input": turn_input} + ), ( - agent.hooks.on_start(context_wrapper, agent) + _call_hook_with_data( + agent.hooks.on_start, + context_wrapper, + agent, + data={"turn_input": turn_input}, + ) if agent.hooks else _coro.noop_coroutine() ), @@ -1591,10 +1605,26 @@ async def _run_single_turn( ) -> SingleStepResult: # Ensure we run the hooks before anything else if should_run_agent_start_hooks: + if server_conversation_tracker is not None: + turn_input = server_conversation_tracker.prepare_input( + original_input, generated_items + ) + else: + turn_input = ItemHelpers.input_to_new_input_list(original_input) + turn_input.extend( + [generated_item.to_input_item() for generated_item in generated_items] + ) await asyncio.gather( - hooks.on_agent_start(context_wrapper, agent), + _call_hook_with_data( + hooks.on_agent_start, context_wrapper, agent, data={"turn_input": turn_input} + ), ( - agent.hooks.on_start(context_wrapper, agent) + _call_hook_with_data( + agent.hooks.on_start, + context_wrapper, + agent, + data={"turn_input": turn_input}, + ) if agent.hooks else _coro.noop_coroutine() ), diff --git a/tests/test_run_hooks.py b/tests/test_run_hooks.py index f5a2ed478..f430c0f96 100644 --- a/tests/test_run_hooks.py +++ b/tests/test_run_hooks.py @@ -15,6 +15,7 @@ from .fake_model import FakeModel from .test_responses import ( get_function_tool, + get_function_tool_call, get_text_message, ) @@ -244,3 +245,72 @@ async def test_streamed_run_hooks_llm_error(monkeypatch): assert hooks.events["on_llm_start"] == 1 assert hooks.events["on_llm_end"] == 0 assert hooks.events["on_agent_end"] == 0 + + +class RunHooksWithData(RunHooks): + """Hooks that accept the optional data parameter to verify new functionality.""" + + def __init__(self): + self.captured_turn_inputs: list[list[TResponseInputItem]] = [] + self.captured_tool_arguments: list[str] = [] + + async def on_agent_start( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + data: Optional[dict[str, Any]] = None, + ) -> None: + if data and "turn_input" in data: + self.captured_turn_inputs.append(data["turn_input"]) + + async def on_tool_start( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + tool: Tool, + data: Optional[dict[str, Any]] = None, + ) -> None: + if data and "arguments" in data: + self.captured_tool_arguments.append(data["arguments"]) + + async def on_tool_end( + self, + context: RunContextWrapper[TContext], + agent: Agent[TContext], + tool: Tool, + result: str, + data: Optional[dict[str, Any]] = None, + ) -> None: + if data and "arguments" in data: + # Verify arguments are also available in on_tool_end + assert data["arguments"] in self.captured_tool_arguments + + +@pytest.mark.asyncio +async def test_hooks_receive_turn_input_and_arguments(): + """Verify that hooks with data parameter receive turn_input and arguments.""" + hooks = RunHooksWithData() + model = FakeModel() + agent = Agent(name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[]) + + # First turn: tool call + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("f", '{"test": "arg"}')], + [get_text_message("done")], + ] + ) + + await Runner.run(agent, input="test input", hooks=hooks) + + # Verify turn_input was captured + assert len(hooks.captured_turn_inputs) == 1 + turn_input = hooks.captured_turn_inputs[0] + assert len(turn_input) == 1 + assert turn_input[0]["role"] == "user" + # For string input, content is the string itself + assert turn_input[0]["content"] == "test input" + + # Verify tool arguments were captured + assert len(hooks.captured_tool_arguments) == 1 + assert hooks.captured_tool_arguments[0] == '{"test": "arg"}'