Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 142 additions & 22 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
ToolCallOutputItem,
TResponseInputItem,
)
from .lifecycle import RunHooks
from .lifecycle import RunHooks, _call_hook_with_data
from .logger import logger
from .model_settings import ModelSettings
from .models.interface import ModelTracing
Expand Down Expand Up @@ -909,9 +909,21 @@ async def _execute_tool_with_hooks(
The result from the tool execution.
"""
await asyncio.gather(
hooks.on_tool_start(tool_context, agent, func_tool),
_call_hook_with_data(
hooks.on_tool_start,
tool_context,
agent,
func_tool,
data={"arguments": tool_call.arguments},
),
(
agent.hooks.on_tool_start(tool_context, agent, func_tool)
_call_hook_with_data(
agent.hooks.on_tool_start,
tool_context,
agent,
func_tool,
data={"arguments": tool_call.arguments},
)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -979,10 +991,22 @@ async def run_single_tool(

# 4) Tool end hooks (with final result, which may have been overridden)
await asyncio.gather(
hooks.on_tool_end(tool_context, agent, func_tool, final_result),
_call_hook_with_data(
hooks.on_tool_end,
tool_context,
agent,
func_tool,
final_result,
data={"arguments": tool_call.arguments},
),
(
agent.hooks.on_tool_end(
tool_context, agent, func_tool, final_result
_call_hook_with_data(
agent.hooks.on_tool_end,
tool_context,
agent,
func_tool,
final_result,
data={"arguments": tool_call.arguments},
)
if agent.hooks
else _coro.noop_coroutine()
Expand Down Expand Up @@ -1553,20 +1577,44 @@ async def execute(
else cls._get_screenshot_sync(computer, action.tool_call)
)

tool_args = getattr(action.tool_call, "arguments", None)
tool_data = {"arguments": tool_args} if tool_args else None
_, _, output = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
_call_hook_with_data(
hooks.on_tool_start, context_wrapper, agent, action.computer_tool, data=tool_data
),
(
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
_call_hook_with_data(
agent.hooks.on_tool_start,
context_wrapper,
agent,
action.computer_tool,
data=tool_data,
)
if agent.hooks
else _coro.noop_coroutine()
),
output_func,
)

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
_call_hook_with_data(
hooks.on_tool_end,
context_wrapper,
agent,
action.computer_tool,
output,
data=tool_data,
),
(
agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output)
_call_hook_with_data(
agent.hooks.on_tool_end,
context_wrapper,
agent,
action.computer_tool,
output,
data=tool_data,
)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -1656,10 +1704,20 @@ async def execute(
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> RunItem:
tool_args = getattr(call.tool_call, "arguments", None)
tool_data = {"arguments": tool_args} if tool_args else None
await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool),
_call_hook_with_data(
hooks.on_tool_start, context_wrapper, agent, call.local_shell_tool, data=tool_data
),
(
agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool)
_call_hook_with_data(
agent.hooks.on_tool_start,
context_wrapper,
agent,
call.local_shell_tool,
data=tool_data,
)
if agent.hooks
else _coro.noop_coroutine()
),
Expand All @@ -1676,9 +1734,23 @@ async def execute(
result = output

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
_call_hook_with_data(
hooks.on_tool_end,
context_wrapper,
agent,
call.local_shell_tool,
result,
data=tool_data,
),
(
agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result)
_call_hook_with_data(
agent.hooks.on_tool_end,
context_wrapper,
agent,
call.local_shell_tool,
result,
data=tool_data,
)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -1707,10 +1779,20 @@ async def execute(
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> RunItem:
tool_args = getattr(call.tool_call, "arguments", None)
tool_data = {"arguments": tool_args} if tool_args else None
await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, call.shell_tool),
_call_hook_with_data(
hooks.on_tool_start, context_wrapper, agent, call.shell_tool, data=tool_data
),
(
agent.hooks.on_tool_start(context_wrapper, agent, call.shell_tool)
_call_hook_with_data(
agent.hooks.on_tool_start,
context_wrapper,
agent,
call.shell_tool,
data=tool_data,
)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -1744,9 +1826,23 @@ async def execute(
logger.error("Shell executor failed: %s", exc, exc_info=True)

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text),
_call_hook_with_data(
hooks.on_tool_end,
context_wrapper,
agent,
call.shell_tool,
output_text,
data=tool_data,
),
(
agent.hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text)
_call_hook_with_data(
agent.hooks.on_tool_end,
context_wrapper,
agent,
call.shell_tool,
output_text,
data=tool_data,
)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -1831,10 +1927,20 @@ async def execute(
config: RunConfig,
) -> RunItem:
apply_patch_tool = call.apply_patch_tool
tool_args = getattr(call.tool_call, "arguments", None)
tool_data = {"arguments": tool_args} if tool_args else None
await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, apply_patch_tool),
_call_hook_with_data(
hooks.on_tool_start, context_wrapper, agent, apply_patch_tool, data=tool_data
),
(
agent.hooks.on_tool_start(context_wrapper, agent, apply_patch_tool)
_call_hook_with_data(
agent.hooks.on_tool_start,
context_wrapper,
agent,
apply_patch_tool,
data=tool_data,
)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -1871,9 +1977,23 @@ async def execute(
logger.error("Apply patch editor failed: %s", exc, exc_info=True)

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text),
_call_hook_with_data(
hooks.on_tool_end,
context_wrapper,
agent,
apply_patch_tool,
output_text,
data=tool_data,
),
(
agent.hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text)
_call_hook_with_data(
agent.hooks.on_tool_end,
context_wrapper,
agent,
apply_patch_tool,
output_text,
data=tool_data,
)
if agent.hooks
else _coro.noop_coroutine()
),
Expand Down
33 changes: 30 additions & 3 deletions src/agents/lifecycle.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Generic, Optional
import inspect
from typing import Any, Callable, Generic, Optional

from typing_extensions import TypeVar

Expand Down Expand Up @@ -34,7 +35,12 @@ async def on_llm_end(
"""Called immediately after the LLM call returns for this agent."""
pass

async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
async def on_agent_start(
self,
context: RunContextWrapper[TContext],
agent: TAgent,
data: Optional[dict[str, Any]] = None,
) -> None:
"""Called before the agent is invoked. Called each time the current agent changes."""
pass

Expand All @@ -61,6 +67,7 @@ async def on_tool_start(
context: RunContextWrapper[TContext],
agent: TAgent,
tool: Tool,
data: Optional[dict[str, Any]] = None,
) -> None:
"""Called immediately before a local tool is invoked."""
pass
Expand All @@ -71,6 +78,7 @@ async def on_tool_end(
agent: TAgent,
tool: Tool,
result: str,
data: Optional[dict[str, Any]] = None,
) -> None:
"""Called immediately after a local tool is invoked."""
pass
Expand All @@ -83,7 +91,12 @@ class AgentHooksBase(Generic[TContext, TAgent]):
Subclass and override the methods you need.
"""

async def on_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
async def on_start(
self,
context: RunContextWrapper[TContext],
agent: TAgent,
data: Optional[dict[str, Any]] = None,
) -> None:
"""Called before the agent is invoked. Called each time the running agent is changed to this
agent."""
pass
Expand Down Expand Up @@ -112,6 +125,7 @@ async def on_tool_start(
context: RunContextWrapper[TContext],
agent: TAgent,
tool: Tool,
data: Optional[dict[str, Any]] = None,
) -> None:
"""Called immediately before a local tool is invoked."""
pass
Expand All @@ -122,6 +136,7 @@ async def on_tool_end(
agent: TAgent,
tool: Tool,
result: str,
data: Optional[dict[str, Any]] = None,
) -> None:
"""Called immediately after a local tool is invoked."""
pass
Expand Down Expand Up @@ -151,3 +166,15 @@ async def on_llm_end(

AgentHooks = AgentHooksBase[TContext, Agent]
"""Agent hooks for `Agent`s."""


async def _call_hook_with_data(
hook_method: Callable[..., Any],
*args: Any,
data: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
sig = inspect.signature(hook_method)
if "data" in sig.parameters:
return await hook_method(*args, data=data, **kwargs)
return await hook_method(*args, **kwargs)
Loading
Loading