diff --git a/clarifai/runners/models/agentic_class.py b/clarifai/runners/models/agentic_class.py index abc347a1..96e99ba9 100644 --- a/clarifai/runners/models/agentic_class.py +++ b/clarifai/runners/models/agentic_class.py @@ -441,37 +441,97 @@ def get_pool(cls) -> MCPConnectionPool: # === Token Tracking === - def _init_tokens(self): - if not hasattr(self._thread_local, 'tokens'): - self._thread_local.tokens = {'prompt': 0, 'completion': 0} + def _drain_tokens(self) -> dict: + """Drain and reset token accumulator. Returns accumulated values.""" + tokens = getattr(self._thread_local, 'tokens', {'prompt': 0, 'completion': 0}) + result = {'prompt': tokens['prompt'], 'completion': tokens['completion']} + self._thread_local.tokens = {'prompt': 0, 'completion': 0} + return result def _add_tokens(self, resp): """Accumulate tokens from response.""" usage = getattr(resp, 'usage', None) or ( getattr(resp.response, 'usage', None) if hasattr(resp, 'response') else None ) - if usage: - self._init_tokens() - self._thread_local.tokens['prompt'] += ( - getattr(usage, 'prompt_tokens', 0) or getattr(usage, 'input_tokens', 0) or 0 - ) - self._thread_local.tokens['completion'] += ( + if not usage: + return + if not hasattr(self._thread_local, 'tokens'): + self._thread_local.tokens = {'prompt': 0, 'completion': 0} + prompt_tokens = ( + getattr(usage, 'prompt_tokens', None) or getattr(usage, 'input_tokens', None) or 0 + ) + total_tokens = getattr(usage, 'total_tokens', None) + # Prefer total_tokens - prompt_tokens over completion_tokens when available. + # For reasoning models (e.g. o1, o3), completion_tokens excludes reasoning tokens, + # but total_tokens includes them, so the subtraction gives the true billable output count. + completion_tokens = ( + (total_tokens - prompt_tokens) + if total_tokens is not None + else ( getattr(usage, 'completion_tokens', 0) or getattr(usage, 'output_tokens', 0) or 0 ) + ) + self._thread_local.tokens['prompt'] += prompt_tokens + self._thread_local.tokens['completion'] += completion_tokens + # logger.info(f"Adding tokens - prompt: {prompt_tokens}, completion+reasoning: {completion_tokens}, total: {total_tokens}") + # logger.info( + # f"Accumulated tokens - prompt: {self._thread_local.tokens['prompt']}, completion+reasoning: {self._thread_local.tokens['completion']}" + # ) def _finalize_tokens(self): - """Send accumulated tokens to output context.""" - if hasattr(self._thread_local, 'tokens'): - t = self._thread_local.tokens - if t['prompt'] > 0 or t['completion'] > 0: - self.set_output_context( - prompt_tokens=t['prompt'], completion_tokens=t['completion'] - ) - del self._thread_local.tokens + """Report accumulated tokens to output context.""" + t = self._drain_tokens() + if t['prompt'] > 0 or t['completion'] > 0: + # logger.info(f"Finalizing tokens - prompt: {t['prompt']}, completion: {t['completion']}") + self.set_output_context(prompt_tokens=t['prompt'], completion_tokens=t['completion']) + + async def _clear_bg_tokens(self) -> dict: + """Drain token accumulator from background thread and return values. + + MCP streaming runs _set_usage() calls inside the pool's background event loop + (via _sync_to_async_iter), so tokens land on that thread's _thread_local storage. + This coroutine is submitted to the same loop via pool._run_async() so it executes + in that thread context and can correctly read and reset the accumulator. + """ + return self._drain_tokens() def _set_usage(self, resp): self._add_tokens(resp) + # === Helpers === + + def _extract_tool_content(self, result, error: Optional[str]) -> Optional[str]: + """Extract text content from an MCP tool result.""" + if error: + return f"Error: {error}" + if ( + hasattr(result, 'content') + and result.content + and hasattr(result.content[0], 'text') + and result.content[0].text + ): + return result.content[0].text + if ( + isinstance(result, (list, tuple)) + and result + and hasattr(result[0], 'text') + and result[0].text + ): + return result[0].text + return None + + def _normalize_input_items(self, input_data) -> List[dict]: + """Normalize input_data to a list of response API input items.""" + if isinstance(input_data, str): + return [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": input_data}], + } + ] + return input_data if isinstance(input_data, list) else [] + # === Tool Format Conversion === def _to_response_api_tools(self, tools: List[dict]) -> List[dict]: @@ -594,24 +654,17 @@ def _execute_chat_tools( tool_to_server: Dict[str, str], ): """Execute chat completion tool calls and append results to messages.""" - pool = self.get_pool() parsed = self._parse_chat_tool_calls(tool_calls) - results = pool.call_tools_batch(parsed, connections, tool_to_server) - - for call_id, result, error in results: - if error: - content = f"Error: {error}" - elif ( - hasattr(result, 'content') - and len(result.content) > 0 - and hasattr(result.content[0], 'text') - ): - content = result.content[0].text - elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: - content = result[0].text - else: - content = None - messages.append({"role": "tool", "tool_call_id": call_id, "content": content}) + for call_id, result, error in self.get_pool().call_tools_batch( + parsed, connections, tool_to_server + ): + messages.append( + { + "role": "tool", + "tool_call_id": call_id, + "content": self._extract_tool_content(result, error), + } + ) async def _execute_chat_tools_async( self, @@ -621,25 +674,17 @@ async def _execute_chat_tools_async( tool_to_server: Dict[str, str], ): """Async version of chat tool execution.""" - pool = self.get_pool() parsed = self._parse_chat_tool_calls(tool_calls) - results = await pool.call_tools_batch_async(parsed, connections, tool_to_server) - - for call_id, result, error in results: - if error: - content = f"Error: {error}" - elif ( - hasattr(result, 'content') - and len(result.content) > 0 - and hasattr(result.content[0], 'text') - and result.content[0].text - ): - content = result.content[0].text - elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: - content = result[0].text - else: - content = None - messages.append({"role": "tool", "tool_call_id": call_id, "content": content}) + for call_id, result, error in await self.get_pool().call_tools_batch_async( + parsed, connections, tool_to_server + ): + messages.append( + { + "role": "tool", + "tool_call_id": call_id, + "content": self._extract_tool_content(result, error), + } + ) def _execute_response_tools( self, @@ -649,25 +694,15 @@ def _execute_response_tools( tool_to_server: Dict[str, str], ): """Execute response API tool calls and append results to input_items.""" - pool = self.get_pool() - results = pool.call_tools_batch(tool_calls, connections, tool_to_server) - - for call_id, result, error in results: - if error: - output = f"Error: {error}" - elif ( - hasattr(result, 'content') - and len(result.content) > 0 - and hasattr(result.content[0], 'text') - and result.content[0].text - ): - output = result.content[0].text - elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: - output = result[0].text - else: - output = None + for call_id, result, error in self.get_pool().call_tools_batch( + tool_calls, connections, tool_to_server + ): input_items.append( - {"type": "function_call_output", "call_id": call_id, "output": output} + { + "type": "function_call_output", + "call_id": call_id, + "output": self._extract_tool_content(result, error), + } ) async def _execute_response_tools_async( @@ -678,25 +713,15 @@ async def _execute_response_tools_async( tool_to_server: Dict[str, str], ): """Async version of response API tool execution.""" - pool = self.get_pool() - results = await pool.call_tools_batch_async(tool_calls, connections, tool_to_server) - - for call_id, result, error in results: - if error: - output = f"Error: {error}" - elif ( - hasattr(result, 'content') - and len(result.content) > 0 - and hasattr(result.content[0], 'text') - and result.content[0].text - ): - output = result.content[0].text - elif len(result) > 0 and hasattr(result[0], 'text') and result[0].text: - output = result[0].text - else: - output = None + for call_id, result, error in await self.get_pool().call_tools_batch_async( + tool_calls, connections, tool_to_server + ): input_items.append( - {"type": "function_call_output", "call_id": call_id, "output": output} + { + "type": "function_call_output", + "call_id": call_id, + "output": self._extract_tool_content(result, error), + } ) # === Response Output Processing === @@ -750,16 +775,23 @@ def _route_request( # === Streaming Helpers === def _accumulate_tool_delta(self, delta, accumulated: dict): - """Accumulate streaming tool call deltas.""" + """Accumulate streaming tool call deltas, preserving all attributes.""" idx = delta.index if idx not in accumulated: accumulated[idx] = { - "id": delta.id, + "id": delta.id or "", "type": "function", "function": {"name": "", "arguments": ""}, } - if delta.id: - accumulated[idx]["id"] = delta.id + + # Capture all fields generically (handles thought_signature and any future extras) + if hasattr(delta, 'model_dump'): + for key, val in delta.model_dump(exclude_none=True).items(): + if key in ("index", "function") or val is None: + continue + accumulated[idx][key] = val + + # Handle function fields manually: arguments must be concatenated across chunks if delta.function: if delta.function.name: accumulated[idx]["function"]["name"] = delta.function.name @@ -767,13 +799,32 @@ def _accumulate_tool_delta(self, delta, accumulated: dict): accumulated[idx]["function"]["arguments"] += delta.function.arguments def _finalize_tool_calls(self, accumulated: dict) -> List[dict]: - """Convert accumulated tool calls to list.""" - return [ - {"id": v["id"], "type": "function", "function": v["function"]} - for v in (accumulated[k] for k in sorted(accumulated)) - ] + """Convert accumulated tool calls to list, preserving all attributes.""" + return [dict(accumulated[k]) for k in sorted(accumulated)] + + def _stream_with_nulled_usage(self, chunks): + """Yield chunks with usage=None on all but the last usage-bearing chunk. - def _create_stream_request(self, messages, tools, max_tokens, temperature, top_p): + Some providers(Gemini) send a usage object on every chunk; keeping only the last one + avoids accumulating duplicated token counts across the multi-turn tool loop. + We buffer the most recent usage chunk and null out all earlier ones before yielding. + """ + self.buffered_usage_chunk = None + for chunk in chunks: + if getattr(chunk, "usage", None) is not None: + if self.buffered_usage_chunk is not None: + # Null out earlier usage so _set_usage only counts the final summary. + self.buffered_usage_chunk.usage = None + yield self.buffered_usage_chunk + self.buffered_usage_chunk = chunk + else: + yield chunk + if self.buffered_usage_chunk is not None: + yield self.buffered_usage_chunk + + def _create_stream_request( + self, messages, tools, max_tokens, temperature, top_p, reasoning_effort=None + ): """Create streaming chat completion request.""" kwargs = { "model": self.model, @@ -784,6 +835,8 @@ def _create_stream_request(self, messages, tools, max_tokens, temperature, top_p "stream": True, "stream_options": {"include_usage": True}, } + if reasoning_effort is not None: + kwargs["reasoning_effort"] = reasoning_effort if tools: kwargs["tools"] = tools kwargs["tool_choice"] = "auto" @@ -816,10 +869,41 @@ async def producer(): break yield item + async def _sync_to_async_iter(self, sync_iter): + """Wrap a sync iterator so each next() runs in a thread pool executor. + + This prevents blocking the event loop thread during sync I/O waits + (e.g. OpenAI streaming responses), allowing queue operations and other + coroutines to interleave properly between chunks. + """ + _sentinel = object() + loop = asyncio.get_running_loop() + it = iter(sync_iter) + + def get_next(): + try: + return next(it) + except StopIteration: + return _sentinel + + while True: + chunk = await loop.run_in_executor(None, get_next) + if chunk is _sentinel: + break + yield chunk + # === Streaming with MCP === async def _stream_chat_with_tools( - self, messages, tools, connections, tool_to_server, max_tokens, temperature, top_p + self, + messages, + tools, + connections, + tool_to_server, + max_tokens, + temperature, + top_p, + reasoning_effort=None, ): """ Stream chat completions with MCP tool support, recursively handling tool calls. @@ -835,13 +919,20 @@ async def _stream_chat_with_tools( max_tokens: Maximum number of tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling parameter. + reasoning_effort: Optional reasoning effort level passed to the model. Yields: JSON-serialized chat completion chunks. """ accumulated_tools = {} assistant_content = "" - for chunk in self._create_stream_request(messages, tools, max_tokens, temperature, top_p): + async for chunk in self._sync_to_async_iter( + self._stream_with_nulled_usage( + self._create_stream_request( + messages, tools, max_tokens, temperature, top_p, reasoning_effort + ) + ) + ): self._set_usage(chunk) yield chunk.model_dump_json() @@ -865,7 +956,14 @@ async def _stream_chat_with_tools( await self._execute_chat_tools_async(tool_calls, connections, messages, tool_to_server) async for chunk in self._stream_chat_with_tools( - messages, tools, connections, tool_to_server, max_tokens, temperature, top_p + messages, + tools, + connections, + tool_to_server, + max_tokens, + temperature, + top_p, + reasoning_effort, ): yield chunk @@ -890,18 +988,7 @@ async def _stream_responses_with_tools(self, request_data, tools, connections, t Yields: str: JSON-encoded response chunks, streamed as they become available. """ - input_data = request_data.get("input", "") - input_items = ( - [ - { - "type": "message", - "role": "user", - "content": [{"type": "input_text", "text": input_data}], - } - ] - if isinstance(input_data, str) - else (input_data if isinstance(input_data, list) else []) - ) + input_items = self._normalize_input_items(request_data.get("input", "")) response_args = {**request_data, "model": self.model} if tools: @@ -912,7 +999,7 @@ async def _stream_responses_with_tools(self, request_data, tools, connections, t tool_calls_by_id = {} msg_index_map = {} - for chunk in self.client.responses.create(**response_args): + async for chunk in self._sync_to_async_iter(self.client.responses.create(**response_args)): self._set_usage(chunk) chunk_type = getattr(chunk, 'type', '') or chunk.__class__.__name__ @@ -1089,18 +1176,7 @@ def openai_transport(self, msg: str) -> str: elif endpoint == self.ENDPOINT_RESPONSES: response = self._route_request(endpoint, data, mcp_servers, connections, tools) - input_data = data.get("input", "") - input_items = ( - [ - { - "type": "message", - "role": "user", - "content": [{"type": "input_text", "text": input_data}], - } - ] - if isinstance(input_data, str) - else (input_data if isinstance(input_data, list) else []) - ) + input_items = self._normalize_input_items(data.get("input", "")) output = response.output if hasattr(response, 'output') else [] tool_calls = self._parse_response_tool_calls(output) @@ -1161,6 +1237,7 @@ def openai_stream_transport(self, msg: str) -> Iterator[str]: data.get("max_completion_tokens", 4096), data.get("temperature", 1.0), data.get("top_p", 1.0), + data.get("reasoning_effort"), ) ) else: @@ -1170,17 +1247,34 @@ def openai_stream_transport(self, msg: str) -> Iterator[str]: ) ) - self._finalize_tokens() + # Token accumulation for MCP streaming happens inside the pool's background + # event loop thread (via _sync_to_async_iter), so _thread_local.tokens lives + # there. We must collect them via _clear_bg_tokens() submitted to that same loop + # rather than calling _finalize_tokens() from this thread, which would read an + # empty accumulator. + bg_tokens = pool._run_async(self._clear_bg_tokens()) + if bg_tokens['prompt'] > 0 or bg_tokens['completion'] > 0: + logger.info( + f"Background tokens used: prompt_tokens={bg_tokens['prompt']}, completion_tokens={bg_tokens['completion']}" + ) + self.set_output_context( + prompt_tokens=bg_tokens['prompt'], + completion_tokens=bg_tokens['completion'], + ) return - # Non-MCP streaming + # Non-MCP streaming — wrap with _stream_with_nulled_usage so only the final + # usage-bearing chunk is forwarded, preventing double-counting when _set_usage + # is called once per chunk. if endpoint == self.ENDPOINT_RESPONSES: - for chunk in self.client.responses.create(**{**data, "model": self.model}): + for chunk in self._stream_with_nulled_usage( + self.client.responses.create(**{**data, "model": self.model}) + ): self._set_usage(chunk) yield chunk.model_dump_json() else: - for chunk in self.client.chat.completions.create( - **self._create_completion_args(data) + for chunk in self._stream_with_nulled_usage( + self.client.chat.completions.create(**self._create_completion_args(data)) ): self._set_usage(chunk) yield chunk.model_dump_json() diff --git a/tests/runners/test_agentic_class.py b/tests/runners/test_agentic_class.py index ef378320..b109474b 100644 --- a/tests/runners/test_agentic_class.py +++ b/tests/runners/test_agentic_class.py @@ -51,18 +51,25 @@ def mock_pool(self): # === Token Tracking Tests === - def test_init_tokens(self, model): - """Test token initialization.""" - model._init_tokens() - assert hasattr(model._thread_local, 'tokens') + def test_drain_tokens(self, model): + """Test draining and resetting token accumulator.""" + model._thread_local.tokens = {'prompt': 10, 'completion': 20} + result = model._drain_tokens() + assert result == {'prompt': 10, 'completion': 20} assert model._thread_local.tokens == {'prompt': 0, 'completion': 0} + def test_drain_tokens_with_no_prior_state(self, model): + """Test draining when no tokens have been accumulated.""" + result = model._drain_tokens() + assert result == {'prompt': 0, 'completion': 0} + def test_add_tokens_from_usage(self, model): """Test adding tokens from response with usage attribute.""" mock_response = MagicMock() mock_usage = MagicMock() mock_usage.prompt_tokens = 10 mock_usage.completion_tokens = 20 + mock_usage.total_tokens = None mock_response.usage = mock_usage model._add_tokens(mock_response) @@ -94,12 +101,14 @@ def test_add_tokens_accumulates(self, model): mock_usage1 = MagicMock() mock_usage1.prompt_tokens = 10 mock_usage1.completion_tokens = 20 + mock_usage1.total_tokens = None mock_response1.usage = mock_usage1 mock_response2 = MagicMock() mock_usage2 = MagicMock() mock_usage2.prompt_tokens = 5 mock_usage2.completion_tokens = 10 + mock_usage2.total_tokens = None mock_response2.usage = mock_usage2 model._add_tokens(mock_response1) @@ -110,14 +119,12 @@ def test_add_tokens_accumulates(self, model): def test_finalize_tokens(self, model): """Test finalizing tokens to output context.""" - model._init_tokens() - model._thread_local.tokens['prompt'] = 10 - model._thread_local.tokens['completion'] = 20 + model._thread_local.tokens = {'prompt': 10, 'completion': 20} with patch.object(model, 'set_output_context') as mock_set: model._finalize_tokens() mock_set.assert_called_once_with(prompt_tokens=10, completion_tokens=20) - assert not hasattr(model._thread_local, 'tokens') + assert model._thread_local.tokens == {'prompt': 0, 'completion': 0} def test_finalize_tokens_no_tokens(self, model): """Test finalizing when no tokens were tracked.""" @@ -125,6 +132,58 @@ def test_finalize_tokens_no_tokens(self, model): model._finalize_tokens() mock_set.assert_not_called() + # === Helper Method Tests === + + def test_extract_tool_content_error(self, model): + """Test that errors are returned as error strings.""" + result = model._extract_tool_content(None, "something went wrong") + assert result == "Error: something went wrong" + + def test_extract_tool_content_from_content_attr(self, model): + """Test extraction from result.content[0].text.""" + mock_result = MagicMock() + mock_result.content = [MagicMock(text="hello from content")] + result = model._extract_tool_content(mock_result, None) + assert result == "hello from content" + + def test_extract_tool_content_from_list(self, model): + """Test extraction from list result[0].text.""" + + class TextItem: + text = "hello from list" + + result = model._extract_tool_content([TextItem()], None) + assert result == "hello from list" + + def test_extract_tool_content_returns_none_when_empty(self, model): + """Test that None is returned when no text can be extracted.""" + mock_result = MagicMock() + mock_result.content = [] + result = model._extract_tool_content(mock_result, None) + assert result is None + + def test_normalize_input_items_string(self, model): + """Test normalizing a plain string to a message item list.""" + result = model._normalize_input_items("hello") + assert result == [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "hello"}], + } + ] + + def test_normalize_input_items_list(self, model): + """Test that an existing list is returned unchanged.""" + items = [{"type": "message", "role": "user", "content": "hi"}] + result = model._normalize_input_items(items) + assert result is items + + def test_normalize_input_items_other(self, model): + """Test that non-string, non-list input returns empty list.""" + assert model._normalize_input_items(None) == [] + assert model._normalize_input_items(42) == [] + # === Tool Format Conversion Tests === def test_to_response_api_tools_with_function(self, model): @@ -1081,6 +1140,7 @@ def test_openai_stream_transport_with_mcp(self, model): {"test_tool": "http://server"}, ) mock_pool._loop = asyncio.new_event_loop() + mock_pool._run_async.return_value = {'prompt': 0, 'completion': 0} mock_chunk = MagicMock() mock_chunk.model_dump_json.return_value = '{"id": "chunk1"}'