Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

268 changes: 266 additions & 2 deletions rust/crates/runtime/src/conversation.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use std::collections::BTreeMap;
use std::fmt::{Display, Formatter};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use serde_json::{Map, Value};
use telemetry::SessionTracer;
Expand Down Expand Up @@ -53,6 +57,41 @@ pub struct PromptCacheEvent {
pub token_drop: u32,
}

/// Shared flag used to request graceful interruption of a running turn.
///
/// Cloning shares the underlying flag, mirroring
/// [`HookAbortSignal`](crate::hooks::HookAbortSignal). An input listener
/// (e.g. Esc or Ctrl+C handling in the CLI) sets the flag while the
/// conversation loop and the streaming API client poll it at safe points.
/// When the flag is observed, the turn winds down without treating the
/// stop as a failure: pending tool calls receive synthesized error
/// results so the session stays consistent, and [`TurnSummary`] reports
/// `interrupted: true`.
#[derive(Debug, Clone, Default)]
pub struct TurnInterruptSignal {
interrupted: Arc<AtomicBool>,
}

impl TurnInterruptSignal {
#[must_use]
pub fn new() -> Self {
Self::default()
}

pub fn interrupt(&self) {
self.interrupted.store(true, Ordering::SeqCst);
}

#[must_use]
pub fn is_interrupted(&self) -> bool {
self.interrupted.load(Ordering::SeqCst)
}

pub fn reset(&self) {
self.interrupted.store(false, Ordering::SeqCst);
}
}

/// Minimal streaming API contract required by [`ConversationRuntime`].
pub trait ApiClient {
fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
Expand Down Expand Up @@ -118,6 +157,7 @@ pub struct TurnSummary {
pub iterations: usize,
pub usage: TokenUsage,
pub auto_compaction: Option<AutoCompactionEvent>,
pub interrupted: bool,
}

/// Details about automatic session compaction applied during a turn.
Expand All @@ -138,6 +178,7 @@ pub struct ConversationRuntime<C, T> {
hook_runner: HookRunner,
auto_compaction_input_tokens_threshold: u32,
hook_abort_signal: HookAbortSignal,
turn_interrupt_signal: TurnInterruptSignal,
hook_progress_reporter: Option<Box<dyn HookProgressReporter>>,
session_tracer: Option<SessionTracer>,
}
Expand Down Expand Up @@ -187,6 +228,7 @@ where
hook_runner: HookRunner::from_feature_config(feature_config),
auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
hook_abort_signal: HookAbortSignal::default(),
turn_interrupt_signal: TurnInterruptSignal::default(),
hook_progress_reporter: None,
session_tracer: None,
}
Expand Down Expand Up @@ -217,6 +259,15 @@ where
self
}

#[must_use]
pub fn with_turn_interrupt_signal(
mut self,
turn_interrupt_signal: TurnInterruptSignal,
) -> Self {
self.turn_interrupt_signal = turn_interrupt_signal;
self
}

#[must_use]
pub fn with_hook_progress_reporter(
mut self,
Expand Down Expand Up @@ -350,8 +401,14 @@ where
let mut prompt_cache_events = Vec::new();
let mut iterations = 0;
let mut auto_compaction = None;
let mut interrupted = false;

loop {
if self.turn_interrupt_signal.is_interrupted() {
self.record_turn_interrupted(iterations, "before_request");
interrupted = true;
break;
}
iterations += 1;
if iterations > self.max_iterations {
let error = RuntimeError::new(
Expand All @@ -368,6 +425,14 @@ where
let events = match self.api_client.stream(request) {
Ok(events) => events,
Err(error) => {
if self.turn_interrupt_signal.is_interrupted() {
// The client aborted because the user interrupted the
// turn; any partial response is discarded and the stop
// is reported as an interruption rather than a failure.
self.record_turn_interrupted(iterations, "during_request");
interrupted = true;
break;
}
self.record_turn_failed(iterations, &error);
return Err(error);
}
Expand Down Expand Up @@ -416,6 +481,25 @@ where
}

for (tool_use_id, tool_name, input) in pending_tool_uses {
if interrupted || self.turn_interrupt_signal.is_interrupted() {
// Every pending tool_use must still receive a tool_result
// so the session stays valid for the next request.
if !interrupted {
self.record_turn_interrupted(iterations, "before_tool");
interrupted = true;
}
let result_message = ConversationMessage::tool_result(
tool_use_id,
tool_name,
"Interrupted by user before this tool could run.",
true,
);
self.session
.push_message(result_message.clone())
.map_err(|error| RuntimeError::new(error.to_string()))?;
tool_results.push(result_message);
continue;
}
let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input);
let effective_input = pre_hook_result
.updated_input()
Expand Down Expand Up @@ -515,6 +599,10 @@ where
self.record_tool_finished(iterations, &result_message);
tool_results.push(result_message);
}

if interrupted {
break;
}
}

let summary = TurnSummary {
Expand All @@ -524,8 +612,11 @@ where
iterations,
usage: self.usage_tracker.cumulative_usage(),
auto_compaction,
interrupted,
};
self.record_turn_completed(&summary);
if !interrupted {
self.record_turn_completed(&summary);
}

Ok(summary)
}
Expand Down Expand Up @@ -689,6 +780,17 @@ where
session_tracer.record("turn_completed", attributes);
}

fn record_turn_interrupted(&self, iteration: usize, phase: &str) {
let Some(session_tracer) = &self.session_tracer else {
return;
};

let mut attributes = Map::new();
attributes.insert("iteration".to_string(), Value::from(iteration as u64));
attributes.insert("phase".to_string(), Value::String(phase.to_string()));
session_tracer.record("turn_interrupted", attributes);
}

fn record_turn_failed(&self, iteration: usize, error: &RuntimeError) {
let Some(session_tracer) = &self.session_tracer else {
return;
Expand Down Expand Up @@ -850,7 +952,8 @@ mod tests {
use super::{
build_assistant_message, parse_auto_compaction_threshold, ApiClient, ApiRequest,
AssistantEvent, AutoCompactionEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
StaticToolExecutor, ToolExecutor, DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD,
StaticToolExecutor, ToolExecutor, TurnInterruptSignal,
DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD,
};
use crate::compact::CompactionConfig;
use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
Expand Down Expand Up @@ -1875,4 +1978,165 @@ mod tests {
// then
assert_eq!(error.to_string(), "upstream failed");
}

#[test]
fn interrupt_before_first_request_skips_the_api_call() {
struct UnreachableApi;

impl ApiClient for UnreachableApi {
fn stream(
&mut self,
_request: ApiRequest,
) -> Result<Vec<AssistantEvent>, RuntimeError> {
unreachable!("interrupted turn must not reach the API")
}
}

// given
let interrupt = TurnInterruptSignal::new();
interrupt.interrupt();
let mut runtime = ConversationRuntime::new(
Session::new(),
UnreachableApi,
StaticToolExecutor::new(),
PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()],
)
.with_turn_interrupt_signal(interrupt);

// when
let summary = runtime
.run_turn("hello", None)
.expect("interruption should not be reported as a failure");

// then
assert!(summary.interrupted);
assert_eq!(summary.iterations, 0);
assert!(summary.assistant_messages.is_empty());
assert!(summary.tool_results.is_empty());
assert_eq!(summary.auto_compaction, None);
assert_eq!(runtime.session().messages.len(), 1);
assert_eq!(runtime.session().messages[0].role, MessageRole::User);
}

#[test]
fn interrupt_after_stream_synthesizes_results_for_pending_tools() {
struct ToolUseApi {
interrupt: TurnInterruptSignal,
}

impl ApiClient for ToolUseApi {
fn stream(
&mut self,
_request: ApiRequest,
) -> Result<Vec<AssistantEvent>, RuntimeError> {
// Simulate the user pressing Esc while the response streams in.
self.interrupt.interrupt();
Ok(vec![
AssistantEvent::TextDelta("Running the tool.".to_string()),
AssistantEvent::ToolUse {
id: "tool-1".to_string(),
name: "add".to_string(),
input: "2,2".to_string(),
},
AssistantEvent::MessageStop,
])
}
}

// given
let interrupt = TurnInterruptSignal::new();
let mut runtime = ConversationRuntime::new(
Session::new(),
ToolUseApi {
interrupt: interrupt.clone(),
},
StaticToolExecutor::new()
.register("add", |_input| panic!("interrupted tool must not run")),
PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()],
)
.with_turn_interrupt_signal(interrupt);

// when
let summary = runtime
.run_turn("what is 2 + 2?", None)
.expect("interruption should not be reported as a failure");

// then
assert!(summary.interrupted);
assert_eq!(summary.iterations, 1);
assert_eq!(summary.assistant_messages.len(), 1);
assert_eq!(summary.tool_results.len(), 1);
assert!(matches!(
&summary.tool_results[0].blocks[0],
ContentBlock::ToolResult {
tool_use_id,
is_error: true,
output,
..
} if tool_use_id == "tool-1" && output.contains("Interrupted by user")
));
// user text, assistant tool_use, synthesized tool_result
assert_eq!(runtime.session().messages.len(), 3);
assert!(matches!(
runtime.session().messages[2].blocks[0],
ContentBlock::ToolResult { is_error: true, .. }
));
}

#[test]
fn stream_error_during_interrupt_is_reported_as_interruption() {
struct AbortedApi {
interrupt: TurnInterruptSignal,
}

impl ApiClient for AbortedApi {
fn stream(
&mut self,
_request: ApiRequest,
) -> Result<Vec<AssistantEvent>, RuntimeError> {
// Simulate the streaming client aborting the connection after
// observing the interrupt flag mid-stream.
self.interrupt.interrupt();
Err(RuntimeError::new("request aborted"))
}
}

// given
let sink = Arc::new(MemoryTelemetrySink::default());
let tracer = SessionTracer::new("session-interrupt", sink.clone());
let interrupt = TurnInterruptSignal::new();
let mut runtime = ConversationRuntime::new(
Session::new(),
AbortedApi {
interrupt: interrupt.clone(),
},
StaticToolExecutor::new(),
PermissionPolicy::new(PermissionMode::DangerFullAccess),
vec!["system".to_string()],
)
.with_turn_interrupt_signal(interrupt)
.with_session_tracer(tracer);

// when
let summary = runtime
.run_turn("hello", None)
.expect("interrupt-driven aborts should not surface as errors");

// then
assert!(summary.interrupted);
assert!(summary.assistant_messages.is_empty());
let trace_names = sink
.events()
.iter()
.filter_map(|event| match event {
TelemetryEvent::SessionTrace(trace) => Some(trace.name.clone()),
_ => None,
})
.collect::<Vec<_>>();
assert!(trace_names.iter().any(|name| name == "turn_interrupted"));
assert!(!trace_names.iter().any(|name| name == "turn_failed"));
assert!(!trace_names.iter().any(|name| name == "turn_completed"));
}
}
2 changes: 1 addition & 1 deletion rust/crates/runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub use config_validate::{
pub use conversation::{
auto_compaction_threshold_from_env, ApiClient, ApiRequest, AssistantEvent, AutoCompactionEvent,
ConversationRuntime, PromptCacheEvent, RuntimeError, StaticToolExecutor, ToolError,
ToolExecutor, TurnSummary,
ToolExecutor, TurnInterruptSignal, TurnSummary,
};
pub use file_ops::{
edit_file, edit_file_in_workspace, glob_search, glob_search_in_workspace, grep_search,
Expand Down
Loading