diff --git a/forecasting_tools/ai_models/general_llm.py b/forecasting_tools/ai_models/general_llm.py index 7782d705..581c0dc0 100644 --- a/forecasting_tools/ai_models/general_llm.py +++ b/forecasting_tools/ai_models/general_llm.py @@ -231,21 +231,23 @@ def __init__( ModelTracker.give_cost_tracking_warning_if_needed(self._litellm_model) - async def invoke(self, prompt: ModelInputType) -> str: + async def invoke(self, prompt: ModelInputType, system_prompt: str | None = None) -> str: response: TextTokenCostResponse = ( - await self._invoke_with_request_cost_time_and_token_limits_and_retry(prompt) + await self._invoke_with_request_cost_time_and_token_limits_and_retry(prompt, system_prompt=system_prompt) ) data = response.data return data @RetryableModel._retry_according_to_model_allowed_tries async def _invoke_with_request_cost_time_and_token_limits_and_retry( - self, prompt: ModelInputType + self, prompt: ModelInputType, system_prompt: str | None = None ) -> Any: logger.debug(f"Invoking model with prompt: {prompt}") + prompt = self.model_input_to_message(prompt, system_prompt) + with track_generation( - input=self.model_input_to_message(prompt), + input=prompt, model=self.model, ) as span: direct_call_response = await self._mockable_direct_call_to_model(prompt)