From 60dd5e95f961b309dea64ea850551a48ef7f1452 Mon Sep 17 00:00:00 2001 From: Stepka Date: Tue, 13 Jan 2026 11:39:47 +0200 Subject: [PATCH] Add `system_prompt` param to `GeneralLLM.invoke function` --- forecasting_tools/ai_models/general_llm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/forecasting_tools/ai_models/general_llm.py b/forecasting_tools/ai_models/general_llm.py index 7782d705..581c0dc0 100644 --- a/forecasting_tools/ai_models/general_llm.py +++ b/forecasting_tools/ai_models/general_llm.py @@ -231,21 +231,23 @@ def __init__( ModelTracker.give_cost_tracking_warning_if_needed(self._litellm_model) - async def invoke(self, prompt: ModelInputType) -> str: + async def invoke(self, prompt: ModelInputType, system_prompt: str | None = None) -> str: response: TextTokenCostResponse = ( - await self._invoke_with_request_cost_time_and_token_limits_and_retry(prompt) + await self._invoke_with_request_cost_time_and_token_limits_and_retry(prompt, system_prompt=system_prompt) ) data = response.data return data @RetryableModel._retry_according_to_model_allowed_tries async def _invoke_with_request_cost_time_and_token_limits_and_retry( - self, prompt: ModelInputType + self, prompt: ModelInputType, system_prompt: str | None = None ) -> Any: logger.debug(f"Invoking model with prompt: {prompt}") + prompt = self.model_input_to_message(prompt, system_prompt) + with track_generation( - input=self.model_input_to_message(prompt), + input=prompt, model=self.model, ) as span: direct_call_response = await self._mockable_direct_call_to_model(prompt)