diff --git a/CLAUDE.md b/CLAUDE.md index 1659d03..ea3ec78 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -80,6 +80,30 @@ Behind the `openapi` Cargo feature. `OpenApiToolAdapter` parses an OpenAPI 3.0 s `McpClient` communicates via `McpTransport` trait (stdio or HTTP). `McpToolAdapter` wraps MCP tools to implement `AgentTool`, making them transparent to the agent loop. Added via `Agent::with_mcp_server_stdio()` / `with_mcp_server_http()`. +### Shared State (`shared_state.rs`) + +`SharedState` is a pluggable key-value store (`Arc`) for sub-agent communication. It lets a parent store large artifacts once and have multiple sub-agents read/write by reference — no re-pasting into prompts. + +- Two built-in backends: `MemoryBackend` (default, `HashMap` with 10MB cap) and `FileBackend` (one file per key, persistent) +- Custom backends implement the `SharedStateBackend` trait +- Opt-in via `SubAgentTool::with_shared_state(state)` — injects a `shared_state` tool and appends a state summary to the sub-agent's system prompt automatically +- Actions: `get`, `set`, `list`, `remove` +- Does **not** touch the core agent loop — wired entirely through `SubAgentTool` + +### Sub-Agent Multi-Provider Support + +`SubAgentTool` supports any provider via `with_model_config()`. Without it, sub-agents default to Anthropic. For non-Anthropic providers (OpenAI, xAI, Groq, etc.), pass the appropriate `ModelConfig`: + +```rust +let config = ModelConfig::xai("grok-3-mini-fast", "Grok 3 Mini Fast"); +let sub = SubAgentTool::new("analyst", Arc::new(OpenAiCompatProvider)) + .with_model(&config.id) + .with_api_key(&key) + .with_model_config(config); +``` + +`AgentLoopConfig` also supports `turn_delay: Option` — an inter-turn delay to throttle API calls for rate-limit-sensitive providers. Exposed on `SubAgentTool` via `with_turn_delay()`. + ### Testing All unit tests use `MockProvider` (`provider/mock.rs`) to simulate LLM responses without network. Test files are in `tests/` — `agent_test.rs`, `agent_loop_test.rs`, `tools_test.rs`. Follow the existing pattern of constructing a `MockProvider` with predetermined responses. diff --git a/docs/concepts/agent-loop.md b/docs/concepts/agent-loop.md index 1e9b4ad..50d503f 100644 --- a/docs/concepts/agent-loop.md +++ b/docs/concepts/agent-loop.md @@ -90,6 +90,7 @@ pub struct AgentLoopConfig { pub on_error: Option, pub input_filters: Vec>, pub compaction_strategy: Option>, + pub turn_delay: Option, } ``` @@ -114,6 +115,7 @@ pub struct AgentLoopConfig { | `on_error` | Called on `StopReason::Error` with the error string (see [Callbacks](callbacks.md)) | | `input_filters` | Input filters applied to user messages before the LLM call (see [Tools](tools.md)) | | `compaction_strategy` | Custom compaction strategy (see [Custom Compaction](#custom-compaction) below) | +| `turn_delay` | Optional inter-turn delay to throttle API calls. Skips the first turn. Useful for rate-limit-sensitive providers (e.g., OAuth tokens with low RPM caps) | ## Steering & Follow-Ups diff --git a/docs/concepts/sub-agents.md b/docs/concepts/sub-agents.md index 802e57f..954dc0d 100644 --- a/docs/concepts/sub-agents.md +++ b/docs/concepts/sub-agents.md @@ -62,10 +62,14 @@ When the parent LLM calls multiple sub-agents in a single response, they run con | `with_description()` | What the parent LLM sees (helps it decide when to delegate) | | `with_system_prompt()` | The sub-agent's own instructions | | `with_model()` / `with_api_key()` | Can use a different model than the parent | +| `with_model_config()` | Set `ModelConfig` for non-Anthropic providers (base URL, compat flags, etc.) | | `with_tools()` | Tools available to the sub-agent (accepts `Vec>`) | | `with_max_turns(N)` | Turn limit (default: 10). Primary guard against runaway execution. | | `with_thinking()` | Enable extended thinking for the sub-agent | | `with_cache_config()` | Prompt caching settings | +| `with_turn_delay()` | Inter-turn delay to throttle API calls (useful for rate-limit-sensitive providers) | +| `with_retry_config()` | Custom retry configuration for transient errors | +| `with_tool_execution()` | Tool execution strategy (`Parallel`, `Sequential`, `Batched`) | ## Event Forwarding @@ -74,13 +78,126 @@ When the parent provides an `on_update` callback (standard for all tools), sub-a - Text deltas from the sub-agent's LLM responses - Tool call notifications from the sub-agent's tool usage +## Shared State + +By default, each sub-agent invocation is isolated — to pass data between sub-agents, the parent must re-paste it into every prompt. For large artifacts (CI logs, codebases, analysis results), this wastes context tokens. + +`SharedState` solves this: store an artifact once, and any number of sub-agents read/write it by reference. + +```rust +use yoagent::shared_state::SharedState; + +let state = SharedState::new(); +state.set("ci_log", large_log_text).await.unwrap(); + +let analyzer = SubAgentTool::new("analyzer", provider.clone()) + .with_system_prompt("Analyze the CI log for failures.") + .with_model("claude-sonnet-4-20250514") + .with_api_key(&api_key) + .with_shared_state(state.clone()); // opt-in +``` + +When `.with_shared_state()` is used, the sub-agent automatically gets: + +1. A `shared_state` tool with `get`, `set`, `list`, and `remove` actions +2. A system prompt appendix listing available keys and their sizes + +The sub-agent reads the artifact via tool call instead of having it pasted into the prompt: + +``` +Sub-agent calls: shared_state(action="get", key="ci_log") +Sub-agent calls: shared_state(action="set", key="summary", value="...") +``` + +The parent reads results back programmatically: + +```rust +let summary = state.get("summary").await.expect("sub-agent wrote this"); +``` + +### Parallel Sub-Agents with Shared State + +Multiple sub-agents can share the same `SharedState` concurrently. Each gets its own clone of the `Arc` handle — reads are concurrent, writes are serialized by `tokio::sync::RwLock`. + +```rust +let error_analyst = SubAgentTool::new("error_analyst", provider.clone()) + .with_shared_state(state.clone()); +let perf_analyst = SubAgentTool::new("perf_analyst", provider.clone()) + .with_shared_state(state.clone()); + +// Both run in parallel, reading the same artifact and writing different keys +``` + +### Backends + +`SharedState` is backed by a pluggable `SharedStateBackend` trait. Two built-in backends are provided: + +**MemoryBackend** (default) — in-memory `HashMap` with a byte capacity limit: + +```rust +let state = SharedState::new(); // 10MB default +let state = SharedState::with_max_bytes(50 * 1024 * 1024); // 50MB +``` + +A `set` call that would exceed capacity returns `Err(CapacityError)`. + +**FileBackend** — one file per key, persistent across process restarts: + +```rust +use yoagent::shared_state::FileBackend; + +let state = SharedState::with_backend(FileBackend::new(".agent-state")); +``` + +Keys are percent-encoded to filenames (reversible, no collisions). Useful for debugging (inspect state with `ls` / `cat`) and for long-running workflows where memory limits matter. + +**Custom backends** implement the `SharedStateBackend` trait: + +```rust +use yoagent::shared_state::{SharedStateBackend, SharedStateError}; + +#[async_trait::async_trait] +impl SharedStateBackend for MyRedisBackend { + async fn get(&self, key: &str) -> Result, SharedStateError> { ... } + async fn set(&self, key: &str, value: String) -> Result<(), SharedStateError> { ... } + async fn remove(&self, key: &str) -> Result { ... } + async fn keys(&self) -> Result, SharedStateError> { ... } + async fn summary(&self) -> Result { ... } +} + +let state = SharedState::with_backend(MyRedisBackend::new()); +``` + +See [`examples/shared_state.rs`](../../examples/shared_state.rs) for a complete parallel analysis demo. + +## Multi-Provider Support + +Sub-agents can use any provider supported by yoagent — not just Anthropic. Pass a `ModelConfig` to configure the base URL, compat flags, and other provider-specific settings: + +```rust +use yoagent::provider::{OpenAiCompatProvider, model::ModelConfig}; + +let provider = Arc::new(OpenAiCompatProvider); +let model_config = ModelConfig::xai("grok-3-mini-fast", "Grok 3 Mini Fast"); + +let analyst = SubAgentTool::new("analyst", provider) + .with_model(&model_config.id) + .with_api_key(&xai_api_key) + .with_model_config(model_config) + .with_tools(vec![...]); +``` + +This works with all providers: OpenAI, Groq, DeepSeek, Gemini, Mistral, xAI, and more. See [`ModelConfig`](../reference/configuration.md) for the full list of factory methods. + ## Design Decisions - **Context isolation**: Each invocation starts fresh. Sub-agents don't accumulate history across calls. -- **No nesting**: Sub-agents are not given other `SubAgentTool`s. This prevents infinite delegation chains. +- **Nesting supported**: Sub-agents can be given other `SubAgentTool`s for recursive delegation (see [`examples/rlm.rs`](../../examples/rlm.rs)). Use `with_max_turns()` to prevent infinite chains. - **Cancellation propagation**: The parent's cancellation token is forwarded. Aborting the parent aborts all sub-agents. - **Turn limiting**: The default 10-turn limit prevents runaway execution. The parent's execution limits also apply to total wall-clock time. -## Example +## Examples -See [`examples/sub_agent.rs`](../../examples/sub_agent.rs) for a complete coordinator with researcher and coder sub-agents. +- [`examples/sub_agent.rs`](../../examples/sub_agent.rs) — Coordinator with researcher and coder sub-agents +- [`examples/code_review.rs`](../../examples/code_review.rs) — 3 parallel sub-agents reviewing a file via shared state +- [`examples/rlm.rs`](../../examples/rlm.rs) — Recursive Language Model: nested sub-agents with autonomous file discovery diff --git a/docs/reference/api.md b/docs/reference/api.md index 38685c8..81170cb 100644 --- a/docs/reference/api.md +++ b/docs/reference/api.md @@ -155,6 +155,83 @@ All return `Self` for chaining (unless noted as `Result`). | `abort()` | Cancel the current run via `CancellationToken` | | `async reset()` | Cancel any pending loop, recover tools, clear all state (messages, queues, streaming flag) | +## SubAgentTool + +Delegates tasks to a child agent loop. + +### Construction + +```rust +let sub = SubAgentTool::new("name", Arc::new(provider)); +``` + +### Builder Methods + +All return `Self` for chaining. + +| Method | Description | +|--------|-------------| +| `with_description(desc) -> Self` | What the parent LLM sees (helps it decide when to delegate) | +| `with_system_prompt(prompt) -> Self` | The sub-agent's own instructions | +| `with_model(model) -> Self` | Set the model identifier | +| `with_api_key(key) -> Self` | Set the API key | +| `with_model_config(config: ModelConfig) -> Self` | Set model config for non-Anthropic providers (base URL, compat flags, etc.) | +| `with_tools(tools: Vec>) -> Self` | Tools available to the sub-agent | +| `with_shared_state(state: SharedState) -> Self` | Attach a shared key-value store (injects `shared_state` tool automatically) | +| `with_max_turns(N) -> Self` | Turn limit (default: 10) | +| `with_thinking(level: ThinkingLevel) -> Self` | Enable extended thinking | +| `with_max_tokens(max: u32) -> Self` | Set max output tokens | +| `with_cache_config(config: CacheConfig) -> Self` | Prompt caching settings | +| `with_tool_execution(strategy: ToolExecutionStrategy) -> Self` | Tool execution strategy (`Parallel`, `Sequential`, `Batched`) | +| `with_retry_config(config: RetryConfig) -> Self` | Custom retry configuration | +| `with_turn_delay(delay: Duration) -> Self` | Inter-turn delay to throttle API calls (skips first turn) | + +## SharedState + +Pluggable key-value store for sub-agent communication. Backed by a `SharedStateBackend` trait. + +### Construction + +```rust +use yoagent::shared_state::{SharedState, FileBackend}; + +let state = SharedState::new(); // MemoryBackend, 10MB cap +let state = SharedState::with_max_bytes(50 * 1024 * 1024); // MemoryBackend, 50MB cap +let state = SharedState::with_backend(FileBackend::new("./state-dir")); // FileBackend +``` + +### Methods + +| Method | Description | +|--------|-------------| +| `async get(key) -> Option` | Read a value by key | +| `async set(key, value) -> Result<(), SharedStateError>` | Store a value | +| `async remove(key) -> bool` | Delete a key, returns whether it existed | +| `async keys() -> Vec` | List all keys | +| `async summary() -> String` | Human-readable summary of keys and sizes | + +### Built-in Backends + +| Backend | Description | +|---------|-------------| +| `MemoryBackend` | In-memory `HashMap` with byte capacity limit (default) | +| `FileBackend` | One file per key, percent-encoded filenames, persistent | + +### Custom Backends + +Implement the `SharedStateBackend` trait: + +```rust +#[async_trait::async_trait] +pub trait SharedStateBackend: Send + Sync { + async fn get(&self, key: &str) -> Result, SharedStateError>; + async fn set(&self, key: &str, value: String) -> Result<(), SharedStateError>; + async fn remove(&self, key: &str) -> Result; + async fn keys(&self) -> Result, SharedStateError>; + async fn summary(&self) -> Result; +} +``` + ## Re-exports The crate re-exports key types from `lib.rs`: diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index 6b9ee50..9ac664e 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -27,7 +27,7 @@ pub struct AgentLoopConfig { pub after_turn: Option, pub on_error: Option, pub input_filters: Vec>, - pub compaction_strategy: Option>, + pub turn_delay: Option, } ``` diff --git a/docs/reference/tools.md b/docs/reference/tools.md index c89cbdb..9d8483b 100644 --- a/docs/reference/tools.md +++ b/docs/reference/tools.md @@ -108,3 +108,19 @@ pub struct SearchTool { ``` Returns matching lines with file paths and line numbers. + +## SharedStateTool + +Read and write named variables in a shared key-value store. This tool is **not** included in `default_tools()` — it is automatically injected into sub-agents when you call `SubAgentTool::with_shared_state()`. + +- **Name**: `shared_state` +- **Parameters**: `action` (required: `get`, `set`, `list`, `remove`), `key` (required for get/set/remove), `value` (required for set) + +| Action | Description | +|--------|-------------| +| `get` | Returns the value for a key, or error if not found | +| `set` | Stores a value, returns confirmation with byte size | +| `list` | Lists all keys with their byte sizes | +| `remove` | Deletes a key | + +See [Sub-Agents: Shared State](../concepts/sub-agents.md#shared-state) for usage details. diff --git a/examples/code_review.rs b/examples/code_review.rs new file mode 100644 index 0000000..6ba7c61 --- /dev/null +++ b/examples/code_review.rs @@ -0,0 +1,189 @@ +//! Code review agent: parallel sub-agents reviewing the same source file. +//! +//! Demonstrates: +//! - Reading a real file and storing it in SharedState +//! - 3 parallel sub-agents each analyzing a different aspect +//! - Sub-agents writing structured findings back to shared state +//! - Parent aggregating all reviews into a unified report +//! +//! Run on any source file: +//! ANTHROPIC_API_KEY=sk-... cargo run --example code_review -- path/to/file.rs +//! +//! Try it on this repo: +//! ANTHROPIC_API_KEY=sk-... cargo run --example code_review -- src/shared_state.rs + +use std::sync::{Arc, Mutex}; +use yoagent::provider::{AnthropicProvider, StreamProvider}; +use yoagent::shared_state::SharedState; +use yoagent::sub_agent::SubAgentTool; +use yoagent::*; + +#[tokio::main] +async fn main() { + let api_key = std::env::var("ANTHROPIC_API_KEY").expect("Set ANTHROPIC_API_KEY"); + let model = "claude-sonnet-4-20250514"; + let provider: Arc = Arc::new(AnthropicProvider); + + // --- Read the target file from CLI args --- + let file_path = std::env::args().nth(1).unwrap_or_else(|| { + eprintln!("Usage: cargo run --example code_review -- "); + std::process::exit(1); + }); + + let source_code = std::fs::read_to_string(&file_path).unwrap_or_else(|e| { + eprintln!("Failed to read '{}': {}", file_path, e); + std::process::exit(1); + }); + + println!("Reviewing: {} ({} bytes)\n", file_path, source_code.len()); + + // --- Store the source code once in shared state --- + let state = SharedState::new(); + state + .set("source_code", source_code) + .await + .expect("store source code"); + state + .set("file_path", file_path.clone()) + .await + .expect("store file path"); + + // --- Three reviewers, each focused on a different aspect --- + + let bug_reviewer = SubAgentTool::new("bug_reviewer", Arc::clone(&provider)) + .with_description("Reviews code for bugs and logic errors") + .with_system_prompt( + "You are a bug-finding specialist. Read the source code from shared state \ + (key: 'source_code'), look for bugs, logic errors, off-by-one errors, \ + race conditions, and potential panics. Write your findings to shared state \ + under key 'bugs_review'. Format: bullet points, each with line reference \ + and severity (critical/warning/info). If no bugs found, say so.", + ) + .with_model(model) + .with_api_key(&api_key) + .with_shared_state(state.clone()) + .with_max_turns(5); + + let quality_reviewer = SubAgentTool::new("quality_reviewer", Arc::clone(&provider)) + .with_description("Reviews code quality and style") + .with_system_prompt( + "You are a code quality reviewer. Read the source code from shared state \ + (key: 'source_code'), evaluate naming, structure, idiomatic usage, \ + error handling patterns, and API design. Write your findings to shared state \ + under key 'quality_review'. Format: bullet points with specific suggestions. \ + Mention what's done well too.", + ) + .with_model(model) + .with_api_key(&api_key) + .with_shared_state(state.clone()) + .with_max_turns(5); + + let docs_reviewer = SubAgentTool::new("docs_reviewer", Arc::clone(&provider)) + .with_description("Reviews documentation completeness") + .with_system_prompt( + "You are a documentation reviewer. Read the source code from shared state \ + (key: 'source_code') and the file path (key: 'file_path'). Evaluate: \ + are public items documented? Are doc comments accurate? Are edge cases \ + explained? Are examples provided where helpful? Write findings to shared \ + state under key 'docs_review'. Format: bullet points.", + ) + .with_model(model) + .with_api_key(&api_key) + .with_shared_state(state.clone()) + .with_max_turns(5); + + // --- Run all three in parallel with streaming --- + println!("Dispatching 3 reviewers in parallel...\n"); + + let make_ctx = |label: &str| -> (ToolContext, Arc>) { + let label = label.to_string(); + let buf: Arc> = Arc::new(Mutex::new(String::new())); + let ctx = ToolContext { + tool_call_id: format!("tc-{}", label), + tool_name: label.clone(), + cancel: tokio_util::sync::CancellationToken::new(), + on_update: Some(Arc::new({ + let label = label.clone(); + let buf = buf.clone(); + move |result: ToolResult| { + for content in &result.content { + if let Content::Text { text } = content { + let mut b = buf.lock().unwrap(); + // Tool call events: flush buffer first, print on own line + if text.starts_with("[sub-agent calling tool:") { + if !b.is_empty() { + eprintln!("[{}] {}", label, b.drain(..).collect::()); + } + eprintln!("[{}] {}", label, text); + continue; + } + b.push_str(text); + while let Some(pos) = b.find('\n') { + let line: String = b.drain(..=pos).collect(); + eprint!("[{}] {}", label, line); + } + } + } + } + })), + on_progress: None, + }; + (ctx, buf) + }; + + let (ctx1, buf1) = make_ctx("bugs"); + let (ctx2, buf2) = make_ctx("quality"); + let (ctx3, buf3) = make_ctx("docs"); + + let (r1, r2, r3) = tokio::join!( + bug_reviewer.execute( + serde_json::json!({"task": "Review the source code for bugs and logic errors."}), + ctx1, + ), + quality_reviewer.execute( + serde_json::json!({"task": "Review the source code for quality and style."}), + ctx2, + ), + docs_reviewer.execute( + serde_json::json!({"task": "Review the source code for documentation completeness."}), + ctx3, + ), + ); + + // Flush any remaining buffered text + for (label, buf) in [("bugs", buf1), ("quality", buf2), ("docs", buf3)] { + let b = buf.lock().unwrap(); + if !b.is_empty() { + eprintln!("[{}] {}", label, *b); + } + } + r1.expect("bug reviewer failed"); + r2.expect("quality reviewer failed"); + r3.expect("docs reviewer failed"); + + // --- Print unified review --- + println!("═══════════════════════════════════════════════════════════"); + println!(" Code Review: {}", file_path); + println!("═══════════════════════════════════════════════════════════\n"); + + let sections = [ + ("bugs_review", "Bug Analysis"), + ("quality_review", "Code Quality"), + ("docs_review", "Documentation"), + ]; + + for (key, title) in sections { + println!("── {} ──\n", title); + match state.get(key).await { + Some(value) => println!("{}\n", value), + None => println!("(reviewer did not produce findings)\n"), + } + } + + println!("═══════════════════════════════════════════════════════════"); + println!( + " Review complete. Shared state keys: {}", + state.summary().await + ); + println!("═══════════════════════════════════════════════════════════"); +} diff --git a/examples/rlm.rs b/examples/rlm.rs new file mode 100644 index 0000000..326ebe0 --- /dev/null +++ b/examples/rlm.rs @@ -0,0 +1,197 @@ +//! Recursive Language Model (RLM) example. +//! +//! Demonstrates true RLM: an LLM that autonomously explores a codebase, +//! recursively delegating file-level analysis to sub-agents. All agents +//! communicate through shared state. +//! +//! Parent (Rust) → lead_analyst (LLM, discovers + delegates) +//! → file_analyst (LLM, reads + analyzes) +//! +//! The lead_analyst uses file system tools to explore, then spawns +//! file_analyst sub-agents for deep analysis. No hardcoded file lists. +//! +//! Run (analyzes current directory): +//! XAI_API_KEY=xai-... cargo run --example rlm +//! +//! Run on a specific directory: +//! XAI_API_KEY=xai-... cargo run --example rlm -- path/to/dir + +use std::sync::{Arc, Mutex}; +use yoagent::provider::model::ModelConfig; +use yoagent::provider::{OpenAiCompatProvider, StreamProvider}; +#[allow(unused_imports)] +use yoagent::shared_state::{FileBackend, SharedState}; +use yoagent::sub_agent::SubAgentTool; +use yoagent::tools; +use yoagent::*; + +#[tokio::main] +async fn main() { + let api_key = std::env::var("XAI_API_KEY").expect("Set XAI_API_KEY"); + let mut model_config = ModelConfig::xai("grok-4-1-fast-reasoning", "Grok 4.1 Fast Reasoning"); + model_config.reasoning = true; + let provider: Arc = Arc::new(OpenAiCompatProvider); + + let target_dir = std::env::args().nth(1).unwrap_or_else(|| ".".into()); + + println!("RLM Codebase Analyzer (Grok)"); + println!("Target: {}\n", target_dir); + + let state = SharedState::new(); + + // Or use filesystem backend — each key becomes a file for persistence + // and easy inspection of agent outputs: + // let state = SharedState::with_backend(FileBackend::new(".rlm-state")); + + // Store the target directory so agents know where to look + state + .set("target_dir", target_dir.clone()) + .await + .expect("store target dir"); + + println!("--- RLM: 2-level recursive agent delegation ---"); + println!("Parent → lead_analyst (explores) → file_analyst (analyzes)\n"); + + // --- Level 2: file_analyst (leaf agent) --- + // Has read_file + shared_state. Reads a file, writes summary to shared state. + let file_analyst = SubAgentTool::new("file_analyst", Arc::clone(&provider)) + .with_description( + "Analyzes a single source file in depth. \ + Call with a task specifying the file path to analyze.", + ) + .with_system_prompt( + "You are a file-level code analyst. When given a file to analyze:\n\ + 1. Use read_file to read the file content\n\ + 2. Analyze it: purpose, key types/functions, design patterns, quality\n\ + 3. Write a concise summary (under 200 words) to shared state with \ + key 'summary:'\n\n\ + Be specific and technical. Focus on what makes this code interesting.", + ) + .with_model(&model_config.id) + .with_api_key(&api_key) + .with_model_config(model_config.clone()) + .with_shared_state(state.clone()) + .with_tools(vec![Arc::new(tools::ReadFileTool::new())]) + .with_max_turns(5); + + // --- Level 1: lead_analyst (orchestrator agent) --- + // Has list_files + read_file to explore, file_analyst to delegate, shared_state for results. + let lead_analyst = SubAgentTool::new("lead_analyst", Arc::clone(&provider)) + .with_description("Orchestrates codebase analysis") + .with_system_prompt( + "You are a lead code analyst orchestrating a codebase review.\n\n\ + IMPORTANT: Only analyze files within the target directory. Do NOT explore \ + parent directories or other parts of the project.\n\n\ + Steps:\n\ + 1. Read 'target_dir' from shared state\n\ + 2. Use list_files to discover source files ONLY within that directory\n\ + 3. Pick the 2 most important files\n\ + 4. For EACH chosen file, delegate to the 'file_analyst' tool: \ + 'Analyze '\n\ + 5. After all files are analyzed, read each summary from shared state \ + (keys are 'summary:')\n\ + 6. Write a final synthesis report to shared state under key 'final_report'\n\n\ + The final report should identify cross-cutting themes, architectural patterns, \ + and how the files relate to each other. Keep it under 300 words.", + ) + .with_model(&model_config.id) + .with_api_key(&api_key) + .with_model_config(model_config) + .with_shared_state(state.clone()) + .with_tools(vec![ + Arc::new(tools::ListFilesTool::new()), + Arc::new(tools::ReadFileTool::new()), + Arc::new(file_analyst), + ]) + .with_max_turns(25); + + // --- Parent: single call triggers the full recursive chain --- + let buf: Arc> = Arc::new(Mutex::new(String::new())); + let ctx = ToolContext { + tool_call_id: "tc-rlm".into(), + tool_name: "lead_analyst".into(), + cancel: tokio_util::sync::CancellationToken::new(), + on_update: Some(Arc::new({ + let buf = buf.clone(); + move |result: ToolResult| { + for content in &result.content { + if let Content::Text { text } = content { + let mut b = buf.lock().unwrap(); + if text.starts_with("[sub-agent calling tool:") { + if !b.is_empty() { + eprintln!("[lead] {}", b.drain(..).collect::()); + } + eprintln!("[lead] {}", text); + return; + } + b.push_str(text); + while let Some(pos) = b.find('\n') { + let line: String = b.drain(..=pos).collect(); + eprint!("[lead] {}", line); + } + } + } + } + })), + on_progress: None, + }; + + let result = lead_analyst + .execute( + serde_json::json!({"task": "Explore and analyze this Rust crate."}), + ctx, + ) + .await; + + // Flush remaining buffer + { + let b = buf.lock().unwrap(); + if !b.is_empty() { + eprintln!("[lead] {}", *b); + } + } + + match result { + Ok(result) => { + eprintln!("[lead] details: {}", result.details); + for content in &result.content { + if let Content::Text { text } = content { + if !text.is_empty() { + eprintln!("[lead] (final) {}", text); + } + } + } + } + Err(e) => { + eprintln!("\nError: {}", e); + std::process::exit(1); + } + } + + // --- Read results from shared state --- + println!("\n═══════════════════════════════════════════════════════════"); + println!(" RLM Results"); + println!("═══════════════════════════════════════════════════════════\n"); + + // Print all per-file summaries + let keys = state.keys().await; + for key in &keys { + if let Some(file) = key.strip_prefix("summary:") { + println!("── {} ──\n", file); + if let Some(summary) = state.get(key).await { + println!("{}\n", summary); + } + } + } + + // Final report (written by lead_analyst, level 1) + println!("── Final Synthesis ──\n"); + match state.get("final_report").await { + Some(report) => println!("{}", report), + None => println!("(lead_analyst did not produce a final report)"), + } + + println!("\n═══════════════════════════════════════════════════════════"); + println!(" Shared state: {}", state.summary().await); + println!("═══════════════════════════════════════════════════════════"); +} diff --git a/examples/shared_state.rs b/examples/shared_state.rs new file mode 100644 index 0000000..62e48ba --- /dev/null +++ b/examples/shared_state.rs @@ -0,0 +1,147 @@ +//! Shared state example: parallel sub-agents analyzing the same artifact. +//! +//! Demonstrates: +//! - Storing a large artifact once in SharedState +//! - Multiple sub-agents reading it by reference (not re-pasted) +//! - Sub-agents writing findings back to shared state +//! - Parent reading all findings after completion +//! +//! The "aha moment": the CI log is 50KB but each sub-agent's prompt is +//! just one sentence. They all read the same artifact via shared_state +//! tool — no context wasted on re-pasting. +//! +//! Run: +//! ANTHROPIC_API_KEY=sk-... cargo run --example shared_state + +use std::sync::Arc; +use yoagent::provider::{AnthropicProvider, StreamProvider}; +use yoagent::shared_state::SharedState; +use yoagent::sub_agent::SubAgentTool; +use yoagent::*; + +#[tokio::main] +async fn main() { + let api_key = std::env::var("ANTHROPIC_API_KEY").expect("Set ANTHROPIC_API_KEY"); + let model = "claude-sonnet-4-20250514"; + let provider: Arc = Arc::new(AnthropicProvider); + + // --- The artifact: a large CI log (simulated) --- + let ci_log = r#" +[2026-04-27T10:00:01Z] Starting CI pipeline for commit abc123 +[2026-04-27T10:00:02Z] Step 1/5: cargo fmt -- --check ... OK +[2026-04-27T10:00:15Z] Step 2/5: cargo clippy ... OK (42 warnings suppressed) +[2026-04-27T10:00:30Z] Step 3/5: cargo build ... OK (debug, 45s) +[2026-04-27T10:01:15Z] Step 4/5: cargo test ... +[2026-04-27T10:01:16Z] test_auth_basic ... ok (12ms) +[2026-04-27T10:01:16Z] test_auth_refresh ... ok (8ms) +[2026-04-27T10:01:17Z] test_db_connection ... FAILED (timeout after 30000ms) +[2026-04-27T10:01:47Z] thread 'test_db_connection' panicked at 'connection timed out: TcpStream::connect' +[2026-04-27T10:01:47Z] note: database host db-ci.internal:5432 unreachable +[2026-04-27T10:01:48Z] test_api_list_users ... ok (145ms) +[2026-04-27T10:01:48Z] test_api_create_user ... ok (89ms) +[2026-04-27T10:01:49Z] test_api_delete_user ... FAILED +[2026-04-27T10:01:49Z] assertion failed: `(left == right)` left: 404, right: 204 +[2026-04-27T10:01:49Z] at tests/api_test.rs:142 +[2026-04-27T10:01:50Z] test_cache_invalidation ... ok (3ms) +[2026-04-27T10:01:50Z] test_cache_ttl ... ok (1002ms) [SLOW] +[2026-04-27T10:01:51Z] test_cache_concurrent ... ok (2105ms) [SLOW] +[2026-04-27T10:01:53Z] test_migration_up ... ok (340ms) +[2026-04-27T10:01:54Z] test_migration_down ... ok (290ms) +[2026-04-27T10:01:54Z] test_migration_idempotent ... ok (680ms) [SLOW] +[2026-04-27T10:01:55Z] test_flaky_network_retry ... FAILED +[2026-04-27T10:01:55Z] thread 'test_flaky_network_retry' panicked at 'retry count exceeded' +[2026-04-27T10:01:55Z] note: this test is known-flaky, see issue #187 +[2026-04-27T10:01:55Z] test result: 3 failed; 11 passed; 0 ignored +[2026-04-27T10:01:55Z] Step 5/5: skipped (tests failed) +[2026-04-27T10:01:55Z] Pipeline FAILED in 114s +"#; + + // --- Store the artifact once in shared state --- + let state = SharedState::new(); + state + .set("ci_log", ci_log.to_string()) + .await + .expect("store CI log"); + + println!("Stored CI log ({} bytes) in shared state.\n", ci_log.len()); + + // --- Three sub-agents, each analyzing a different aspect --- + + let error_analyst = SubAgentTool::new("error_analyst", Arc::clone(&provider)) + .with_description("Analyzes test failures in CI logs") + .with_system_prompt( + "You analyze CI logs for test failures. Read the log from shared state, \ + identify each failure, its root cause, and write a concise summary back \ + to shared state under 'errors_summary'. Be brief — bullet points only.", + ) + .with_model(model) + .with_api_key(&api_key) + .with_shared_state(state.clone()) + .with_max_turns(5); + + let perf_analyst = SubAgentTool::new("perf_analyst", Arc::clone(&provider)) + .with_description("Analyzes performance issues in CI logs") + .with_system_prompt( + "You analyze CI logs for performance issues. Read the log from shared state, \ + identify slow tests and bottlenecks, and write a concise summary back \ + to shared state under 'perf_summary'. Be brief — bullet points only.", + ) + .with_model(model) + .with_api_key(&api_key) + .with_shared_state(state.clone()) + .with_max_turns(5); + + let flaky_analyst = SubAgentTool::new("flaky_analyst", Arc::clone(&provider)) + .with_description("Identifies flaky tests in CI logs") + .with_system_prompt( + "You analyze CI logs for flaky/unreliable tests. Read the log from shared state, \ + identify tests that are flaky or infrastructure-dependent, and write a concise \ + summary back to shared state under 'flaky_summary'. Be brief — bullet points only.", + ) + .with_model(model) + .with_api_key(&api_key) + .with_shared_state(state.clone()) + .with_max_turns(5); + + // --- Run all three in parallel --- + println!("Dispatching 3 sub-agents in parallel...\n"); + + let ctx = |name: &str| ToolContext { + tool_call_id: format!("tc-{}", name), + tool_name: name.to_string(), + cancel: tokio_util::sync::CancellationToken::new(), + on_update: None, + on_progress: None, + }; + + let (r1, r2, r3) = tokio::join!( + error_analyst.execute( + serde_json::json!({"task": "Analyze the CI log for test failures."}), + ctx("error_analyst"), + ), + perf_analyst.execute( + serde_json::json!({"task": "Analyze the CI log for performance issues."}), + ctx("perf_analyst"), + ), + flaky_analyst.execute( + serde_json::json!({"task": "Analyze the CI log for flaky tests."}), + ctx("flaky_analyst"), + ), + ); + + r1.expect("error analyst failed"); + r2.expect("perf analyst failed"); + r3.expect("flaky analyst failed"); + + // --- Read all findings from shared state --- + println!("=== All sub-agents complete. Reading findings from shared state: ===\n"); + + for key in ["errors_summary", "perf_summary", "flaky_summary"] { + match state.get(key).await { + Some(value) => println!("--- {} ---\n{}\n", key, value), + None => println!("--- {} ---\n(sub-agent did not write this key)\n", key), + } + } + + println!("=== Shared state keys: {} ===", state.summary().await); +} diff --git a/src/agent.rs b/src/agent.rs index 846db0b..003d172 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -676,6 +676,7 @@ impl Agent { after_turn: self.after_turn.clone(), on_error: self.on_error.clone(), input_filters: self.input_filters.clone(), + turn_delay: None, } } } diff --git a/src/agent_loop.rs b/src/agent_loop.rs index aa84b1b..0e30ac8 100644 --- a/src/agent_loop.rs +++ b/src/agent_loop.rs @@ -86,6 +86,11 @@ pub struct AgentLoopConfig { /// Filters run in order; first `Reject` wins and discards any accumulated /// warnings. `Warn` messages accumulate and are appended to the user message. pub input_filters: Vec>, + + /// Optional delay between turns. Useful for rate-limit-sensitive scenarios + /// (e.g., OAuth tokens with low request-per-minute caps). Skipped on the + /// first turn so the agent starts immediately. + pub turn_delay: Option, } /// Default convert_to_llm: keep only user/assistant/toolResult messages. @@ -326,6 +331,15 @@ async fn run_loop( return; } } + + // Inter-turn delay — throttle API calls to stay under rate limits. + // Skipped on the first turn so the agent starts immediately. + if turn_number > 0 { + if let Some(delay) = config.turn_delay { + tokio::time::sleep(delay).await; + } + } + turn_number += 1; // Compact context if configured (tiered: tool outputs → summarize → drop) diff --git a/src/lib.rs b/src/lib.rs index 8087216..1a6cb6e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ pub mod context; pub mod mcp; pub mod provider; pub mod retry; +pub mod shared_state; pub mod skills; pub mod sub_agent; pub mod tools; @@ -16,6 +17,7 @@ pub use agent::Agent; pub use agent_loop::{agent_loop, agent_loop_continue}; pub use context::{CompactionStrategy, DefaultCompaction}; pub use retry::RetryConfig; +pub use shared_state::SharedState; pub use skills::SkillSet; pub use sub_agent::SubAgentTool; pub use types::*; diff --git a/src/provider/traits.rs b/src/provider/traits.rs index ae8c110..33d829a 100644 --- a/src/provider/traits.rs +++ b/src/provider/traits.rs @@ -141,7 +141,7 @@ pub async fn classify_eventsource_error(error: reqwest_eventsource::Error) -> Pr ), ) } - reqwest_eventsource::Error::Transport(e) => ProviderError::Network(e.to_string()), + reqwest_eventsource::Error::Transport(e) => ProviderError::Network(format!("{:?}", e)), other => ProviderError::Other(other.to_string()), } } diff --git a/src/shared_state.rs b/src/shared_state.rs new file mode 100644 index 0000000..5eb2738 --- /dev/null +++ b/src/shared_state.rs @@ -0,0 +1,605 @@ +//! Shared key-value state for sub-agent communication. +//! +//! `SharedState` is a pluggable key-value store that multiple sub-agents (and +//! the parent) can read/write. The default backend is in-memory; a filesystem +//! backend is also available for persistence and large artifacts. +//! +//! # Example +//! +//! ```rust,no_run +//! use yoagent::shared_state::SharedState; +//! +//! # async fn example() { +//! let state = SharedState::new(); +//! state.set("log", "big CI output...".into()).await.unwrap(); +//! +//! assert_eq!(state.get("log").await, Some("big CI output...".into())); +//! assert_eq!(state.keys().await, vec!["log"]); +//! # } +//! ``` + +use std::collections::HashMap; +use std::fmt; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// Default capacity for the memory backend: 10 MB. +const DEFAULT_MAX_BYTES: usize = 10 * 1024 * 1024; + +/// Error returned when a `set` would exceed the capacity limit. +#[derive(Debug, Clone)] +pub struct CapacityError { + pub key: String, + pub value_bytes: usize, + pub current_bytes: usize, + pub max_bytes: usize, +} + +impl fmt::Display for CapacityError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "SharedState capacity exceeded: storing '{}' ({} bytes) would bring total to {} / {} bytes", + self.key, self.value_bytes, self.current_bytes + self.value_bytes, self.max_bytes + ) + } +} + +impl std::error::Error for CapacityError {} + +/// Error type for shared state operations. +#[derive(Debug)] +pub enum SharedStateError { + Capacity(CapacityError), + Io(std::io::Error), +} + +impl fmt::Display for SharedStateError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Capacity(e) => write!(f, "{}", e), + Self::Io(e) => write!(f, "SharedState I/O error: {}", e), + } + } +} + +impl std::error::Error for SharedStateError {} + +impl From for SharedStateError { + fn from(e: CapacityError) -> Self { + Self::Capacity(e) + } +} + +impl From for SharedStateError { + fn from(e: std::io::Error) -> Self { + Self::Io(e) + } +} + +// --------------------------------------------------------------------------- +// Backend trait +// --------------------------------------------------------------------------- + +/// Pluggable storage backend for `SharedState`. +/// +/// Implement this trait to back shared state with a custom store +/// (database, Redis, HTTP service, etc.). +#[async_trait::async_trait] +pub trait SharedStateBackend: Send + Sync { + /// Get a value by key. Returns `None` if the key doesn't exist. + async fn get(&self, key: &str) -> Result, SharedStateError>; + + /// Store a value. Implementations should enforce their own capacity limits. + async fn set(&self, key: &str, value: String) -> Result<(), SharedStateError>; + + /// Remove a key. Returns `true` if the key existed. + async fn remove(&self, key: &str) -> Result; + + /// List all keys (sorted). + async fn keys(&self) -> Result, SharedStateError>; + + /// Human-readable summary of stored variables (key names + sizes). + async fn summary(&self) -> Result; +} + +// --------------------------------------------------------------------------- +// Memory backend (default) +// --------------------------------------------------------------------------- + +/// In-memory backend backed by `HashMap` with a byte capacity limit. +pub struct MemoryBackend { + inner: RwLock>, + max_bytes: usize, +} + +impl Default for MemoryBackend { + fn default() -> Self { + Self::new() + } +} + +impl MemoryBackend { + pub fn new() -> Self { + Self { + inner: RwLock::new(HashMap::new()), + max_bytes: DEFAULT_MAX_BYTES, + } + } + + pub fn with_max_bytes(max_bytes: usize) -> Self { + Self { + inner: RwLock::new(HashMap::new()), + max_bytes, + } + } +} + +#[async_trait::async_trait] +impl SharedStateBackend for MemoryBackend { + async fn get(&self, key: &str) -> Result, SharedStateError> { + Ok(self.inner.read().await.get(key).cloned()) + } + + async fn set(&self, key: &str, value: String) -> Result<(), SharedStateError> { + let mut map = self.inner.write().await; + + // Calculate current total excluding the old value for this key. + let current: usize = map + .iter() + .filter(|(k, _)| k.as_str() != key) + .map(|(k, v)| k.len() + v.len()) + .sum(); + let new_entry = key.len() + value.len(); + + if current + new_entry > self.max_bytes { + return Err(CapacityError { + key: key.to_string(), + value_bytes: value.len(), + current_bytes: current, + max_bytes: self.max_bytes, + } + .into()); + } + + map.insert(key.to_string(), value); + Ok(()) + } + + async fn remove(&self, key: &str) -> Result { + Ok(self.inner.write().await.remove(key).is_some()) + } + + async fn keys(&self) -> Result, SharedStateError> { + let map = self.inner.read().await; + let mut keys: Vec = map.keys().cloned().collect(); + keys.sort(); + Ok(keys) + } + + async fn summary(&self) -> Result { + let map = self.inner.read().await; + Ok(format_summary( + map.iter().map(|(k, v)| (k.as_str(), v.len())), + )) + } +} + +// --------------------------------------------------------------------------- +// Filesystem backend +// --------------------------------------------------------------------------- + +/// Filesystem backend — each key is stored as a file in a directory. +/// +/// Keys are sanitized to safe filenames. Values are stored as plain text +/// (no extension) for easy inspection and debugging. +/// +/// ```rust,no_run +/// use yoagent::shared_state::{SharedState, FileBackend}; +/// +/// # async fn example() { +/// let state = SharedState::with_backend(FileBackend::new("/tmp/agent-state")); +/// state.set("summary", "analysis results...".into()).await.unwrap(); +/// // Creates /tmp/agent-state/summary with the content +/// # } +/// ``` +pub struct FileBackend { + dir: PathBuf, +} + +impl FileBackend { + /// Create a new filesystem backend. The directory is created lazily on first write. + pub fn new(dir: impl AsRef) -> Self { + Self { + dir: dir.as_ref().to_path_buf(), + } + } + + /// Encode a key into a safe, reversible filename. + /// Percent-encodes any character that isn't alphanumeric, `-`, `_`, or `.`. + /// This avoids collisions: distinct keys always produce distinct filenames. + fn key_to_path(&self, key: &str) -> PathBuf { + let encoded: String = key + .chars() + .map(|c| { + if c.is_alphanumeric() || c == '-' || c == '_' || c == '.' { + c.to_string() + } else { + format!("%{:02X}", c as u32) + } + }) + .collect(); + self.dir.join(encoded) + } + + /// Decode a filename back into the original key. + fn path_to_key(filename: &str) -> String { + let mut result = String::new(); + let mut chars = filename.chars(); + while let Some(c) = chars.next() { + if c == '%' { + let hex: String = chars.by_ref().take(2).collect(); + if let Ok(code) = u32::from_str_radix(&hex, 16) { + if let Some(decoded) = char::from_u32(code) { + result.push(decoded); + continue; + } + } + // Fallback: keep the raw percent sequence + result.push('%'); + result.push_str(&hex); + } else { + result.push(c); + } + } + result + } +} + +#[async_trait::async_trait] +impl SharedStateBackend for FileBackend { + async fn get(&self, key: &str) -> Result, SharedStateError> { + let path = self.key_to_path(key); + match tokio::fs::read_to_string(&path).await { + Ok(content) => Ok(Some(content)), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None), + Err(e) => Err(e.into()), + } + } + + async fn set(&self, key: &str, value: String) -> Result<(), SharedStateError> { + tokio::fs::create_dir_all(&self.dir).await?; + let path = self.key_to_path(key); + tokio::fs::write(&path, &value).await?; + Ok(()) + } + + async fn remove(&self, key: &str) -> Result { + let path = self.key_to_path(key); + match tokio::fs::remove_file(&path).await { + Ok(()) => Ok(true), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(false), + Err(e) => Err(e.into()), + } + } + + async fn keys(&self) -> Result, SharedStateError> { + let mut keys = Vec::new(); + let mut entries = match tokio::fs::read_dir(&self.dir).await { + Ok(entries) => entries, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(keys), + Err(e) => return Err(e.into()), + }; + while let Some(entry) = entries.next_entry().await? { + if let Some(name) = entry.file_name().to_str() { + // Skip hidden files + if !name.starts_with('.') { + keys.push(Self::path_to_key(name)); + } + } + } + keys.sort(); + Ok(keys) + } + + async fn summary(&self) -> Result { + let mut entries = Vec::new(); + let mut dir = match tokio::fs::read_dir(&self.dir).await { + Ok(dir) => dir, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok("(empty)".to_string()), + Err(e) => return Err(e.into()), + }; + while let Some(entry) = dir.next_entry().await? { + if let Some(name) = entry.file_name().to_str() { + if !name.starts_with('.') { + let meta = entry.metadata().await?; + entries.push((Self::path_to_key(name), meta.len() as usize)); + } + } + } + entries.sort_by(|a, b| a.0.cmp(&b.0)); + Ok(format_summary( + entries.iter().map(|(k, s)| (k.as_str(), *s)), + )) + } +} + +// --------------------------------------------------------------------------- +// SharedState (public API) +// --------------------------------------------------------------------------- + +/// A shared string key-value store for sub-agent communication. +/// +/// Cheaply cloneable (wraps `Arc`). Delegates all operations to a +/// pluggable [`SharedStateBackend`]. +#[derive(Clone)] +pub struct SharedState { + backend: Arc, +} + +impl SharedState { + /// Create a new in-memory store with the default 10 MB capacity. + pub fn new() -> Self { + Self { + backend: Arc::new(MemoryBackend::new()), + } + } + + /// Create a new in-memory store with a custom byte capacity. + pub fn with_max_bytes(max_bytes: usize) -> Self { + Self { + backend: Arc::new(MemoryBackend::with_max_bytes(max_bytes)), + } + } + + /// Create a store backed by a custom backend. + pub fn with_backend(backend: impl SharedStateBackend + 'static) -> Self { + Self { + backend: Arc::new(backend), + } + } + + /// Get a value by key. Returns `None` if the key doesn't exist. + pub async fn get(&self, key: &str) -> Option { + match self.backend.get(key).await { + Ok(val) => val, + Err(e) => { + eprintln!("[SharedState] get({:?}) error: {}", key, e); + None + } + } + } + + /// Store a value. Returns `Err` if the backend rejects it (capacity, I/O, etc.). + pub async fn set(&self, key: &str, value: String) -> Result<(), SharedStateError> { + self.backend.set(key, value).await + } + + /// Remove a key. Returns `true` if the key existed. + pub async fn remove(&self, key: &str) -> bool { + match self.backend.remove(key).await { + Ok(existed) => existed, + Err(e) => { + eprintln!("[SharedState] remove({:?}) error: {}", key, e); + false + } + } + } + + /// List all keys (sorted). + pub async fn keys(&self) -> Vec { + match self.backend.keys().await { + Ok(keys) => keys, + Err(e) => { + eprintln!("[SharedState] keys() error: {}", e); + Vec::new() + } + } + } + + /// Human-readable summary of stored variables (key names + byte sizes). + /// Suitable for injecting into a system prompt. + pub async fn summary(&self) -> String { + match self.backend.summary().await { + Ok(s) => s, + Err(e) => { + eprintln!("[SharedState] summary() error: {}", e); + "(error reading state)".to_string() + } + } + } +} + +impl Default for SharedState { + fn default() -> Self { + Self::new() + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn format_summary<'a>(entries: impl Iterator) -> String { + let entries: Vec<_> = entries.collect(); + if entries.is_empty() { + return "(empty)".to_string(); + } + entries + .iter() + .map(|(k, size)| format_entry(k, *size)) + .collect::>() + .join(", ") +} + +fn format_entry(key: &str, bytes: usize) -> String { + if bytes >= 1024 * 1024 { + format!("{} ({:.1} MB)", key, bytes as f64 / (1024.0 * 1024.0)) + } else if bytes >= 1024 { + format!("{} ({:.1} KB)", key, bytes as f64 / 1024.0) + } else { + format!("{} ({} bytes)", key, bytes) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_get_set_remove() { + let state = SharedState::new(); + assert_eq!(state.get("x").await, None); + + state.set("x", "hello".into()).await.unwrap(); + assert_eq!(state.get("x").await, Some("hello".into())); + + assert!(state.remove("x").await); + assert_eq!(state.get("x").await, None); + assert!(!state.remove("x").await); + } + + #[tokio::test] + async fn test_keys_sorted() { + let state = SharedState::new(); + state.set("c", "3".into()).await.unwrap(); + state.set("a", "1".into()).await.unwrap(); + state.set("b", "2".into()).await.unwrap(); + assert_eq!(state.keys().await, vec!["a", "b", "c"]); + } + + #[tokio::test] + async fn test_overwrite_same_key() { + let state = SharedState::with_max_bytes(100); + state.set("k", "short".into()).await.unwrap(); + state.set("k", "also short".into()).await.unwrap(); + assert_eq!(state.get("k").await, Some("also short".into())); + } + + #[tokio::test] + async fn test_capacity_limit() { + let state = SharedState::with_max_bytes(20); + state.set("a", "12345".into()).await.unwrap(); // 1 + 5 = 6 bytes + let err = state.set("b", "12345678901234567890".into()).await; + assert!(err.is_err()); + let e = err.unwrap_err(); + assert!(e.to_string().contains("capacity exceeded")); + } + + #[tokio::test] + async fn test_overwrite_within_capacity() { + let state = SharedState::with_max_bytes(30); + state.set("k", "aaaaaaaaaa".into()).await.unwrap(); // 1+10=11 + // Overwrite with larger value — old value excluded from budget + state.set("k", "bbbbbbbbbbbbbbbbbb".into()).await.unwrap(); // 1+18=19 + assert_eq!(state.get("k").await, Some("bbbbbbbbbbbbbbbbbb".into())); + } + + #[tokio::test] + async fn test_summary_formatting() { + let state = SharedState::new(); + assert_eq!(state.summary().await, "(empty)"); + + state.set("small", "hi".into()).await.unwrap(); + let s = state.summary().await; + assert!(s.contains("small")); + assert!(s.contains("bytes)")); + } + + #[tokio::test] + async fn test_concurrent_access() { + let state = SharedState::new(); + let mut handles = vec![]; + for i in 0..10 { + let s = state.clone(); + handles.push(tokio::spawn(async move { + s.set(&format!("k{}", i), format!("v{}", i)).await.unwrap(); + })); + } + for h in handles { + h.await.unwrap(); + } + assert_eq!(state.keys().await.len(), 10); + } + + #[tokio::test] + async fn test_file_backend() { + let dir = tempfile::tempdir().unwrap(); + let state = SharedState::with_backend(FileBackend::new(dir.path())); + + // Empty state + assert_eq!(state.get("x").await, None); + assert_eq!(state.keys().await, Vec::::new()); + assert_eq!(state.summary().await, "(empty)"); + + // Set and get + state.set("report", "analysis done".into()).await.unwrap(); + assert_eq!(state.get("report").await, Some("analysis done".into())); + + // File actually exists on disk + let content = std::fs::read_to_string(dir.path().join("report")).unwrap(); + assert_eq!(content, "analysis done"); + + // Keys + state.set("log", "build output".into()).await.unwrap(); + assert_eq!(state.keys().await, vec!["log", "report"]); + + // Summary + let summary = state.summary().await; + assert!(summary.contains("report")); + assert!(summary.contains("log")); + + // Remove + assert!(state.remove("report").await); + assert_eq!(state.get("report").await, None); + assert!(!state.remove("report").await); + } + + #[tokio::test] + async fn test_file_backend_key_encoding() { + let dir = tempfile::tempdir().unwrap(); + let state = SharedState::with_backend(FileBackend::new(dir.path())); + + // Keys with special chars are percent-encoded (reversible) + state + .set("summary:src/main.rs", "file analysis".into()) + .await + .unwrap(); + assert_eq!( + state.get("summary:src/main.rs").await, + Some("file analysis".into()) + ); + + // The file on disk uses percent-encoded name + let encoded = dir.path().join("summary%3Asrc%2Fmain.rs"); + assert!(encoded.exists()); + + // keys() returns the original key, not the filename + let keys = state.keys().await; + assert!(keys.contains(&"summary:src/main.rs".to_string())); + + // No collision: distinct keys produce distinct files + state + .set("summary_src_main.rs", "different".into()) + .await + .unwrap(); + assert_eq!( + state.get("summary:src/main.rs").await, + Some("file analysis".into()) + ); + assert_eq!( + state.get("summary_src_main.rs").await, + Some("different".into()) + ); + assert_eq!(state.keys().await.len(), 2); + } + + #[tokio::test] + async fn test_with_backend() { + // Verify with_backend works with MemoryBackend directly + let state = SharedState::with_backend(MemoryBackend::new()); + state.set("k", "v".into()).await.unwrap(); + assert_eq!(state.get("k").await, Some("v".into())); + } +} diff --git a/src/sub_agent.rs b/src/sub_agent.rs index 1a5f490..f3458f2 100644 --- a/src/sub_agent.rs +++ b/src/sub_agent.rs @@ -7,7 +7,7 @@ //! # Design //! //! - **Context isolation**: each invocation starts a fresh conversation -//! - **Depth limiting**: sub-agents are not given other SubAgentTools (static, no runtime counter) +//! - **Nesting supported**: sub-agents can contain other SubAgentTools for recursive delegation (use `with_max_turns()` to bound depth) //! - **Cancellation propagation**: the parent's cancel token is forwarded //! - **Event forwarding**: sub-agent events stream to the parent via `on_update` //! @@ -27,7 +27,10 @@ use crate::agent_loop::{agent_loop, AgentLoopConfig}; use crate::context::ExecutionLimits; +use crate::provider::model::ModelConfig; use crate::provider::StreamProvider; +use crate::shared_state::SharedState; +use crate::tools::shared_state_tool::SharedStateTool; use crate::types::*; use std::sync::Arc; use tokio::sync::mpsc; @@ -54,6 +57,9 @@ pub struct SubAgentTool { tool_execution: ToolExecutionStrategy, retry_config: crate::retry::RetryConfig, max_turns: usize, + shared_state: Option, + turn_delay: Option, + model_config: Option, } impl SubAgentTool { @@ -74,6 +80,9 @@ impl SubAgentTool { tool_execution: ToolExecutionStrategy::default(), retry_config: crate::retry::RetryConfig::default(), max_turns: DEFAULT_MAX_TURNS, + shared_state: None, + turn_delay: None, + model_config: None, } } @@ -131,6 +140,30 @@ impl SubAgentTool { self.max_turns = max; self } + + /// Attach a shared key-value store. Sub-agents get a `shared_state` tool + /// to read/write variables. The parent can also read/write programmatically + /// via the `SharedState` handle. + pub fn with_shared_state(mut self, state: SharedState) -> Self { + self.shared_state = Some(state); + self + } + + /// Add an inter-turn delay to throttle API requests. + /// Useful when using OAuth tokens or providers with low rate limits. + /// The delay is applied before each turn except the first. + pub fn with_turn_delay(mut self, delay: std::time::Duration) -> Self { + self.turn_delay = Some(delay); + self + } + + /// Set the model configuration for multi-provider support. + /// Required for non-Anthropic providers (OpenAI-compat, Google, etc.) + /// to specify base URL, compat flags, and other provider-specific settings. + pub fn with_model_config(mut self, config: ModelConfig) -> Self { + self.model_config = Some(config); + self + } } /// Thin adapter: wraps `Arc` so it can be placed in a @@ -203,15 +236,26 @@ impl AgentTool for SubAgentTool { .to_string(); // Build tool list from Arc wrappers - let tools: Vec> = self + let mut tools: Vec> = self .tools .iter() .map(|t| Box::new(ArcToolWrapper(Arc::clone(t))) as Box) .collect(); + // Inject SharedStateTool when shared state is configured + let mut system_prompt = self.system_prompt.clone(); + if let Some(ref state) = self.shared_state { + tools.push(Box::new(SharedStateTool::new(state.clone()))); + let summary = state.summary().await; + system_prompt.push_str(&format!( + "\n\n## Shared State\nYou have access to a shared variable store via the `shared_state` tool.\nAvailable: {}", + summary + )); + } + // Fresh context for the sub-agent let mut context = AgentContext { - system_prompt: self.system_prompt.clone(), + system_prompt, messages: Vec::new(), tools, }; @@ -224,7 +268,7 @@ impl AgentTool for SubAgentTool { thinking_level: self.thinking_level, max_tokens: self.max_tokens, temperature: None, - model_config: None, + model_config: self.model_config.clone(), convert_to_llm: None, transform_context: None, get_steering_messages: None, @@ -244,6 +288,7 @@ impl AgentTool for SubAgentTool { after_turn: None, on_error: None, input_filters: vec![], + turn_delay: self.turn_delay, }; // Channel for sub-agent events @@ -296,6 +341,14 @@ impl AgentTool for SubAgentTool { let _ = handle.await; } + // Check if the last message was an error + if let Some(error_msg) = extract_error(&new_messages) { + return Err(ToolError::Failed(format!( + "Sub-agent '{}' failed: {}", + self.tool_name, error_msg + ))); + } + // Extract final assistant text from the returned messages let result_text = extract_final_text(&new_messages); @@ -312,6 +365,27 @@ impl AgentTool for SubAgentTool { } } +/// Check if the last assistant message was an error, return the error message. +fn extract_error(messages: &[AgentMessage]) -> Option { + for msg in messages.iter().rev() { + if let AgentMessage::Llm(Message::Assistant { + stop_reason, + error_message, + .. + }) = msg + { + if *stop_reason == StopReason::Error { + return Some( + error_message + .clone() + .unwrap_or_else(|| "Unknown error".into()), + ); + } + } + } + None +} + /// Extract the final assistant text from agent messages. /// Collects text from the last assistant message, or returns a fallback. fn extract_final_text(messages: &[AgentMessage]) -> String { @@ -320,7 +394,7 @@ fn extract_final_text(messages: &[AgentMessage]) -> String { let texts: Vec<&str> = content .iter() .filter_map(|c| match c { - Content::Text { text } => Some(text.as_str()), + Content::Text { text } if !text.is_empty() => Some(text.as_str()), _ => None, }) .collect(); diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 347d1da..ef804a3 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -3,12 +3,14 @@ pub mod edit; pub mod file; pub mod list; pub mod search; +pub mod shared_state_tool; pub use bash::BashTool; pub use edit::EditFileTool; pub use file::{ReadFileTool, WriteFileTool}; pub use list::ListFilesTool; pub use search::SearchTool; +pub use shared_state_tool::SharedStateTool; use crate::types::AgentTool; diff --git a/src/tools/shared_state_tool.rs b/src/tools/shared_state_tool.rs new file mode 100644 index 0000000..fbb2012 --- /dev/null +++ b/src/tools/shared_state_tool.rs @@ -0,0 +1,252 @@ +//! Tool that exposes `SharedState` to sub-agents. +//! +//! Injected automatically by `SubAgentTool` when `.with_shared_state()` is used. +//! Provides get/set/list/remove actions against the shared key-value store. + +use crate::shared_state::SharedState; +use crate::types::*; + +/// A tool that lets an LLM read/write a [`SharedState`] store. +pub struct SharedStateTool { + state: SharedState, +} + +impl SharedStateTool { + pub fn new(state: SharedState) -> Self { + Self { state } + } +} + +#[async_trait::async_trait] +impl AgentTool for SharedStateTool { + fn name(&self) -> &str { + "shared_state" + } + + fn label(&self) -> &str { + "Shared State" + } + + fn description(&self) -> &str { + "Read and write named variables in a shared store. Variables persist across tool calls within this session." + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["get", "set", "list", "remove"], + "description": "Action to perform" + }, + "key": { + "type": "string", + "description": "Variable name (required for get/set/remove)" + }, + "value": { + "type": "string", + "description": "Value to store (required for set)" + } + }, + "required": ["action"] + }) + } + + async fn execute( + &self, + params: serde_json::Value, + _ctx: ToolContext, + ) -> Result { + let action = params + .get("action") + .and_then(|v| v.as_str()) + .ok_or_else(|| ToolError::InvalidArgs("Missing required 'action' parameter".into()))?; + + match action { + "get" => { + let key = require_key(¶ms)?; + match self.state.get(&key).await { + Some(value) => Ok(ToolResult { + content: vec![Content::Text { text: value }], + details: serde_json::json!({"action": "get", "key": key}), + }), + None => Err(ToolError::Failed(format!( + "Key '{}' not found in shared state", + key + ))), + } + } + "set" => { + let key = require_key(¶ms)?; + let value = params + .get("value") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + ToolError::InvalidArgs("Missing required 'value' parameter for set".into()) + })? + .to_string(); + + let bytes = value.len(); + self.state + .set(&key, value) + .await + .map_err(|e| ToolError::Failed(e.to_string()))?; + + Ok(ToolResult { + content: vec![Content::Text { + text: format!("Stored '{}' ({} bytes)", key, bytes), + }], + details: serde_json::json!({"action": "set", "key": key, "bytes": bytes}), + }) + } + "list" => { + let summary = self.state.summary().await; + Ok(ToolResult { + content: vec![Content::Text { text: summary }], + details: serde_json::json!({"action": "list"}), + }) + } + "remove" => { + let key = require_key(¶ms)?; + let existed = self.state.remove(&key).await; + let text = if existed { + format!("Removed '{}'", key) + } else { + format!("Key '{}' not found", key) + }; + Ok(ToolResult { + content: vec![Content::Text { text }], + details: serde_json::json!({"action": "remove", "key": key, "existed": existed}), + }) + } + other => Err(ToolError::InvalidArgs(format!( + "Unknown action '{}'. Use get, set, list, or remove.", + other + ))), + } + } +} + +fn require_key(params: &serde_json::Value) -> Result { + params + .get("key") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .ok_or_else(|| ToolError::InvalidArgs("Missing required 'key' parameter".into())) +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio_util::sync::CancellationToken; + + fn ctx() -> ToolContext { + ToolContext { + tool_call_id: "test".into(), + tool_name: "shared_state".into(), + cancel: CancellationToken::new(), + on_update: None, + on_progress: None, + } + } + + fn text_of(result: &ToolResult) -> &str { + match &result.content[0] { + Content::Text { text } => text, + _ => panic!("expected Text content"), + } + } + + #[tokio::test] + async fn test_set_and_get() { + let state = SharedState::new(); + let tool = SharedStateTool::new(state); + + let result = tool + .execute( + serde_json::json!({"action": "set", "key": "x", "value": "hello"}), + ctx(), + ) + .await + .unwrap(); + assert!(text_of(&result).contains("Stored")); + + let result = tool + .execute(serde_json::json!({"action": "get", "key": "x"}), ctx()) + .await + .unwrap(); + assert_eq!(text_of(&result), "hello"); + } + + #[tokio::test] + async fn test_get_missing_key() { + let tool = SharedStateTool::new(SharedState::new()); + let err = tool + .execute(serde_json::json!({"action": "get", "key": "nope"}), ctx()) + .await; + assert!(matches!(err, Err(ToolError::Failed(_)))); + } + + #[tokio::test] + async fn test_list() { + let state = SharedState::new(); + state.set("a", "1".into()).await.unwrap(); + let tool = SharedStateTool::new(state); + + let result = tool + .execute(serde_json::json!({"action": "list"}), ctx()) + .await + .unwrap(); + assert!(text_of(&result).contains("a")); + } + + #[tokio::test] + async fn test_remove() { + let state = SharedState::new(); + state.set("k", "v".into()).await.unwrap(); + let tool = SharedStateTool::new(state); + + let result = tool + .execute(serde_json::json!({"action": "remove", "key": "k"}), ctx()) + .await + .unwrap(); + assert!(text_of(&result).contains("Removed")); + + let result = tool + .execute(serde_json::json!({"action": "remove", "key": "k"}), ctx()) + .await + .unwrap(); + assert!(text_of(&result).contains("not found")); + } + + #[tokio::test] + async fn test_invalid_action() { + let tool = SharedStateTool::new(SharedState::new()); + let err = tool + .execute(serde_json::json!({"action": "explode"}), ctx()) + .await; + assert!(matches!(err, Err(ToolError::InvalidArgs(_)))); + } + + #[tokio::test] + async fn test_missing_params() { + let tool = SharedStateTool::new(SharedState::new()); + + // Missing action + let err = tool.execute(serde_json::json!({}), ctx()).await; + assert!(matches!(err, Err(ToolError::InvalidArgs(_)))); + + // Missing key for get + let err = tool + .execute(serde_json::json!({"action": "get"}), ctx()) + .await; + assert!(matches!(err, Err(ToolError::InvalidArgs(_)))); + + // Missing value for set + let err = tool + .execute(serde_json::json!({"action": "set", "key": "k"}), ctx()) + .await; + assert!(matches!(err, Err(ToolError::InvalidArgs(_)))); + } +} diff --git a/tests/agent_loop_test.rs b/tests/agent_loop_test.rs index d3bf06d..82ae9b3 100644 --- a/tests/agent_loop_test.rs +++ b/tests/agent_loop_test.rs @@ -30,6 +30,7 @@ fn make_config(provider: MockProvider) -> AgentLoopConfig { after_turn: None, on_error: None, input_filters: vec![], + turn_delay: None, } } @@ -761,6 +762,7 @@ async fn test_retry_on_rate_limit_succeeds() { after_turn: None, on_error: None, input_filters: vec![], + turn_delay: None, }; let mut context = AgentContext { @@ -828,6 +830,7 @@ async fn test_retry_exhausted_returns_error() { after_turn: None, on_error: None, input_filters: vec![], + turn_delay: None, }; let mut context = AgentContext { @@ -897,6 +900,7 @@ async fn test_no_retry_on_auth_error() { after_turn: None, on_error: None, input_filters: vec![], + turn_delay: None, }; let mut context = AgentContext { @@ -954,6 +958,7 @@ async fn test_retry_none_disables_retries() { after_turn: None, on_error: None, input_filters: vec![], + turn_delay: None, }; let mut context = AgentContext { @@ -1189,6 +1194,7 @@ async fn test_on_error_fires_on_provider_error() { error_msgs_clone.lock().unwrap().push(err.to_string()); })), input_filters: vec![], + turn_delay: None, }; let mut context = AgentContext { @@ -1892,6 +1898,7 @@ async fn test_custom_compaction_strategy_is_called() { after_turn: None, on_error: None, input_filters: vec![], + turn_delay: None, }; let prompt = AgentMessage::Llm(Message::user("Hello")); @@ -1967,6 +1974,7 @@ async fn test_none_compaction_strategy_uses_default() { after_turn: None, on_error: None, input_filters: vec![], + turn_delay: None, }; let prompt = AgentMessage::Llm(Message::user("Hello")); diff --git a/tests/integration_anthropic.rs b/tests/integration_anthropic.rs index bf02803..ddd12fc 100644 --- a/tests/integration_anthropic.rs +++ b/tests/integration_anthropic.rs @@ -37,6 +37,7 @@ fn make_config(provider: AnthropicProvider) -> AgentLoopConfig { after_turn: None, on_error: None, input_filters: vec![], + turn_delay: None, } } diff --git a/tests/integration_gemini.rs b/tests/integration_gemini.rs index cf77ad9..514677e 100644 --- a/tests/integration_gemini.rs +++ b/tests/integration_gemini.rs @@ -39,6 +39,7 @@ fn make_config(model: &str) -> AgentLoopConfig { after_turn: None, on_error: None, input_filters: vec![], + turn_delay: None, } } diff --git a/tests/shared_state_test.rs b/tests/shared_state_test.rs new file mode 100644 index 0000000..3392e0f --- /dev/null +++ b/tests/shared_state_test.rs @@ -0,0 +1,261 @@ +//! Tests for SharedState and its integration with SubAgentTool. + +use std::sync::Arc; +use tokio_util::sync::CancellationToken; +use yoagent::provider::mock::*; +use yoagent::provider::MockProvider; +use yoagent::shared_state::SharedState; +use yoagent::sub_agent::SubAgentTool; +use yoagent::*; + +// --------------------------------------------------------------------------- +// Integration: parent stores a value, sub-agent reads it via shared_state tool +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_sub_agent_reads_shared_state() { + let state = SharedState::new(); + state + .set("artifact", "LINE1: build failed\nLINE2: exit code 1".into()) + .await + .unwrap(); + + // Sub-agent mock: first call issues shared_state get, second returns text + let sub_provider = Arc::new(MockProvider::new(vec![ + MockResponse::ToolCalls(vec![MockToolCall { + name: "shared_state".into(), + provider_metadata: None, + arguments: serde_json::json!({"action": "get", "key": "artifact"}), + }]), + MockResponse::Text("The build failed with exit code 1".into()), + ])); + + let sub_agent = SubAgentTool::new("analyzer", sub_provider) + .with_description("Analyzes artifacts") + .with_system_prompt("Analyze the artifact.") + .with_model("mock") + .with_api_key("test") + .with_shared_state(state.clone()); + + let result = sub_agent + .execute( + serde_json::json!({"task": "What happened in the build?"}), + ToolContext { + tool_call_id: "tc-1".into(), + tool_name: "analyzer".into(), + cancel: CancellationToken::new(), + on_update: None, + on_progress: None, + }, + ) + .await + .expect("sub-agent should succeed"); + + let text = match &result.content[0] { + Content::Text { text } => text.as_str(), + _ => panic!("Expected text content"), + }; + assert!(text.contains("build failed")); +} + +// --------------------------------------------------------------------------- +// Integration: sub-agent writes a value, parent reads it back +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_sub_agent_writes_shared_state() { + let state = SharedState::new(); + + // Sub-agent mock: sets a value then responds with text + let sub_provider = Arc::new(MockProvider::new(vec![ + MockResponse::ToolCalls(vec![MockToolCall { + name: "shared_state".into(), + provider_metadata: None, + arguments: serde_json::json!({ + "action": "set", + "key": "summary", + "value": "Root cause: OOM in test runner" + }), + }]), + MockResponse::Text("Done, wrote summary.".into()), + ])); + + let sub_agent = SubAgentTool::new("writer", sub_provider) + .with_description("Writes summaries") + .with_system_prompt("Summarize findings.") + .with_model("mock") + .with_api_key("test") + .with_shared_state(state.clone()); + + sub_agent + .execute( + serde_json::json!({"task": "Summarize"}), + ToolContext { + tool_call_id: "tc-1".into(), + tool_name: "writer".into(), + cancel: CancellationToken::new(), + on_update: None, + on_progress: None, + }, + ) + .await + .expect("sub-agent should succeed"); + + // Parent reads back the value the sub-agent stored + let summary = state.get("summary").await.expect("summary should exist"); + assert_eq!(summary, "Root cause: OOM in test runner"); +} + +// --------------------------------------------------------------------------- +// Integration: two parallel sub-agents share state +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_parallel_sub_agents_share_state() { + let state = SharedState::new(); + state.set("input", "shared data".into()).await.unwrap(); + + // Agent A reads then writes result_a + let provider_a = Arc::new(MockProvider::new(vec![ + MockResponse::ToolCalls(vec![MockToolCall { + name: "shared_state".into(), + provider_metadata: None, + arguments: serde_json::json!({"action": "get", "key": "input"}), + }]), + MockResponse::ToolCalls(vec![MockToolCall { + name: "shared_state".into(), + provider_metadata: None, + arguments: serde_json::json!({"action": "set", "key": "result_a", "value": "from A"}), + }]), + MockResponse::Text("A done".into()), + ])); + + // Agent B reads then writes result_b + let provider_b = Arc::new(MockProvider::new(vec![ + MockResponse::ToolCalls(vec![MockToolCall { + name: "shared_state".into(), + provider_metadata: None, + arguments: serde_json::json!({"action": "get", "key": "input"}), + }]), + MockResponse::ToolCalls(vec![MockToolCall { + name: "shared_state".into(), + provider_metadata: None, + arguments: serde_json::json!({"action": "set", "key": "result_b", "value": "from B"}), + }]), + MockResponse::Text("B done".into()), + ])); + + let agent_a = SubAgentTool::new("agent_a", provider_a) + .with_system_prompt("You are agent A.") + .with_model("mock") + .with_api_key("test") + .with_shared_state(state.clone()); + + let agent_b = SubAgentTool::new("agent_b", provider_b) + .with_system_prompt("You are agent B.") + .with_model("mock") + .with_api_key("test") + .with_shared_state(state.clone()); + + let ctx = || ToolContext { + tool_call_id: "tc".into(), + tool_name: "test".into(), + cancel: CancellationToken::new(), + on_update: None, + on_progress: None, + }; + + // Run in parallel + let (ra, rb) = tokio::join!( + agent_a.execute(serde_json::json!({"task": "process"}), ctx()), + agent_b.execute(serde_json::json!({"task": "process"}), ctx()), + ); + ra.unwrap(); + rb.unwrap(); + + assert_eq!(state.get("result_a").await, Some("from A".into())); + assert_eq!(state.get("result_b").await, Some("from B".into())); + assert_eq!(state.get("input").await, Some("shared data".into())); +} + +// --------------------------------------------------------------------------- +// SubAgentTool without shared_state works as before +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_sub_agent_without_shared_state_unchanged() { + let sub_provider = Arc::new(MockProvider::text("hello")); + + let sub_agent = SubAgentTool::new("plain", sub_provider) + .with_system_prompt("You are plain.") + .with_model("mock") + .with_api_key("test"); + // No .with_shared_state() — existing behavior + + let result = sub_agent + .execute( + serde_json::json!({"task": "say hi"}), + ToolContext { + tool_call_id: "tc-1".into(), + tool_name: "plain".into(), + cancel: CancellationToken::new(), + on_update: None, + on_progress: None, + }, + ) + .await + .expect("should work without shared state"); + + let text = match &result.content[0] { + Content::Text { text } => text.as_str(), + _ => panic!("Expected text"), + }; + assert_eq!(text, "hello"); +} + +// --------------------------------------------------------------------------- +// System prompt includes shared state summary +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_shared_state_summary_in_system_prompt() { + let state = SharedState::new(); + state.set("log", "x".repeat(2048)).await.unwrap(); + + // We can't inspect the system prompt directly from outside, but we can + // verify the sub-agent gets the shared_state tool by having it call list + let sub_provider = Arc::new(MockProvider::new(vec![ + MockResponse::ToolCalls(vec![MockToolCall { + name: "shared_state".into(), + provider_metadata: None, + arguments: serde_json::json!({"action": "list"}), + }]), + MockResponse::Text("Listed state".into()), + ])); + + let sub_agent = SubAgentTool::new("lister", sub_provider) + .with_system_prompt("List state.") + .with_model("mock") + .with_api_key("test") + .with_shared_state(state); + + let result = sub_agent + .execute( + serde_json::json!({"task": "list"}), + ToolContext { + tool_call_id: "tc-1".into(), + tool_name: "lister".into(), + cancel: CancellationToken::new(), + on_update: None, + on_progress: None, + }, + ) + .await + .unwrap(); + + let text = match &result.content[0] { + Content::Text { text } => text.as_str(), + _ => panic!("Expected text"), + }; + assert_eq!(text, "Listed state"); +} diff --git a/tests/sub_agent_test.rs b/tests/sub_agent_test.rs index f084b2d..89abc8c 100644 --- a/tests/sub_agent_test.rs +++ b/tests/sub_agent_test.rs @@ -32,6 +32,7 @@ fn make_config(provider: MockProvider) -> AgentLoopConfig { after_turn: None, on_error: None, input_filters: vec![], + turn_delay: None, } }