diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index b47e3a65a..b6abfa3ac 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio import importlib import inspect import os.path @@ -26,6 +27,11 @@ from ..config.config import Config, ConfigLifecycleHandler from .base import Agent +# Current process shared +TOTAL_PROMPT_TOKENS = 0 +TOTAL_COMPLETION_TOKENS = 0 +TOKEN_LOCK = asyncio.Lock() + class LLMAgent(Agent): """ @@ -467,9 +473,26 @@ async def step( messages = await self.parallel_tool_call(messages) await self.after_tool_call(messages) + + # usage + + prompt_tokens = _response_message.prompt_tokens + completion_tokens = _response_message.completion_tokens + + global TOTAL_PROMPT_TOKENS, TOTAL_COMPLETION_TOKENS, TOKEN_LOCK + async with TOKEN_LOCK: + TOTAL_PROMPT_TOKENS += prompt_tokens + TOTAL_COMPLETION_TOKENS += completion_tokens + + # tokens in the current step + self.log_output( + f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}' + ) + # total tokens for the process so far self.log_output( - f'[usage] prompt_tokens: {_response_message.prompt_tokens}, ' - f'completion_tokens: {_response_message.completion_tokens}') + f'[usage_total] total_prompt_tokens: {TOTAL_PROMPT_TOKENS}, ' + f'total_completion_tokens: {TOTAL_COMPLETION_TOKENS}') + yield messages def prepare_llm(self): diff --git a/ms_agent/llm/openai_llm.py b/ms_agent/llm/openai_llm.py index 3d012619d..ff5bb6744 100644 --- a/ms_agent/llm/openai_llm.py +++ b/ms_agent/llm/openai_llm.py @@ -132,6 +132,10 @@ def _call_llm(self, """ messages = self._format_input_message(messages) + if kwargs.get('stream', False) and self.args.get( + 'stream_options', {}).get('include_usage', True): + kwargs.setdefault('stream_options', {})['include_usage'] = True + return self.client.chat.completions.create( model=self.model, messages=messages, tools=tools, **kwargs)