Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions plugin/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,17 @@ export const runPipeline = async ({
}),
})

const data = await readApiResponse<PipelineRunResponse | { detail?: string }>(response, url)
const data = await readApiResponse<PipelineRunResponse | { detail?: unknown }>(response, url)

if (!response.ok) {
if (response.status == 422) {
throw new Error(
'Plugin is not configured. Set the model, API key, and API URL in the plugin settings.',
)
}
const detail = 'detail' in data ? data.detail : undefined
throw new Error(
'detail' in data && data.detail ? data.detail : 'Pipeline execution failed.',
typeof detail === 'string' && detail ? detail : 'Pipeline execution failed.',
)
}

Expand Down
10 changes: 0 additions & 10 deletions service/config.template.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
llm_response_generation:
type: "openai-api"
model: "gpt-oss-120b"
api_url: https://llm.ai.e-infra.cz/v1/
api_key: "xxx"
workers: 4

dsw:
api_url: ""

auth:
allowed_project_urls:
- "https://your-dsw-instance.example.com"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from ai_document_plugin_service.ai.assignment.types import SectionAssignment
from ai_document_plugin_service.ai.common.config import Config
from ai_document_plugin_service.ai.common.llm_client import LLMClient
from ai_document_plugin_service.ai.common.progress import progress_percent
from ai_document_plugin_service.ai.common.types import AssignmentStats
from ai_document_plugin_service.ai.knowledgemodel.types import QuestionData
Expand All @@ -38,6 +39,17 @@ class AssignmentComponentResult(TypedDict):

@component
class AssignmentComponent:
"""
This component is responsible for assigning questions from the questionnaire to the sections from the dmp template.
For example, it assigns question 'When will the project start?' to sections Introduction and Project Timeline
"""

def __init__(self, llm_client: LLMClient, config: Config) -> None:
self.llm_client = llm_client
self.config = config
self.section_id_generator = OpenAISectionIdGenerator(llm_client, config)
self.section_matcher = OpenAILayerMatcher(self.llm_client, self.config)

@staticmethod
def _add_chunk_mapping_to_result(
*,
Expand All @@ -53,16 +65,13 @@ def _add_chunk_mapping_to_result(
continue
result_mapping[question_path] = [section_formatter.record_id_for_sid(sid) for sid in section_ids]

@staticmethod
def _match_single_chunk(
*,
config: Config,
self,
sections_xml: str,
question_chunk: str,
stats: AssignmentStats,
) -> dict[str, list[str]]:
matcher = OpenAILayerMatcher(config)
return matcher.match_questions_to_sections(
return self.section_matcher.match_questions_to_sections(
sections_xml,
question_chunk,
stats,
Expand All @@ -73,7 +82,6 @@ def run(
self,
data: list[QuestionData],
template_data: dict[str, Any],
config: Config,
km: dict[str, Any],
on_progress: Callable[[str], None] | None = None,
) -> AssignmentComponentResult:
Expand All @@ -85,7 +93,7 @@ def run(
stats = AssignmentStats()

section_formatter = SectionFormatter(sections)
section_formatter.create_mappings(OpenAISectionIdGenerator(config), stats)
section_formatter.create_mappings(self.section_id_generator, stats)
sections_xml = section_formatter.get_sections_as_xml()

result_mapping = {}
Expand All @@ -94,7 +102,6 @@ def run(

def match_chunk(question_chunk: str) -> dict[str, list[str]]:
result = self._match_single_chunk(
config=config,
sections_xml=sections_xml,
question_chunk=question_chunk,
stats=stats,
Expand All @@ -109,8 +116,8 @@ def match_chunk(question_chunk: str) -> dict[str, list[str]]:
for question_to_section_ids in thread_map(
match_chunk,
question_chunks,
max_workers=config.parallel_workers,
desc=f'Assigning questions to sections ({config.parallel_workers} workers)',
max_workers=self.llm_client.get_max_workers(),
desc=f'Assigning questions to sections ({self.llm_client.get_max_workers()} workers)',
):
self._add_chunk_mapping_to_result(
result_mapping=result_mapping,
Expand Down
5 changes: 2 additions & 3 deletions service/src/ai_document_plugin_service/ai/assignment/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def match_questions_to_sections(


class OpenAILayerMatcher(LayerMatcher):
def __init__(self, config: Config) -> None:
def __init__(self, llm_client: LLMClient, config: Config) -> None:
self.client = llm_client
self.config = config
self.client = LLMClient(config)

def match_questions_to_sections(
self,
Expand All @@ -74,7 +74,6 @@ def match_questions_to_sections(

def call_and_parse() -> dict[str, list[str]]:
response = self.client.completion(
model=self.config.model,
messages=messages,
temperature=self.config.assignment.temperature,
max_tokens=self.config.assignment.max_tokens,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import typing
from abc import ABC, abstractmethod

from openai import OpenAI
from tqdm import tqdm

from ai_document_plugin_service.ai.assignment.types import LeafSection
Expand Down Expand Up @@ -29,9 +28,9 @@ def generate_leaf_section_ids(


class OpenAISectionIdGenerator(SectionIdGenerator):
def __init__(self, config: Config) -> None:
def __init__(self, llm_client: LLMClient, config: Config) -> None:
self.config = config
self.client = LLMClient(config)
self.client = llm_client

@typing.override
def generate_leaf_section_ids(
Expand Down Expand Up @@ -59,7 +58,6 @@ def generate_leaf_section_ids(
)
response = call_with_retry(
lambda um=user_msg: self.client.completion(
model=self.config.model,
messages=[
{'role': 'system', 'content': system_msg},
{'role': 'user', 'content': um},
Expand Down Expand Up @@ -92,7 +90,6 @@ def _normalize_section_id(raw: str) -> str:
class LoggingNoopSectionIdGenerator(SectionIdGenerator):
def __init__(self, config: Config) -> None:
self.config = config
self.client = OpenAI(api_key=config.api_key, base_url=config.api_url)

@typing.override
def generate_leaf_section_ids( # ty: ignore[invalid-method-override]
Expand Down
48 changes: 5 additions & 43 deletions service/src/ai_document_plugin_service/ai/common/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import pathlib
from dataclasses import dataclass, replace
from dataclasses import dataclass
from typing import Any

import yaml
Expand Down Expand Up @@ -41,26 +41,21 @@ class DatabaseConfig:

@dataclass(frozen=True)
class Config:
api_key: str
api_url: str
dsw_api_url: str
allowed_project_urls: tuple[str, ...]
model: str
log_level: str
database: DatabaseConfig
files: FilePaths
assignment: SystemAndUserPrompt
section_id: SystemAndUserPrompt
dmp_generation: SystemPrompt
dmp_polishing: SystemAndUserPrompt
parallel_workers: int


@dataclass(frozen=True)
class LLMConfigOverride:
model: str | None = None
api_key: str | None = None
api_url: str | None = None
class LLMConfig:
model: str
api_key: str
api_url: str
parallel_workers: int | None = None


Expand Down Expand Up @@ -134,19 +129,6 @@ def _get_allowed_project_urls(config: dict[str, Any]) -> tuple[str, ...]:
return tuple(normalized)


def _get_parallel_workers(config: dict[str, Any]) -> int:
workers = config.get('llm_response_generation', {}).get('workers', 1)
try:
workers_int = int(workers)
except (TypeError, ValueError) as exc:
msg = "Invalid config value: 'parallelism.workers' must be an integer >= 1"
raise ValueError(msg) from exc
if workers_int < 1:
msg = "Invalid config value: 'parallelism.workers' must be >= 1"
raise ValueError(msg)
return workers_int


def _resolve_existing_path(path: str, *, base_dir: pathlib.Path | None = None) -> str:
normalized_path = _normalize_path(path)
path_obj = pathlib.Path(normalized_path)
Expand Down Expand Up @@ -199,12 +181,6 @@ def load_config(config_path: str | None = None) -> Config:
raise TypeError(msg)

return Config(
api_key=_expand_env_vars(
_get(config, 'llm_response_generation', 'api_key'),
),
api_url=_get(config, 'llm_response_generation', 'api_url'),
model=_get(config, 'llm_response_generation', 'model'),
dsw_api_url=_get(config, 'dsw', 'api_url'),
allowed_project_urls=_get_allowed_project_urls(config),
log_level=_get_log_level(config),
database=DatabaseConfig(
Expand Down Expand Up @@ -241,18 +217,4 @@ def load_config(config_path: str | None = None) -> Config:
system_message=_get(prompts, 'dmp_polishing', 'system_message'),
user_message=_get(prompts, 'dmp_polishing', 'user_message'),
),
parallel_workers=_get_parallel_workers(config),
)


def apply_llm_override(config: Config, override: LLMConfigOverride | None = None) -> Config:
if override is None:
return config

return replace(
config,
model=override.model or config.model,
api_key=override.api_key or config.api_key,
api_url=override.api_url or config.api_url,
parallel_workers=override.parallel_workers or config.parallel_workers,
)
19 changes: 12 additions & 7 deletions service/src/ai_document_plugin_service/ai/common/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from openai import APIConnectionError, APITimeoutError, OpenAI, RateLimitError
from openai.types.chat import ChatCompletion

from ai_document_plugin_service.ai.common.config import Config
from ai_document_plugin_service.ai.common.dynamic_semaphore import DynamicSemaphore

if TYPE_CHECKING:
Expand Down Expand Up @@ -76,27 +75,33 @@ def add_usage(stats: 'AssignmentStats | None', response: object) -> None:


class LLMClient:
def __init__(self, config: Config) -> None:
self.client = OpenAI(api_key=config.api_key, base_url=config.api_url, max_retries=0)
self.max_workers = config.parallel_workers or 1
def __init__(self, model: str, api_key: str, api_url: str, parallel_workers: int | None) -> None:
self.client = OpenAI(api_key=api_key, base_url=api_url, max_retries=0)
self.model = model
self.max_workers = parallel_workers or 1
logger.debug(
'Initializing LLM client, setting semaphore limit to %s',
self.max_workers,
)
semaphore.set_limit(self.max_workers)

def get_max_workers(self) -> int:
return self.max_workers

def get_model_name(self) -> str:
return self.model

def completion(
self,
*args: Any, # noqa: ANN401
**kwargs: Any, # noqa: ANN401
) -> ChatCompletion:
req_id = uuid.uuid4().hex[:8]
model = kwargs.get('model', args[0] if args else '?')
wait_start = time.perf_counter()
logger.debug(
'[llm] req=%s model=%s queueing (semaphore active/limit unknown until acquire)',
req_id,
model,
self.model,
)
with semaphore:
wait_s = time.perf_counter() - wait_start
Expand All @@ -107,7 +112,7 @@ def completion(
semaphore.limit,
)
call_start = time.perf_counter()
result = self.client.chat.completions.create(*args, **kwargs)
result = self.client.chat.completions.create(*args, model=self.model, **kwargs)
logger.debug(
'[llm] req=%s completed in %.3fs (releasing semaphore)',
req_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
from tqdm import tqdm

from ai_document_plugin_service.ai.assignment.types import SerializedSectionAssignment
from ai_document_plugin_service.ai.common import Config
from ai_document_plugin_service.ai.common.progress import progress_percent
from ai_document_plugin_service.ai.common.types import AssignmentStats
from ai_document_plugin_service.ai.generation.llm import (
GenerationLLM,
OpenAIGenerationLLM,
)
from ai_document_plugin_service.ai.generation.parse_answers import parse_answer

Expand All @@ -42,16 +40,14 @@ class _ScheduledSection:

@component
class DmpGeneratorComponent:
def __init__(self, llm: GenerationLLM | None = None) -> None:
self.llm = llm
def __init__(self, dmp_generator_llm: GenerationLLM) -> None:
self.dmp_generator_llm = dmp_generator_llm

@component.output_types(markdown=str, debug_markdown=str, stats=AssignmentStats)
def run(
self,
replies: dict,
km: dict,
config: Config,
llm: GenerationLLM | None = None,
new_assignments: list[SerializedSectionAssignment] | None = None,
db_assignments: list[SerializedSectionAssignment] | None = None,
on_progress: Callable[[str], None] | None = None,
Expand All @@ -66,18 +62,15 @@ def run(
replies = self._filter_reachable_replies(replies, km)

stats = AssignmentStats()

active_llm = llm or self.llm or OpenAIGenerationLLM(config)

worker_count = max(1, config.parallel_workers)
with ThreadPoolExecutor(max_workers=worker_count) as executor:
max_workers = self.dmp_generator_llm.get_max_workers()
with ThreadPoolExecutor(max_workers=max_workers) as executor:
scheduled_sections = [
self._schedule_section(
node=node,
depth=0,
replies=replies,
km=km,
llm=active_llm,
llm=self.dmp_generator_llm,
stats=stats,
executor=executor,
)
Expand All @@ -90,7 +83,7 @@ def run(
for i, future in tqdm(
enumerate(as_completed(leaf_futures), start=1),
total=total_sections,
desc=f'Generating sections ({worker_count} workers)',
desc=f'Generating sections ({max_workers} workers)',
):
future.result()
if on_progress is not None:
Expand Down
Loading