From dec93f68301270b0f866965fde7f26c6290ec18c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 10:11:43 +0200 Subject: [PATCH 01/13] add B rule --- pixi.lock | 4 ++-- pyproject.toml | 16 ++++++++++++---- ragna/__init__.py | 2 +- ragna/_cli/config.py | 6 +++--- ragna/_cli/corpus.py | 10 +++++----- ragna/core/_metadata_filter.py | 2 +- ragna/core/_rag.py | 2 +- ragna/core/_utils.py | 2 +- ragna/deploy/_ui/main_page.py | 4 +++- ragna/deploy/_ui/modal_configuration.py | 2 +- ragna/source_storages/_qdrant.py | 1 + ragna/source_storages/_vector_database.py | 2 +- tests/deploy/api/test_endpoints.py | 4 +++- tests/deploy/api/utils.py | 6 ++++-- 14 files changed, 39 insertions(+), 24 deletions(-) diff --git a/pixi.lock b/pixi.lock index 2a7c38e03..8e3fa43ba 100644 --- a/pixi.lock +++ b/pixi.lock @@ -11713,8 +11713,8 @@ packages: requires_python: '>=3.8' - pypi: ./ name: ragna - version: 0.4.0.dev40+g36d2640 - sha256: bbd81c4ebf5b4929da386781561b1df118bf8266b35858337294f670515199f6 + version: 0.4.0.dev38+gaf99c03.d20250620 + sha256: 26a22de2005bb9ece2d75fb16617c2c0abbe5787f5b438adc77fd6e6b202426c requires_dist: - aiofiles - fastapi diff --git a/pyproject.toml b/pyproject.toml index e2f27f48d..ca9ef64cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,15 +177,23 @@ depends-on = [ "build", ] +[tool.ruff] +target-version = "py310" + [tool.ruff.lint] select = [ "E", "F", - # import sorting - "I001" + "I001", + "B", +] + +ignore = [ + # Ignore line too long, because due to black, the error can only occur for strings + "E501", + # cache has its purpose + "B019", ] -# Ignore line too long, because due to black, the error can only occur for strings -ignore = ["E501"] [tool.ruff.lint.per-file-ignores] # ignore unused imports and imports not at the top of the file in __init__.py files diff --git a/ragna/__init__.py b/ragna/__init__.py index b7d5468a1..4887589e8 100644 --- a/ragna/__init__.py +++ b/ragna/__init__.py @@ -3,7 +3,7 @@ except ModuleNotFoundError: import warnings - warnings.warn("ragna was not properly installed!") + warnings.warn("ragna was not properly installed!", stacklevel=2) del warnings __version__ = "UNKNOWN" diff --git a/ragna/_cli/config.py b/ragna/_cli/config.py index 73044ed9b..32d77bf31 100644 --- a/ragna/_cli/config.py +++ b/ragna/_cli/config.py @@ -27,19 +27,19 @@ def parse_config(value: str) -> Config: try: config = Config.from_file(value) - except RagnaException: + except RagnaException as exc: rich.print(f"The configuration file {value} does not exist.") if value == "./ragna.toml": rich.print( "If you don't have a configuration file yet, " "run [bold]ragna init[/bold] to generate one." ) - raise typer.Exit(1) + raise typer.Exit(1) from exc except pydantic.ValidationError as validation: # FIXME: pretty formatting! for error in validation.errors(): rich.print(error) - raise typer.Exit(1) + raise typer.Exit(1) from validation # This stores the original value so we can pass it on to subprocesses that we might # start. config.__ragna_cli_config_path__ = value # type: ignore[attr-defined] diff --git a/ragna/_cli/corpus.py b/ragna/_cli/corpus.py index 661529810..54a4b4d57 100644 --- a/ragna/_cli/corpus.py +++ b/ragna/_cli/corpus.py @@ -69,12 +69,12 @@ def ingest( ] = False, ) -> None: try: - document_factory = getattr(config.document, "from_path") - except AttributeError: + document_factory = config.document.from_path + except AttributeError as exc: raise typer.BadParameter( f"{config.document.__name__} does not support creating documents from a" f"path. Please implement a `from_path` method." - ) + ) from exc database = Database(config.database_url) core_to_schema_document = CoreToSchemaConverter().document @@ -83,10 +83,10 @@ def ingest( try: with open(metadata_fields) as file: metadata = json.load(file) - except Exception: + except Exception as exc: raise typer.BadParameter( f"Could not read the metadata fields file: {metadata_fields}" - ) + ) from exc else: metadata = {} diff --git a/ragna/core/_metadata_filter.py b/ragna/core/_metadata_filter.py index 75c3cc57a..35d49eba5 100644 --- a/ragna/core/_metadata_filter.py +++ b/ragna/core/_metadata_filter.py @@ -64,7 +64,7 @@ def __eq__(self, other: Any) -> bool: if len(self.value) != len(other.value): return False - for self_child, other_child in zip(self.value, other.value): + for self_child, other_child in zip(self.value, other.value, strict=False): if self_child != other_child: return False diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index e966e4a9d..e9a57cdd8 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -489,7 +489,7 @@ def format_error( ] ) - raise RagnaException("\n".join(parts)) + raise RagnaException("\n".join(parts)) from None def _as_awaitable( self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any diff --git a/ragna/core/_utils.py b/ragna/core/_utils.py index 26fe1b83f..6360ec76f 100644 --- a/ragna/core/_utils.py +++ b/ragna/core/_utils.py @@ -141,7 +141,7 @@ def merge_models( field_definitions = {} for name, definitions in raw_field_definitions.items(): - types, defaults = zip(*definitions) + types, defaults = zip(*definitions, strict=False) types = set(types) if len(types) > 1: diff --git a/ragna/deploy/_ui/main_page.py b/ragna/deploy/_ui/main_page.py index e09cdd1e0..cac0b6615 100644 --- a/ragna/deploy/_ui/main_page.py +++ b/ragna/deploy/_ui/main_page.py @@ -106,7 +106,9 @@ def show_right_sidebar(self, title, content): self.right_sidebar.show() @param.depends("current_chat_id", watch=True) - def update_subviews_current_chat_id(self, avoid_senders=[]): + def update_subviews_current_chat_id(self, avoid_senders=None): + if avoid_senders is None: + avoid_senders = [] if self.left_sidebar is not None and self.left_sidebar not in avoid_senders: self.left_sidebar.current_chat_id = self.current_chat_id diff --git a/ragna/deploy/_ui/modal_configuration.py b/ragna/deploy/_ui/modal_configuration.py index 78de6ba37..9a325ebe9 100644 --- a/ragna/deploy/_ui/modal_configuration.py +++ b/ragna/deploy/_ui/modal_configuration.py @@ -177,7 +177,7 @@ async def content_stream() -> AsyncIterator[bytes]: ids_and_streams=[ (document.id, make_content_stream(data)) for document, data in zip( - documents, self.document_uploader.value + documents, self.document_uploader.value, strict=False ) ], ) diff --git a/ragna/source_storages/_qdrant.py b/ragna/source_storages/_qdrant.py index a92f5ecc5..fa1ea54b4 100644 --- a/ragna/source_storages/_qdrant.py +++ b/ragna/source_storages/_qdrant.py @@ -158,6 +158,7 @@ async def list_metadata( await asyncio.gather( *[self._fetch_metadata(corpus_name) for corpus_name in corpus_names] ), + strict=False, ) ) diff --git a/ragna/source_storages/_vector_database.py b/ragna/source_storages/_vector_database.py index 0ac6ff13e..96186aa34 100644 --- a/ragna/source_storages/_vector_database.py +++ b/ragna/source_storages/_vector_database.py @@ -79,7 +79,7 @@ def _chunk_pages( n=chunk_size, step=chunk_size - chunk_overlap, ): - tokens, page_numbers = zip(*window) + tokens, page_numbers = zip(*window, strict=False) yield Chunk( text=self._tokenizer.decode(tokens), page_numbers=list(filter(lambda n: n is not None, page_numbers)) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index 5f40e3e86..b025176e0 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -30,7 +30,9 @@ def test_get_documents(tmp_local_root, mime_type): document_paths = [ document_root / f"test{idx}.txt" for idx in range(len(_document_content_text)) ] - for content, document_path in zip(_document_content_text, document_paths): + for content, document_path in zip( + _document_content_text, document_paths, strict=False + ): with open(document_path, "w") as file: file.write(content) diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py index 043529518..3e1859670 100644 --- a/tests/deploy/api/utils.py +++ b/tests/deploy/api/utils.py @@ -14,7 +14,9 @@ def upload_documents(*, client, document_paths, mime_types=None): "name": document_path.name, "mime_type": mime_type, } - for document_path, mime_type in zip(document_paths, mime_types) + for document_path, mime_type in zip( + document_paths, mime_types, strict=False + ) ], ) .raise_for_status() @@ -30,7 +32,7 @@ def upload_documents(*, client, document_paths, mime_types=None): "/api/documents", files=[ ("documents", (document["id"], file)) - for document, file in zip(documents, files) + for document, file in zip(documents, files, strict=False) ], ) From 87e636956c366a015205eca21988082393cf0bee Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 10:15:47 +0200 Subject: [PATCH 02/13] add C4 rule --- pyproject.toml | 1 + ragna/_cli/config.py | 4 ++-- ragna/_cli/corpus.py | 12 ++++-------- ragna/assistants/_google.py | 2 +- ragna/deploy/_database.py | 4 ++-- ragna/deploy/_engine.py | 2 +- ragna/source_storages/_qdrant.py | 4 ++-- tests/assistants/test_api.py | 6 +++--- tests/deploy/ui/test_ui.py | 14 +++++++------- 9 files changed, 23 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ca9ef64cb..e019a130b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -186,6 +186,7 @@ select = [ "F", "I001", "B", + "C4", ] ignore = [ diff --git a/ragna/_cli/config.py b/ragna/_cli/config.py index 32d77bf31..9e469a860 100644 --- a/ragna/_cli/config.py +++ b/ragna/_cli/config.py @@ -187,12 +187,12 @@ def _select_components( def _handle_unmet_requirements(components: Iterable[Type[Component]]) -> None: - unmet_requirements = set( + unmet_requirements = { requirement for component in components for requirement in component.requirements() if not requirement.is_available() - ) + } if not unmet_requirements: return diff --git a/ragna/_cli/corpus.py b/ragna/_cli/corpus.py index 54a4b4d57..0eacc4166 100644 --- a/ragna/_cli/corpus.py +++ b/ragna/_cli/corpus.py @@ -132,14 +132,10 @@ def ingest( document_instances = [] if source_storage.display_name() in ingestion_log: - batch_doc_set = set( - [ - str(doc) - for doc in documents[ - batch_number : batch_number + BATCH_SIZE - ] - ] - ) + batch_doc_set = { + str(doc) + for doc in documents[batch_number : batch_number + BATCH_SIZE] + } if batch_doc_set.issubset( ingestion_log[source_storage.display_name()] ): diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index 627afd803..c263ca565 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -57,7 +57,7 @@ async def answer( "maxOutputTokens": max_new_tokens, }, }, - parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"), + parse_kwargs={"item": "item.candidates.item.content.parts.item.text"}, ) as stream: async for chunk in stream: yield chunk diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index 6d7c69463..c721411e7 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -25,9 +25,9 @@ class Database: def __init__(self, url: str) -> None: components = urlsplit(url) if components.scheme == "sqlite": - connect_args = dict(check_same_thread=False) + connect_args = {"check_same_thread": False} else: - connect_args = dict() + connect_args = {} engine = create_engine(url, connect_args=connect_args) orm.Base.metadata.create_all(bind=engine) diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index f52f2b823..9a45cf5f5 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -265,7 +265,7 @@ async def answer_stream( yield message # Avoid sending the sources multiple times - message_chunk = message.model_copy(update=dict(sources=None)) + message_chunk = message.model_copy(update={"sources": None}) async for content_chunk in content_stream: message_chunk.content = content_chunk yield message_chunk diff --git a/ragna/source_storages/_qdrant.py b/ragna/source_storages/_qdrant.py index fa1ea54b4..5fa1477f9 100644 --- a/ragna/source_storages/_qdrant.py +++ b/ragna/source_storages/_qdrant.py @@ -56,9 +56,9 @@ def __init__(self) -> None: from qdrant_client import AsyncQdrantClient if (url := os.environ.get("QDRANT_URL")) is not None: - kwargs = dict(url=url, api_key=os.environ.get("QDRANT_API_KEY")) + kwargs = {"url": url, "api_key": os.environ.get("QDRANT_API_KEY")} else: - kwargs = dict(path=str(ragna.local_root() / "qdrant")) + kwargs = {"path": str(ragna.local_root() / "qdrant")} self._client = AsyncQdrantClient(**kwargs) # type: ignore[arg-type] async def list_corpuses(self) -> list[str]: diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index af73ce86e..61f973f17 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -69,7 +69,7 @@ def new(base_url, streaming_protocol): cls = type( f"{streaming_protocol.name.title()}{HttpStreamingAssistant.__name__}", (HttpStreamingAssistant,), - dict(_STREAMING_PROTOCOL=streaming_protocol), + {"_STREAMING_PROTOCOL": streaming_protocol}, ) return cls(base_url) @@ -79,9 +79,9 @@ def __init__(self, base_url): async def answer(self, messages): if self._STREAMING_PROTOCOL is HttpStreamingProtocol.JSON: - parse_kwargs = dict(item="item") + parse_kwargs = {"item": "item"} else: - parse_kwargs = dict() + parse_kwargs = {} async with self._call_api( "POST", diff --git a/tests/deploy/ui/test_ui.py b/tests/deploy/ui/test_ui.py index e8fd48423..776d7b85b 100644 --- a/tests/deploy/ui/test_ui.py +++ b/tests/deploy/ui/test_ui.py @@ -16,13 +16,13 @@ def deploy(config): process = multiprocessing.Process( target=_deploy, - kwargs=dict( - config=config, - api=False, - ui=True, - ignore_unavailable_components=False, - open_browser=False, - ), + kwargs={ + "config": config, + "api": False, + "ui": True, + "ignore_unavailable_components": False, + "open_browser": False, + }, ) try: process.start() From 461310148f71d5bcc793618042439cf3f1078f3b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 10:16:18 +0200 Subject: [PATCH 03/13] add ISC rule --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e019a130b..eb7512d96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -187,6 +187,7 @@ select = [ "I001", "B", "C4", + "ISC", ] ignore = [ From a820c676e54cf00b38378b4df0551e985db8a6fb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 10:29:20 +0200 Subject: [PATCH 04/13] add RET rule --- pyproject.toml | 1 + ragna/core/_metadata_filter.py | 19 ++++---- ragna/deploy/_auth.py | 21 +++++---- ragna/deploy/_engine.py | 12 ++--- ragna/deploy/_schemas.py | 12 ++--- ragna/deploy/_ui/central_view.py | 12 ++--- .../components/metadata_filters_builder.py | 46 ++++++++----------- ragna/deploy/_ui/left_sidebar.py | 4 +- ragna/deploy/_ui/modal_configuration.py | 37 ++++++--------- ragna/source_storages/_chroma.py | 22 ++++----- ragna/source_storages/_demo.py | 30 ++++++------ ragna/source_storages/_lancedb.py | 30 ++++++------ ragna/source_storages/_qdrant.py | 20 ++++---- 13 files changed, 125 insertions(+), 141 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index eb7512d96..5f852720e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -188,6 +188,7 @@ select = [ "B", "C4", "ISC", + "RET", ] ignore = [ diff --git a/ragna/core/_metadata_filter.py b/ragna/core/_metadata_filter.py index 35d49eba5..3f894801a 100644 --- a/ragna/core/_metadata_filter.py +++ b/ragna/core/_metadata_filter.py @@ -39,7 +39,8 @@ def __init__(self, operator: MetadataOperator, key: str, value: Any) -> None: def __repr__(self) -> str: if self.operator is MetadataOperator.RAW: return f"{self.operator.name}({self.value!r})" - elif self.operator in {MetadataOperator.AND, MetadataOperator.OR}: + + if self.operator in {MetadataOperator.AND, MetadataOperator.OR}: return "\n".join( [ f"{self.operator.name}(", @@ -50,8 +51,8 @@ def __repr__(self) -> str: ")", ] ) - else: - return f"{self.operator.name}({self.key!r}, {self.value!r})" + + return f"{self.operator.name}({self.key!r}, {self.value!r})" def __eq__(self, other: Any) -> bool: if not isinstance(other, MetadataFilter): @@ -69,8 +70,8 @@ def __eq__(self, other: Any) -> bool: return False return True - else: - return (self.key == other.key) and (self.value == other.value) + + return (self.key == other.key) and (self.value == other.value) def to_primitive(self) -> dict[str, Any]: if self.operator is MetadataOperator.RAW: @@ -103,14 +104,14 @@ def __get_pydantic_core_schema__( def validate(value: Union[MetadataFilter, dict[str, Any]]) -> MetadataFilter: if isinstance(value, MetadataFilter): return value - else: - return cls.from_primitive(value) + + return cls.from_primitive(value) def serialize(value: Union[MetadataFilter, dict[str, Any]]) -> dict[str, Any]: if isinstance(value, MetadataFilter): return value.to_primitive() - else: - return value + + return value dict_schema = pydantic_core.core_schema.dict_schema( keys_schema=pydantic_core.core_schema.literal_schema( diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py index 756b14513..6009fb83d 100644 --- a/ragna/deploy/_auth.py +++ b/ragna/deploy/_auth.py @@ -64,19 +64,19 @@ async def dispatch(self, request: Request, call_next: CallNext) -> Response: return await self._api_token_dispatch( request, call_next, authorization=authorization ) - elif (cookie := request.cookies.get(self._COOKIE_NAME)) is not None: + if (cookie := request.cookies.get(self._COOKIE_NAME)) is not None: return await self._cookie_dispatch(request, call_next, cookie=cookie) - elif request.url.path in {"/login", "/oauth-callback"}: + if request.url.path in {"/login", "/oauth-callback"}: return await self._login_dispatch(request, call_next) - elif self._api and request.url.path.startswith("/api"): + if self._api and request.url.path.startswith("/api"): return self._unauthorized("Missing authorization header") - elif self._ui and request.url.path.startswith("/ui"): + if self._ui and request.url.path.startswith("/ui"): return redirect("/login") - else: - # Either an unknown route or something on the default router. In any case, - # this doesn't need a session and so we let it pass. - request.state.session = None - return await call_next(request) + + # Either an unknown route or something on the default router. In any case, + # this doesn't need a session, and so we let it pass. + request.state.session = None + return await call_next(request) async def _api_token_dispatch( self, request: Request, call_next: CallNext, authorization: str @@ -337,7 +337,8 @@ async def login(self, request: Request) -> Union[schemas.User, Response]: if not username: return self.login_page(request, fail_reason="Username cannot be empty") - elif (self._password is not None and password != self._password) or ( + + if (self._password is not None and password != self._password) or ( self._password is None and password != username ): return self.login_page( diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 9a45cf5f5..72c654348 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -123,12 +123,12 @@ def _get_source_storage_components( http_detail=RagnaException.MESSAGE, ) return [component] - else: - return [ - source_storage - for source_storage in self._rag._components.values() - if isinstance(source_storage, core.SourceStorage) - ] + + return [ + source_storage + for source_storage in self._rag._components.values() + if isinstance(source_storage, core.SourceStorage) + ] async def get_corpuses( self, source_storage: str | None = None diff --git a/ragna/deploy/_schemas.py b/ragna/deploy/_schemas.py index ce0bc6a4a..3c85b3264 100644 --- a/ragna/deploy/_schemas.py +++ b/ragna/deploy/_schemas.py @@ -38,8 +38,8 @@ class ApiKey(BaseModel): def _set_utc_timezone(cls, v: datetime) -> datetime: if v.tzinfo is None: return v.replace(tzinfo=timezone.utc) - else: - return v.astimezone(timezone.utc) + + return v.astimezone(timezone.utc) @computed_field # type: ignore[misc] @property @@ -55,15 +55,15 @@ def _maybe_obfuscate(cls, v: str, info: ValidationInfo) -> str: i = min(len(v) // 6, 3) if i > 0: return f"{v[:i]}***{v[-i:]}" - else: - return "***" + + return "***" def _set_utc_timezone(v: datetime) -> datetime: if v.tzinfo is None: return v.replace(tzinfo=timezone.utc) - else: - return v.astimezone(timezone.utc) + + return v.astimezone(timezone.utc) UtcDateTime = Annotated[datetime, AfterValidator(_set_utc_timezone)] diff --git a/ragna/deploy/_ui/central_view.py b/ragna/deploy/_ui/central_view.py index 9980997a6..4c1b35701 100644 --- a/ragna/deploy/_ui/central_view.py +++ b/ragna/deploy/_ui/central_view.py @@ -314,12 +314,12 @@ def set_current_chat(self, chat): def get_user_from_role(self, role: Literal["system", "user", "assistant"]) -> str: if role == "system": return "Ragna" - elif role == "user": + if role == "user": return cast(str, pn.state.user) - elif role == "assistant": + if role == "assistant": return cast(str, self.current_chat.assistant) - else: - raise RuntimeError + + raise RuntimeError async def chat_callback( self, content: str, user: str, instance: pn.chat.ChatInterface @@ -370,7 +370,7 @@ async def chat_callback( @pn.depends("current_chat") def chat_interface(self): if self.current_chat is None: - return + return None return RagnaChatInterface( *[ @@ -412,7 +412,7 @@ def chat_interface(self): @pn.depends("current_chat") def header(self): if self.current_chat is None: - return + return None current_chat_name = "" if self.current_chat is not None: diff --git a/ragna/deploy/_ui/components/metadata_filters_builder.py b/ragna/deploy/_ui/components/metadata_filters_builder.py index 518717fb8..efb16cc66 100644 --- a/ragna/deploy/_ui/components/metadata_filters_builder.py +++ b/ragna/deploy/_ui/components/metadata_filters_builder.py @@ -86,47 +86,39 @@ def key_did_change(self): def compute_valid_operator_options(self, type_str): if type_str == "bool": return ["EQ", "NE"] - elif type_str == "str": + if type_str == "str": return ["EQ", "NE", "IN", "NOT_IN"] - else: - return ALLOWED_OPERATORS + + return ALLOWED_OPERATORS def construct_metadata_filter(self): if self.key_select.value == NO_FILTER_KEY: return None if self.operator_select.value in ["IN", "NOT_IN"]: - return MetadataFilter( - MetadataOperator[self.operator_select.value], - self.key_select.value, - self.multi_value_select.value, - ) + value = self.multi_value_select.value else: - return MetadataFilter( - MetadataOperator[self.operator_select.value], - self.key_select.value, - self.value_select.value, - ) + value = self.value_select.value + return MetadataFilter( + MetadataOperator[self.operator_select.value], + self.key_select.value, + value, + ) @param.depends("operator") def display(self): if self.operator == "IN" or self.operator == "NOT_IN": _, self.param.multi_value.objects = self.key_value_pairs[self.key] - return pn.Row( - self.key_select, - self.operator_select, - self.multi_value_select, - self.delete_button, - css_classes=["metadata-filter-row"], - ) + value_select = self.multi_value_select else: - return pn.Row( - self.key_select, - self.operator_select, - self.value_select, - self.delete_button, - css_classes=["metadata-filter-row"], - ) + value_select = self.value_select + return pn.Row( + self.key_select, + self.operator_select, + value_select, + self.delete_button, + css_classes=["metadata-filter-row"], + ) def __panel__(self): return self.display diff --git a/ragna/deploy/_ui/left_sidebar.py b/ragna/deploy/_ui/left_sidebar.py index af91f3d53..92547d28b 100644 --- a/ragna/deploy/_ui/left_sidebar.py +++ b/ragna/deploy/_ui/left_sidebar.py @@ -108,9 +108,7 @@ def __panel__(self): ] ) - result = pn.Column( + return pn.Column( *objects, css_classes=["left_sidebar_main_column"], ) - - return result diff --git a/ragna/deploy/_ui/modal_configuration.py b/ragna/deploy/_ui/modal_configuration.py index 9a325ebe9..244ab9df5 100644 --- a/ragna/deploy/_ui/modal_configuration.py +++ b/ragna/deploy/_ui/modal_configuration.py @@ -15,11 +15,10 @@ def get_default_chat_name(timezone_offset=None): - if timezone_offset is None: - return f"Chat {datetime.now():%m/%d/%Y %I:%M %p}" - else: - tz = timezone(offset=timedelta(minutes=timezone_offset)) - return f"Chat {datetime.now().astimezone(tz=tz):%m/%d/%Y %I:%M %p}" + now = datetime.now() + if timezone_offset is not None: + now = now.astimezone(timezone(offset=timedelta(minutes=timezone_offset))) + return f"Chat {now:%m/%d/%Y %I:%M %p}" class ChatConfig(param.Parameterized): @@ -423,10 +422,7 @@ def toggle_card(event): "config.source_storage_name", ) def corpus_or_upload_config(self): - if self.corpus_or_upload == USE_CORPUS_LABEL: - return self.advanced_config(is_corpus=True) - else: - return self.advanced_config(is_corpus=False) + return self.advanced_config(is_corpus=self.corpus_or_upload == USE_CORPUS_LABEL) @pn.depends("advanced_config_collapsed", watch=True) def shrink_upload_container_height(self): @@ -467,24 +463,19 @@ def corpus_or_upload_row(self): corpus_names=corpus_names, corpus_metadata=corpus_metadata ) - if len(corpus_names) > 0: - data = pn.Column( - self.metadata_filter_rows_title, self.metadata_filter_rows - ) - else: - data = pn.Column(self.metadata_filter_rows) - self.error = False - return data - - else: return pn.Column( - pn.pane.HTML("Corpus Name"), - self.corpus_name_input, - self.upload_files_label, - self.upload_row, + *[self.metadata_filter_rows_title] if len(corpus_names) > 1 else [], + self.metadata_filter_rows, ) + return pn.Column( + pn.pane.HTML("Corpus Name"), + self.corpus_name_input, + self.upload_files_label, + self.upload_row, + ) + def __panel__(self): return pn.Column( pn.pane.HTML( diff --git a/ragna/source_storages/_chroma.py b/ragna/source_storages/_chroma.py index d80f248e5..99f524da2 100644 --- a/ragna/source_storages/_chroma.py +++ b/ragna/source_storages/_chroma.py @@ -162,9 +162,9 @@ def _translate_metadata_filter( ) -> Optional[dict[str, Any]]: if metadata_filter is None: return None - elif metadata_filter.operator is MetadataOperator.RAW: + if metadata_filter.operator is MetadataOperator.RAW: return cast(dict[str, Any], metadata_filter.value) - elif metadata_filter.operator in {MetadataOperator.AND, MetadataOperator.OR}: + if metadata_filter.operator in {MetadataOperator.AND, MetadataOperator.OR}: child_filters = [ self._translate_metadata_filter(child) for child in metadata_filter.value @@ -172,16 +172,16 @@ def _translate_metadata_filter( if len(child_filters) > 1: operator = self._METADATA_OPERATOR_MAP[metadata_filter.operator] return {operator: child_filters} - else: - return child_filters[0] - else: - return { - metadata_filter.key: { - self._METADATA_OPERATOR_MAP[ - metadata_filter.operator - ]: metadata_filter.value - } + + return child_filters[0] + + return { + metadata_filter.key: { + self._METADATA_OPERATOR_MAP[ + metadata_filter.operator + ]: metadata_filter.value } + } def retrieve( self, diff --git a/ragna/source_storages/_demo.py b/ragna/source_storages/_demo.py index 23f0435c8..2ea2e7330 100644 --- a/ragna/source_storages/_demo.py +++ b/ragna/source_storages/_demo.py @@ -118,9 +118,9 @@ def _apply_filter( ) -> list[tuple[int, dict[str, Any]]]: if metadata_filter is None: return list(enumerate(corpus)) - elif metadata_filter.operator is MetadataOperator.RAW: + if metadata_filter.operator is MetadataOperator.RAW: raise RagnaException - elif metadata_filter.operator in {MetadataOperator.AND, MetadataOperator.OR}: + if metadata_filter.operator in {MetadataOperator.AND, MetadataOperator.OR}: idcs_groups = [] rows_map = {} for child in metadata_filter.value: @@ -142,19 +142,19 @@ def _apply_filter( idcs_groups, ) return [(idx, rows_map[idx]) for idx in sorted(idcs)] - else: - rows_with_idx = [] - for idx, row in enumerate(corpus): - value = row.get(metadata_filter.key) - if value is None: - continue - - if self._METADATA_OPERATOR_MAP[metadata_filter.operator]( - value, metadata_filter.value - ): - rows_with_idx.append((idx, row)) - - return rows_with_idx + + rows_with_idx = [] + for idx, row in enumerate(corpus): + value = row.get(metadata_filter.key) + if value is None: + continue + + if self._METADATA_OPERATOR_MAP[metadata_filter.operator]( + value, metadata_filter.value + ): + rows_with_idx.append((idx, row)) + + return rows_with_idx def retrieve( self, corpus_name: str, metadata_filter: MetadataFilter, prompt: str diff --git a/ragna/source_storages/_lancedb.py b/ragna/source_storages/_lancedb.py index 4ada51873..5d70c5bca 100644 --- a/ragna/source_storages/_lancedb.py +++ b/ragna/source_storages/_lancedb.py @@ -83,12 +83,12 @@ def _get_table( ] ), ) - elif no_corpuses: + if no_corpuses: raise_no_corpuses_available(self) - elif non_existing_corpus: + if non_existing_corpus: raise_non_existing_corpus(self, corpus_name) - else: - return self._db.open_table(corpus_name) + + return self._db.open_table(corpus_name) def list_metadata( self, corpus_name: Optional[str] = None @@ -213,7 +213,7 @@ def store( def _translate_metadata_filter(self, metadata_filter: MetadataFilter) -> str: if metadata_filter.operator is MetadataOperator.RAW: return cast(str, metadata_filter.value) - elif metadata_filter.operator in { + if metadata_filter.operator in { MetadataOperator.AND, MetadataOperator.OR, }: @@ -222,20 +222,20 @@ def _translate_metadata_filter(self, metadata_filter: MetadataFilter) -> str: f"({self._translate_metadata_filter(child)})" for child in metadata_filter.value ) - elif metadata_filter.operator is MetadataOperator.NOT_IN: + if metadata_filter.operator is MetadataOperator.NOT_IN: in_ = self._translate_metadata_filter( MetadataFilter.in_(metadata_filter.key, metadata_filter.value) ) return f"NOT ({in_})" - else: - key = metadata_filter.key - operator = self._METADATA_OPERATOR_MAP[metadata_filter.operator] - value = ( - tuple(metadata_filter.value) - if metadata_filter.operator is MetadataOperator.IN - else metadata_filter.value - ) - return f"{key} {operator} {value!r}" + + key = metadata_filter.key + operator = self._METADATA_OPERATOR_MAP[metadata_filter.operator] + value = ( + tuple(metadata_filter.value) + if metadata_filter.operator is MetadataOperator.IN + else metadata_filter.value + ) + return f"{key} {operator} {value!r}" def retrieve( self, diff --git a/ragna/source_storages/_qdrant.py b/ragna/source_storages/_qdrant.py index 5fa1477f9..00d0d78c3 100644 --- a/ragna/source_storages/_qdrant.py +++ b/ragna/source_storages/_qdrant.py @@ -211,23 +211,23 @@ def _build_condition( # See https://qdrant.tech/documentation/concepts/filtering/#range if operator == MetadataOperator.EQ: return models.FieldCondition(key=key, match=models.MatchValue(value=value)) - elif operator == MetadataOperator.LT: + if operator == MetadataOperator.LT: return models.FieldCondition(key=key, range=models.Range(lt=value)) - elif operator == MetadataOperator.LE: + if operator == MetadataOperator.LE: return models.FieldCondition(key=key, range=models.Range(lte=value)) - elif operator == MetadataOperator.GT: + if operator == MetadataOperator.GT: return models.FieldCondition(key=key, range=models.Range(gt=value)) - elif operator == MetadataOperator.GE: + if operator == MetadataOperator.GE: return models.FieldCondition(key=key, range=models.Range(gte=value)) - elif operator == MetadataOperator.IN: + if operator == MetadataOperator.IN: return models.FieldCondition(key=key, match=models.MatchAny(any=value)) - elif operator in {MetadataOperator.NE, MetadataOperator.NOT_IN}: + if operator in {MetadataOperator.NE, MetadataOperator.NOT_IN}: except_value = [value] if operator == MetadataOperator.NE else value return models.FieldCondition( key=key, match=models.MatchExcept(**{"except": except_value}) ) - else: - raise ValueError(f"Unsupported operator: {operator}") + + raise ValueError(f"Unsupported operator: {operator}") def _translate_metadata_filter( self, metadata_filter: MetadataFilter @@ -236,14 +236,14 @@ def _translate_metadata_filter( if metadata_filter.operator is MetadataOperator.RAW: return cast(models.Filter, metadata_filter.value) - elif metadata_filter.operator == MetadataOperator.AND: + if metadata_filter.operator == MetadataOperator.AND: return models.Filter( must=[ self._translate_metadata_filter(child) for child in metadata_filter.value ] ) - elif metadata_filter.operator == MetadataOperator.OR: + if metadata_filter.operator == MetadataOperator.OR: return models.Filter( should=[ self._translate_metadata_filter(child) From c7de4e09add89af64665650faeb4e568cb43d4b5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 10:38:24 +0200 Subject: [PATCH 05/13] add SIM rule --- docs/tutorials/gallery_rest_api.py | 9 +++++---- pyproject.toml | 1 + ragna/_cli/corpus.py | 7 +------ ragna/deploy/_ui/modal_configuration.py | 13 ++++--------- ragna/source_storages/_chroma.py | 5 +---- ragna/source_storages/_demo.py | 5 +---- ragna/source_storages/_lancedb.py | 5 +---- tests/test_importable.py | 6 +----- 8 files changed, 15 insertions(+), 36 deletions(-) diff --git a/docs/tutorials/gallery_rest_api.py b/docs/tutorials/gallery_rest_api.py index c9352a6b2..6df61c55f 100644 --- a/docs/tutorials/gallery_rest_api.py +++ b/docs/tutorials/gallery_rest_api.py @@ -117,10 +117,11 @@ # - The field name is the ID of the document returned by step 1. # - The field value is the binary content of the document. -client.put( - "/api/documents", - files=[("documents", (documents[0]["id"], open(document_path, "rb")))], -) +with open(document_path, "rb") as f: + client.put( + "/api/documents", + files=[("documents", (documents[0]["id"], f))], + ) # %% # ## Step 4: Select a source storage and assistant diff --git a/pyproject.toml b/pyproject.toml index 5f852720e..f4dd27ba8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -189,6 +189,7 @@ select = [ "C4", "ISC", "RET", + "SIM", ] ignore = [ diff --git a/ragna/_cli/corpus.py b/ragna/_cli/corpus.py index 0eacc4166..7137d97ce 100644 --- a/ragna/_cli/corpus.py +++ b/ragna/_cli/corpus.py @@ -146,12 +146,7 @@ def ingest( try: document_instances.append( document_factory( - document, - metadata=( - metadata[str(document)] - if str(document) in metadata - else None - ), + document, metadata=metadata.get(str(document)) ) ) except Exception: diff --git a/ragna/deploy/_ui/modal_configuration.py b/ragna/deploy/_ui/modal_configuration.py index 244ab9df5..aea7d88f2 100644 --- a/ragna/deploy/_ui/modal_configuration.py +++ b/ragna/deploy/_ui/modal_configuration.py @@ -449,15 +449,10 @@ def add_error_message(self): ) def corpus_or_upload_row(self): if self.corpus_or_upload == USE_CORPUS_LABEL: - if self.config.source_storage_name in self.corpus_names: - corpus_names = self.corpus_names[self.config.source_storage_name] - else: - corpus_names = [] - - if self.config.source_storage_name in self.corpus_metadata: - corpus_metadata = self.corpus_metadata[self.config.source_storage_name] - else: - corpus_metadata = {} + corpus_names = self.corpus_names.get(self.config.source_storage_name, []) + corpus_metadata = self.corpus_metadata.get( + self.config.source_storage_name, {} + ) self.metadata_filter_rows = MetadataFiltersBuilder( corpus_names=corpus_names, corpus_metadata=corpus_metadata diff --git a/ragna/source_storages/_chroma.py b/ragna/source_storages/_chroma.py index 99f524da2..38ca3ffa6 100644 --- a/ragna/source_storages/_chroma.py +++ b/ragna/source_storages/_chroma.py @@ -77,10 +77,7 @@ def _get_collection( def list_metadata( self, corpus_name: Optional[str] = None ) -> dict[str, dict[str, tuple[str, list[Any]]]]: - if corpus_name is None: - corpus_names = self.list_corpuses() - else: - corpus_names = [corpus_name] + corpus_names = self.list_corpuses() if corpus_name is None else [corpus_name] metadata = {} for corpus_name in corpus_names: diff --git a/ragna/source_storages/_demo.py b/ragna/source_storages/_demo.py index 2ea2e7330..94843e581 100644 --- a/ragna/source_storages/_demo.py +++ b/ragna/source_storages/_demo.py @@ -55,10 +55,7 @@ def _get_corpus( def list_metadata( self, corpus_name: Optional[str] = None ) -> dict[str, dict[str, tuple[str, list[Any]]]]: - if corpus_name is None: - corpus_names = self.list_corpuses() - else: - corpus_names = [corpus_name] + corpus_names = self.list_corpuses() if corpus_name is None else [corpus_name] metadata = {} for corpus_name in corpus_names: diff --git a/ragna/source_storages/_lancedb.py b/ragna/source_storages/_lancedb.py index 5d70c5bca..fc921f280 100644 --- a/ragna/source_storages/_lancedb.py +++ b/ragna/source_storages/_lancedb.py @@ -93,10 +93,7 @@ def _get_table( def list_metadata( self, corpus_name: Optional[str] = None ) -> dict[str, dict[str, tuple[str, list[Any]]]]: - if corpus_name is None: - corpus_names = self.list_corpuses() - else: - corpus_names = [corpus_name] + corpus_names = self.list_corpuses() if corpus_name is None else [corpus_name] metadata = {} for corpus_name in corpus_names: diff --git a/tests/test_importable.py b/tests/test_importable.py index de09a4bb9..3a1a71c5f 100644 --- a/tests/test_importable.py +++ b/tests/test_importable.py @@ -8,11 +8,7 @@ def main(): for path in PACKAGE_ROOT.rglob("*.py"): - if path.name == "__init__.py": - path = path.parent - else: - path = path.with_suffix("") - + path = path.parent if path.name == "__init__.py" else path.with_suffix("") path = path.relative_to(PROJECT_ROOT) if any(part.startswith("_") for part in path.parts): From 810512842fc67b424e3c5595fdd4e42a94242786 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 10:39:29 +0200 Subject: [PATCH 06/13] exclude ISC001 --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index f4dd27ba8..b2e40efb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,6 +193,8 @@ select = [ ] ignore = [ + # Conflicts with ruff formatter + "ISC001", # Ignore line too long, because due to black, the error can only occur for strings "E501", # cache has its purpose From fc88e8415dc6412b76722435599b6ddd0953ab05 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 10:41:30 +0200 Subject: [PATCH 07/13] add PTH rule --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b2e40efb1..89d0d1dd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -190,6 +190,7 @@ select = [ "ISC", "RET", "SIM", + "PTH", ] ignore = [ @@ -199,6 +200,8 @@ ignore = [ "E501", # cache has its purpose "B019", + # built-in open() is well understood + "PTH123", ] [tool.ruff.lint.per-file-ignores] From a86931fefb1f49d0f1f471e0718e5413ae56b850 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 10:49:01 +0200 Subject: [PATCH 08/13] add D2 rule --- docs/examples/gallery_streaming.py | 3 +- docs/tutorials/gallery_custom_components.py | 3 +- docs/tutorials/gallery_python_api.py | 3 +- docs/tutorials/gallery_rest_api.py | 3 +- docs/tutorials/gallery_web_ui.py | 3 +- pyproject.toml | 11 +++++-- ragna/_utils.py | 3 ++ ragna/assistants/_cohere.py | 6 ++-- ragna/core/_components.py | 32 +++++++++++++++------ ragna/core/_document.py | 24 +++++++++------- ragna/core/_metadata_filter.py | 8 ++---- ragna/core/_rag.py | 18 ++++++++---- ragna/deploy/_auth.py | 8 ++---- ragna/deploy/_key_value_store.py | 4 +-- ragna/deploy/_ui/styles.py | 4 +-- tests/test_dependencies.py | 8 ++---- 16 files changed, 78 insertions(+), 63 deletions(-) diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index 2ba90e76b..e8447f5ec 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -1,5 +1,4 @@ -""" -# Streaming messages +"""# Streaming messages Ragna supports streaming responses from the assistant. This example showcases how this is performed using the Python and REST API. diff --git a/docs/tutorials/gallery_custom_components.py b/docs/tutorials/gallery_custom_components.py index 82476b3a7..2143727ec 100644 --- a/docs/tutorials/gallery_custom_components.py +++ b/docs/tutorials/gallery_custom_components.py @@ -1,5 +1,4 @@ -""" -# Custom Components +"""# Custom Components While Ragna has builtin support for a few [source storages][ragna.source_storages] and [assistants][ragna.assistants], its real strength lies in allowing users diff --git a/docs/tutorials/gallery_python_api.py b/docs/tutorials/gallery_python_api.py index 3b667c94e..eece8f784 100644 --- a/docs/tutorials/gallery_python_api.py +++ b/docs/tutorials/gallery_python_api.py @@ -1,5 +1,4 @@ -""" -# Python API +"""# Python API The [Python API](../../references/python-api.md) is the best place to get started with Ragna and understand its key components. It's also the best way to continue diff --git a/docs/tutorials/gallery_rest_api.py b/docs/tutorials/gallery_rest_api.py index 6df61c55f..e69adf885 100644 --- a/docs/tutorials/gallery_rest_api.py +++ b/docs/tutorials/gallery_rest_api.py @@ -1,5 +1,4 @@ -""" -# REST API +"""# REST API Ragna was designed to help you quickly build custom RAG powered web applications. For this you can leverage the built-in diff --git a/docs/tutorials/gallery_web_ui.py b/docs/tutorials/gallery_web_ui.py index f51ce7e70..018b85be8 100644 --- a/docs/tutorials/gallery_web_ui.py +++ b/docs/tutorials/gallery_web_ui.py @@ -1,5 +1,4 @@ -""" -# Web UI +"""# Web UI Ragna comes with a web UI where you can try out all of the features! diff --git a/pyproject.toml b/pyproject.toml index 89d0d1dd1..017bca8de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,6 +191,7 @@ select = [ "RET", "SIM", "PTH", + "D2", ] ignore = [ @@ -202,14 +203,18 @@ ignore = [ "B019", # built-in open() is well understood "PTH123", + # mutually-exclusive with D211 + "D203", + # mutually-exclusive with D212 + "D213", ] [tool.ruff.lint.per-file-ignores] # ignore unused imports and imports not at the top of the file in __init__.py files "__init__.py" = ["F401", "E402"] -# The examples often have imports below the top of the file to follow the narrative -"docs/examples/**/*.py" = ["E402", "F704", "I001"] -"docs/tutorials/**/*.py" = ["E402", "F704", "I001"] +# The examples and tutorials need to have a good the narrative rather than follow our code-style rules +"docs/examples/**/*.py" = ["E402", "F704", "I001", "D"] +"docs/tutorials/**/*.py" = ["E402", "F704", "I001", "D"] [tool.pytest.ini_options] minversion = "6.0" diff --git a/ragna/_utils.py b/ragna/_utils.py index 952b4d20e..e5b28059b 100644 --- a/ragna/_utils.py +++ b/ragna/_utils.py @@ -45,10 +45,13 @@ def local_root(path: Optional[Union[str, Path]] = None) -> Path: `~/.cache/ragna`. Args: + ---- path: If passed, this is set as new local root directory. Returns: + ------- Ragna's local root directory. + """ global _LOCAL_ROOT if path is not None: diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index e32acb824..3bb561f0b 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -60,8 +60,7 @@ async def answer( class Command(CohereAssistant): - """ - [Cohere Command](https://docs.cohere.com/docs/models#command) + """[Cohere Command](https://docs.cohere.com/docs/models#command) !!! info "Required environment variables" @@ -72,8 +71,7 @@ class Command(CohereAssistant): class CommandLight(CohereAssistant): - """ - [Cohere Command-Light](https://docs.cohere.com/docs/models#command) + """[Cohere Command-Light](https://docs.cohere.com/docs/models#command) !!! info "Required environment variables" diff --git a/ragna/core/_components.py b/ragna/core/_components.py index e1238e9a2..9d56b5711 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -37,10 +37,7 @@ class Component(RequirementsMixin): @classmethod def display_name(cls) -> str: - """ - Returns: - Component name. - """ + """Returns Component name.""" return cls.__name__ def __repr__(self) -> str: @@ -110,12 +107,14 @@ def _protocol_model(cls) -> Type[pydantic.BaseModel]: class Source(pydantic.BaseModel): """Data class for sources stored inside a source storage. - Attributes: + Attributes + ---------- id: Unique ID of the source. document: Document this source belongs to. location: Location of the source inside the document. content: Content of the source. num_tokens: Number of tokens of the content. + """ id: str @@ -137,8 +136,10 @@ def store(self, corpus_name: str, documents: list[Document]) -> None: """Store content of documents. Args: + ---- corpus_name: Name of the corpus to store the documents in. documents: Documents to store. + """ ... @@ -149,20 +150,25 @@ def retrieve( """Retrieve sources for a given prompt. Args: + ---- corpus_name: Name of the corpus to retrieve sources from. metadata_filter: Filter to select available sources. prompt: Prompt to retrieve sources for. Returns: + ------- Matching sources for the given prompt ordered by relevance. + """ ... def list_corpuses(self) -> list[str]: """List available corpuses. - Returns: + Returns + ------- List of available corpuses. + """ raise RagnaException( "list_corpuses is not implemented", @@ -177,10 +183,13 @@ def list_metadata( """List available metadata for corpuses. Args: + ---- corpus_name: Only return metadata for this corpus. Returns: + ------- List of available metadata. + """ raise RagnaException( "list_metadata is not implemented", @@ -193,12 +202,14 @@ def list_metadata( class MessageRole(str, enum.Enum): """Message role - Attributes: + Attributes + ---------- SYSTEM: The message was produced by the system. This includes the welcome message when [preparing a new chat][ragna.core.Chat.prepare] as well as error messages. USER: The message was produced by the user. ASSISTANT: The message was produced by an assistant. + """ SYSTEM = "system" @@ -209,7 +220,8 @@ class MessageRole(str, enum.Enum): class Message: """Data class for messages. - Attributes: + Attributes + ---------- role: The message producer. sources: The sources used to produce the message. @@ -217,6 +229,7 @@ class Message: - [ragna.core.Chat.prepare][] - [ragna.core.Chat.answer][] + """ def __init__( @@ -295,11 +308,14 @@ def answer(self, messages: list[Message]) -> Iterator[str]: """Answer a prompt given the chat history. Args: + ---- messages: List of messages in the chat history. The last item is the current user prompt and has the relevant sources attached to it. Returns: + ------- Answer. + """ ... diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 27d6a8fdc..fc58c25ca 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -38,11 +38,7 @@ def __init__( @staticmethod def supported_suffixes() -> set[str]: - """ - Returns: - Suffixes, i.e. `".txt"`, that can be handled by the builtin - [ragna.core.DocumentHandler][]s. - """ + """Returns Suffixes, i.e. `".txt"`, that can be handled by the builtin [ragna.core.DocumentHandler][]s.""" return set(DOCUMENT_HANDLERS.keys()) @staticmethod @@ -50,7 +46,9 @@ def get_handler(name: str) -> DocumentHandler: """Get a document handler based on a suffix. Args: + ---- name: Name of the document. + """ handler = DOCUMENT_HANDLERS.get(Path(name).suffix) if handler is None: @@ -102,6 +100,7 @@ def from_path( """Create a [ragna.core.LocalDocument][] from a path. Args: + ---- path: Local path to the file. id: ID of the document. If omitted, one is generated. name: Name of the document. If omitted, defaults to the name of the `path`. @@ -110,7 +109,9 @@ def from_path( on the suffix of the `path`. Raises: + ------ RagnaException: If `metadata` is passed and contains a `"path"` key. + """ if metadata is None: metadata = {} @@ -156,9 +157,11 @@ def read(self) -> bytes: class Page(BaseModel): """Dataclass for pages of a document - Attributes: + Attributes + ---------- text: Text included in the page. number: Page number. + """ text: str @@ -171,10 +174,7 @@ class DocumentHandler(RequirementsMixin, abc.ABC): @classmethod @abc.abstractmethod def supported_suffixes(cls) -> list[str]: - """ - Returns: - Suffixes supported by this document handler. - """ + """Returns Suffixes supported by this document handler.""" pass @abc.abstractmethod @@ -182,10 +182,13 @@ def extract_pages(self, document: Document) -> Iterator[Page]: """Extract pages from a document. Args: + ---- document: Document to extract pages from. Returns: + ------- Extracted pages. + """ ... @@ -209,6 +212,7 @@ def load_if_available(self, cls: Type[T]) -> Type[T]: @DOCUMENT_HANDLERS.load_if_available class PlainTextDocumentHandler(DocumentHandler): """Document handler for plain-text documents. + Currently supports `.txt` and `.md` extensions. """ diff --git a/ragna/core/_metadata_filter.py b/ragna/core/_metadata_filter.py index 3f894801a..b796b8174 100644 --- a/ragna/core/_metadata_filter.py +++ b/ragna/core/_metadata_filter.py @@ -9,9 +9,7 @@ class MetadataOperator(enum.Enum): - """ - ADDME - """ + """ADDME""" RAW = enum.auto() AND = enum.auto() @@ -27,9 +25,7 @@ class MetadataOperator(enum.Enum): class MetadataFilter: - """ - ADDME - """ + """ADDME""" def __init__(self, operator: MetadataOperator, key: str, value: Any) -> None: self.operator = operator diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index e9a57cdd8..193c62c29 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -154,6 +154,7 @@ def chat( """Create a new [ragna.core.Chat][]. Args: + ---- input: Subject of the chat. Available options: - `None`: Use the full corpus of documents specified by `corpus_name`. @@ -166,6 +167,7 @@ def chat( assistant: Assistant to use. corpus_name: Corpus of documents to use. **params: Additional parameters passed to the source storage and assistant. + """ return Chat( self, @@ -191,8 +193,7 @@ class SpecialChatParams(pydantic.BaseModel): class Chat: - """ - !!! tip + """!!! tip This object is usually not instantiated manually, but rather through [ragna.core.Rag.chat][]. @@ -215,6 +216,7 @@ class Chat: ``` Args: + ---- rag: The RAG workflow this chat is associated with. input: Subject of the chat. Available options: @@ -228,6 +230,7 @@ class Chat: assistant: Assistant to use. corpus_name: Corpus of documents to use. **params: Additional parameters passed to the source storage and assistant. + """ def __init__( @@ -265,8 +268,10 @@ async def prepare(self) -> Message: This [`store`][ragna.core.SourceStorage.store]s the documents in the selected source storage. Afterwards prompts can be [`answer`][ragna.core.Chat.answer]ed. - Returns: + Returns + ------- Welcome message. + """ welcome = Message( content="How can I help you with the documents?", @@ -287,12 +292,15 @@ async def prepare(self) -> Message: async def answer(self, prompt: str, *, stream: bool = False) -> Message: """Answer a prompt. - Returns: + Returns + ------- Answer. - Raises: + Raises + ------ ragna.core.RagnaException: If chat is not [`prepare`][ragna.core.Chat.prepare]d. + """ if not self._prepared: raise RagnaException( diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py index 6009fb83d..4d103fa95 100644 --- a/ragna/deploy/_auth.py +++ b/ragna/deploy/_auth.py @@ -223,9 +223,7 @@ async def _get_user(session: SessionDependency) -> schemas.User: class Auth(abc.ABC): - """ - ADDME - """ + """ADDME""" @classmethod def _add_to_app( @@ -285,9 +283,7 @@ def login_page(self, request: Request) -> Response: class NoAuth(_AutomaticLoginAuthBase): - """ - ADDME - """ + """ADDME""" def login(self, request: Request) -> schemas.User: return schemas.User( diff --git a/ragna/deploy/_key_value_store.py b/ragna/deploy/_key_value_store.py index c6be1597c..17083bf89 100644 --- a/ragna/deploy/_key_value_store.py +++ b/ragna/deploy/_key_value_store.py @@ -26,9 +26,7 @@ def to_model(self) -> M: class KeyValueStore(abc.ABC, RequirementsMixin, Generic[M]): - """ - ADDME - """ + """ADDME""" def serialize(self, model: M) -> str: return SerializableModel.from_model(model).model_dump_json() diff --git a/ragna/deploy/_ui/styles.py b/ragna/deploy/_ui/styles.py index 10ddf3097..a22bbc8c6 100644 --- a/ragna/deploy/_ui/styles.py +++ b/ragna/deploy/_ui/styles.py @@ -1,6 +1,4 @@ -""" -UI Helpers -""" +"""UI Helpers""" from typing import Any, Dict, Iterable, Union diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index 4fe56a0ff..6e4e81439 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -1,8 +1,6 @@ -""" -The tests in this module are only here as a reminder to clean up our code if an issue is -fixed upstream. If you see a test failing here, i.e. an unexpected success, feel free to -remove the offending test after you have cleaned up our code. -""" +# The tests in this module are only here as a reminder to clean up our code if an issue is +# fixed upstream. If you see a test failing here, i.e. an unexpected success, feel free to +# remove the offending test after you have cleaned up our code. from importlib.metadata import packages_distributions From 3377d862542311d420e6a5486d904a9168d935c0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 10:50:25 +0200 Subject: [PATCH 09/13] add UP rule --- docs/tutorials/gallery_custom_components.py | 4 +- pyproject.toml | 1 + ragna/_cli/config.py | 15 ++--- ragna/_cli/core.py | 6 +- ragna/_cli/corpus.py | 6 +- ragna/_docs.py | 4 +- ragna/_utils.py | 17 ++---- ragna/assistants/_ai21labs.py | 3 +- ragna/assistants/_anthropic.py | 3 +- ragna/assistants/_cohere.py | 3 +- ragna/assistants/_demo.py | 2 +- ragna/assistants/_google.py | 2 +- ragna/assistants/_http_api.py | 15 ++--- ragna/assistants/_ollama.py | 3 +- ragna/assistants/_openai.py | 8 ++- ragna/core/_components.py | 21 +++---- ragna/core/_document.py | 25 ++++---- ragna/core/_metadata_filter.py | 7 ++- ragna/core/_rag.py | 68 +++++++++------------ ragna/core/_utils.py | 13 ++-- ragna/deploy/_api.py | 3 +- ragna/deploy/_auth.py | 17 +++--- ragna/deploy/_config.py | 8 +-- ragna/deploy/_core.py | 5 +- ragna/deploy/_database.py | 15 +++-- ragna/deploy/_engine.py | 7 ++- ragna/deploy/_key_value_store.py | 27 ++++---- ragna/deploy/_orm.py | 14 ++--- ragna/deploy/_ui/central_view.py | 9 +-- ragna/deploy/_ui/modal_configuration.py | 2 +- ragna/deploy/_ui/styles.py | 7 ++- ragna/deploy/_utils.py | 3 +- ragna/source_storages/_chroma.py | 10 +-- ragna/source_storages/_demo.py | 7 ++- ragna/source_storages/_lancedb.py | 6 +- ragna/source_storages/_qdrant.py | 7 ++- ragna/source_storages/_vector_database.py | 9 +-- 37 files changed, 189 insertions(+), 193 deletions(-) diff --git a/docs/tutorials/gallery_custom_components.py b/docs/tutorials/gallery_custom_components.py index 2143727ec..8be576777 100644 --- a/docs/tutorials/gallery_custom_components.py +++ b/docs/tutorials/gallery_custom_components.py @@ -83,7 +83,7 @@ def retrieve( # [streaming example](../../generated/examples/gallery_streaming.md) for more # information. -from typing import Iterator +from collections.abc import Iterator from ragna.core import Assistant, Source @@ -401,7 +401,7 @@ def answer( import asyncio import time -from typing import AsyncIterator +from collections.abc import AsyncIterator class AsyncAssistant(Assistant): diff --git a/pyproject.toml b/pyproject.toml index 017bca8de..9666966f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -192,6 +192,7 @@ select = [ "SIM", "PTH", "D2", + "UP", ] ignore = [ diff --git a/ragna/_cli/config.py b/ragna/_cli/config.py index 9e469a860..27daef4dc 100644 --- a/ragna/_cli/config.py +++ b/ragna/_cli/config.py @@ -1,8 +1,9 @@ import itertools from collections import defaultdict +from collections.abc import Iterable from pathlib import Path from types import ModuleType -from typing import Annotated, Iterable, Type, TypeVar, cast +from typing import Annotated, TypeVar, cast import pydantic import questionary @@ -157,8 +158,8 @@ def _wizard_builtin() -> Config: def _select_components( title: str, module: ModuleType, - base_cls: Type[T], -) -> list[Type[T]]: + base_cls: type[T], +) -> list[type[T]]: components = sorted( ( obj @@ -170,7 +171,7 @@ def _select_components( key=lambda component: component.display_name(), ) return cast( - list[Type[T]], + list[type[T]], questionary.checkbox( f"Which {title} do you want to use?", choices=[ @@ -186,7 +187,7 @@ def _select_components( ) -def _handle_unmet_requirements(components: Iterable[Type[Component]]) -> None: +def _handle_unmet_requirements(components: Iterable[type[Component]]) -> None: unmet_requirements = { requirement for component in components @@ -331,7 +332,7 @@ def check_config(config: Config) -> bool: ("source storages", config.source_storages), ("assistants", config.assistants), ]: - components = cast(list[Type[Component]], components) + components = cast(list[type[Component]], components) table = Table( "", @@ -361,7 +362,7 @@ def check_config(config: Config) -> bool: def _split_requirements( requirements: Iterable[Requirement], -) -> dict[Type[Requirement], list[Requirement]]: +) -> dict[type[Requirement], list[Requirement]]: split_reqs = defaultdict(list) for req in requirements: split_reqs[type(req)].append(req) diff --git a/ragna/_cli/core.py b/ragna/_cli/core.py index 2190045eb..d07632128 100644 --- a/ragna/_cli/core.py +++ b/ragna/_cli/core.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Annotated, Optional +from typing import Annotated import rich import typer @@ -30,7 +30,7 @@ def version_callback(value: bool) -> None: @app.callback() def _main( version: Annotated[ - Optional[bool], + bool | None, typer.Option( "--version", callback=version_callback, help="Show version and exit." ), @@ -97,7 +97,7 @@ def deploy( ), ] = False, open_browser: Annotated[ - Optional[bool], + bool | None, typer.Option( help="Open a browser when Ragna is deployed.", show_default="value of ui / no-ui", diff --git a/ragna/_cli/corpus.py b/ragna/_cli/corpus.py index 7137d97ce..7ebee9ae2 100644 --- a/ragna/_cli/corpus.py +++ b/ragna/_cli/corpus.py @@ -1,7 +1,7 @@ import json import sys from pathlib import Path -from typing import Annotated, Optional +from typing import Annotated import rich import typer @@ -45,7 +45,7 @@ def experimental_warning() -> None: def ingest( documents: list[Path], metadata_fields: Annotated[ - Optional[Path], + Path | None, typer.Option( help="JSON file that contains mappings from document name " "to metadata fields associated with a document.", @@ -57,7 +57,7 @@ def ingest( ] = "default", config: ConfigOption = "./ragna.toml", # type: ignore[assignment] user: Annotated[ - Optional[str], + str | None, typer.Option(help="User to link the documents to in the ragna database."), ] = None, report_failures: Annotated[ diff --git a/ragna/_docs.py b/ragna/_docs.py index a9a3977cf..adec905aa 100644 --- a/ragna/_docs.py +++ b/ragna/_docs.py @@ -7,7 +7,7 @@ import tempfile import textwrap from pathlib import Path -from typing import Any, Optional, cast +from typing import Any, cast import httpx @@ -114,7 +114,7 @@ def get_http_client( *, authenticate: bool = False, upload_sample_document: bool = False, - ) -> tuple[httpx.Client, Optional[dict[str, Any]]]: + ) -> tuple[httpx.Client, dict[str, Any] | None]: if upload_sample_document and not authenticate: raise RagnaException( "Cannot upload a document without authenticating first. " diff --git a/ragna/_utils.py b/ragna/_utils.py index e5b28059b..435e106d2 100644 --- a/ragna/_utils.py +++ b/ragna/_utils.py @@ -10,16 +10,11 @@ import sys import threading import time +from collections.abc import AsyncIterator, Awaitable, Callable, Iterator from pathlib import Path from typing import ( Any, - AsyncIterator, - Awaitable, - Callable, - Iterator, - Optional, TypeVar, - Union, cast, ) @@ -32,13 +27,13 @@ ) -def make_directory(path: Union[str, Path]) -> Path: +def make_directory(path: str | Path) -> Path: path = Path(path).expanduser().resolve() path.mkdir(parents=True, exist_ok=True) return path -def local_root(path: Optional[Union[str, Path]] = None) -> Path: +def local_root(path: str | Path | None = None) -> Path: """Get or set the local root directory Ragna uses for storing files. Defaults to the value of the `RAGNA_LOCAL_ROOT` environment variable or otherwise to @@ -137,7 +132,7 @@ def is_debugging() -> bool: def as_awaitable( - fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any, **kwargs: Any + fn: Callable[..., T] | Callable[..., Awaitable[T]], *args: Any, **kwargs: Any ) -> Awaitable[T]: if inspect.iscoroutinefunction(fn): fn = cast(Callable[..., Awaitable[T]], fn) @@ -150,7 +145,7 @@ def as_awaitable( def as_async_iterator( - fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]], + fn: Callable[..., Iterator[T]] | Callable[..., AsyncIterator[T]], *args: Any, **kwargs: Any, ) -> AsyncIterator[T]: @@ -178,7 +173,7 @@ def __init__( *cmd: str, stdout: Any = sys.stdout, stderr: Any = sys.stdout, - startup_fn: Optional[Callable[[], bool]] = None, + startup_fn: Callable[[], bool] | None = None, startup_timeout: float = 10, terminate_timeout: float = 10, text: bool = True, diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 79230dde5..8b41a13dc 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -1,4 +1,5 @@ -from typing import AsyncIterator, cast +from collections.abc import AsyncIterator +from typing import cast from ragna.core import Message, Source diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index cede0e30a..20a6d9e47 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -1,4 +1,5 @@ -from typing import AsyncIterator, cast +from collections.abc import AsyncIterator +from typing import cast from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index 3bb561f0b..4671d7245 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -1,4 +1,5 @@ -from typing import AsyncIterator, cast +from collections.abc import AsyncIterator +from typing import cast from ragna.core import Message, RagnaException, Source diff --git a/ragna/assistants/_demo.py b/ragna/assistants/_demo.py index aa1301f48..99e55a779 100644 --- a/ragna/assistants/_demo.py +++ b/ragna/assistants/_demo.py @@ -1,5 +1,5 @@ import textwrap -from typing import Iterator +from collections.abc import Iterator from ragna.core import Assistant, Message, MessageRole diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index c263ca565..863076132 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator +from collections.abc import AsyncIterator from ragna.core import Message, Source diff --git a/ragna/assistants/_http_api.py b/ragna/assistants/_http_api.py index adc794b89..363a06b48 100644 --- a/ragna/assistants/_http_api.py +++ b/ragna/assistants/_http_api.py @@ -2,7 +2,8 @@ import enum import json import os -from typing import Any, AsyncContextManager, AsyncIterator, Optional +from collections.abc import AsyncIterator +from typing import Any import httpx @@ -34,7 +35,7 @@ def requirements(cls, protocol: HttpStreamingProtocol) -> list[Requirement]: def __init__( self, client: httpx.AsyncClient, - protocol: Optional[HttpStreamingProtocol] = None, + protocol: HttpStreamingProtocol | None = None, ) -> None: self._client = client self._protocol = protocol @@ -44,9 +45,9 @@ def __call__( method: str, url: str, *, - parse_kwargs: Optional[dict[str, Any]] = None, + parse_kwargs: dict[str, Any] | None = None, **kwargs: Any, - ) -> AsyncContextManager[AsyncIterator[Any]]: + ) -> contextlib.AbstractAsyncContextManager[AsyncIterator[Any]]: if self._protocol is None: call_method = self._no_stream else: @@ -176,8 +177,8 @@ async def _assert_api_call_is_success(self, response: httpx.Response) -> None: class HttpApiAssistant(Assistant): - _API_KEY_ENV_VAR: Optional[str] - _STREAMING_PROTOCOL: Optional[HttpStreamingProtocol] + _API_KEY_ENV_VAR: str | None + _STREAMING_PROTOCOL: HttpStreamingProtocol | None @classmethod def requirements(cls) -> list[Requirement]: @@ -200,7 +201,7 @@ def __init__(self) -> None: headers={"User-Agent": f"{ragna.__version__}/{self}"}, timeout=60, ) - self._api_key: Optional[str] = ( + self._api_key: str | None = ( os.environ[self._API_KEY_ENV_VAR] if self._API_KEY_ENV_VAR is not None else None diff --git a/ragna/assistants/_ollama.py b/ragna/assistants/_ollama.py index 591c7ed14..3015ae2c6 100644 --- a/ragna/assistants/_ollama.py +++ b/ragna/assistants/_ollama.py @@ -1,6 +1,7 @@ import os +from collections.abc import AsyncIterator from functools import cached_property -from typing import AsyncIterator, cast +from typing import cast from ragna.core import Message, RagnaException diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 1867a26be..547362e44 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -1,6 +1,8 @@ import abc +import contextlib +from collections.abc import AsyncIterator from functools import cached_property -from typing import Any, AsyncContextManager, AsyncIterator, Optional, cast +from typing import Any, cast from ragna.core import Message, Source @@ -8,7 +10,7 @@ class OpenaiLikeHttpApiAssistant(HttpApiAssistant): - _MODEL: Optional[str] + _MODEL: str | None @property @abc.abstractmethod @@ -25,7 +27,7 @@ def _make_system_content(self, sources: list[Source]) -> str: def _call_openai_api( self, prompt: str, sources: list[Source], *, max_new_tokens: int - ) -> AsyncContextManager[AsyncIterator[dict[str, Any]]]: + ) -> contextlib.AbstractAsyncContextManager[AsyncIterator[dict[str, Any]]]: # See https://platform.openai.com/docs/api-reference/chat/create # and https://platform.openai.com/docs/api-reference/chat/streaming headers = { diff --git a/ragna/core/_components.py b/ragna/core/_components.py index 9d56b5711..b7a02ed25 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -5,15 +5,10 @@ import functools import inspect import uuid +from collections.abc import AsyncIterable, AsyncIterator, Iterator from datetime import datetime, timezone from typing import ( Any, - AsyncIterable, - AsyncIterator, - Iterator, - Optional, - Type, - Union, get_type_hints, ) @@ -51,7 +46,7 @@ def __repr__(self) -> str: @functools.cache def _protocol_models( cls, - ) -> dict[tuple[Type[Component], str], Type[pydantic.BaseModel]]: + ) -> dict[tuple[type[Component], str], type[pydantic.BaseModel]]: # This method dynamically builds a pydantic.BaseModel for the extra parameters # of each method that is listed in the __ragna_protocol_methods__ class # variable. These models are used by ragna.core.Chat._unpack_chat_params to @@ -100,7 +95,7 @@ def _protocol_models( @classmethod @functools.cache - def _protocol_model(cls) -> Type[pydantic.BaseModel]: + def _protocol_model(cls) -> type[pydantic.BaseModel]: return merge_models(cls.display_name(), *cls._protocol_models().values()) @@ -178,7 +173,7 @@ def list_corpuses(self) -> list[str]: ) def list_metadata( - self, corpus_name: Optional[str] = None + self, corpus_name: str | None = None ) -> dict[str, dict[str, tuple[str, list[Any]]]]: """List available metadata for corpuses. @@ -234,12 +229,12 @@ class Message: def __init__( self, - content: Union[str, AsyncIterable[str]], + content: str | AsyncIterable[str], *, role: MessageRole = MessageRole.SYSTEM, - sources: Optional[list[Source]] = None, - id: Optional[uuid.UUID] = None, - timestamp: Optional[datetime] = None, + sources: list[Source] | None = None, + id: uuid.UUID | None = None, + timestamp: datetime | None = None, ) -> None: if isinstance(content, str): self._content: str = content diff --git a/ragna/core/_document.py b/ragna/core/_document.py index fc58c25ca..0d51f7597 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -4,9 +4,10 @@ import io import mimetypes import uuid +from collections.abc import AsyncIterator, Iterator from functools import cached_property from pathlib import Path -from typing import Any, AsyncIterator, Iterator, Optional, Type, TypeVar, Union +from typing import Any, TypeVar import aiofiles from pydantic import BaseModel @@ -22,10 +23,10 @@ class Document(RequirementsMixin, abc.ABC): def __init__( self, *, - id: Optional[uuid.UUID] = None, + id: uuid.UUID | None = None, name: str, metadata: dict[str, Any], - handler: Optional[DocumentHandler] = None, + handler: DocumentHandler | None = None, mime_type: str | None = None, ): self.id = id or uuid.uuid4() @@ -75,10 +76,10 @@ class LocalDocument(Document): def __init__( self, *, - id: Optional[uuid.UUID] = None, + id: uuid.UUID | None = None, name: str, metadata: dict[str, Any], - handler: Optional[DocumentHandler] = None, + handler: DocumentHandler | None = None, mime_type: str | None = None, ): super().__init__( @@ -90,12 +91,12 @@ def __init__( @classmethod def from_path( cls, - path: Union[str, Path], + path: str | Path, *, - id: Optional[uuid.UUID] = None, - name: Optional[str] = None, - metadata: Optional[dict[str, Any]] = None, - handler: Optional[DocumentHandler] = None, + id: uuid.UUID | None = None, + name: str | None = None, + metadata: dict[str, Any] | None = None, + handler: DocumentHandler | None = None, ) -> LocalDocument: """Create a [ragna.core.LocalDocument][] from a path. @@ -165,7 +166,7 @@ class Page(BaseModel): """ text: str - number: Optional[int] = None + number: int | None = None class DocumentHandler(RequirementsMixin, abc.ABC): @@ -197,7 +198,7 @@ def extract_pages(self, document: Document) -> Iterator[Page]: class DocumentHandlerRegistry(dict[str, DocumentHandler]): - def load_if_available(self, cls: Type[T]) -> Type[T]: + def load_if_available(self, cls: type[T]) -> type[T]: if cls.is_available(): instance = cls() for suffix in cls.supported_suffixes(): diff --git a/ragna/core/_metadata_filter.py b/ragna/core/_metadata_filter.py index b796b8174..84985bffd 100644 --- a/ragna/core/_metadata_filter.py +++ b/ragna/core/_metadata_filter.py @@ -2,7 +2,8 @@ import enum import textwrap -from typing import Any, Literal, Sequence, Union, cast +from collections.abc import Sequence +from typing import Any, Literal, cast import pydantic import pydantic_core @@ -97,13 +98,13 @@ def from_primitive(cls, obj: dict[str, Any]) -> MetadataFilter: def __get_pydantic_core_schema__( cls, source_type: Any, handler: pydantic.GetCoreSchemaHandler ) -> pydantic_core.CoreSchema: - def validate(value: Union[MetadataFilter, dict[str, Any]]) -> MetadataFilter: + def validate(value: MetadataFilter | dict[str, Any]) -> MetadataFilter: if isinstance(value, MetadataFilter): return value return cls.from_primitive(value) - def serialize(value: Union[MetadataFilter, dict[str, Any]]) -> dict[str, Any]: + def serialize(value: MetadataFilter | dict[str, Any]) -> dict[str, Any]: if isinstance(value, MetadataFilter): return value.to_primitive() diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 193c62c29..1c15c6b9b 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -5,20 +5,14 @@ import itertools import uuid from collections import defaultdict +from collections.abc import AsyncIterator, Awaitable, Callable, Collection, Iterator from datetime import datetime, timezone from pathlib import Path from typing import ( TYPE_CHECKING, Any, - AsyncIterator, - Awaitable, - Callable, - Collection, Generic, - Iterator, - Optional, TypeVar, - Union, cast, ) @@ -55,7 +49,7 @@ class Rag(Generic[C]): def __init__( self, *, - config: Optional[Config] = None, + config: Config | None = None, ignore_unavailable_components: bool = False, ) -> None: self._components: dict[type[C], C] = {} @@ -92,10 +86,10 @@ def _preload_components( ) def _load_component( - self, component: Union[C, type[C], str], *, ignore_unavailable: bool = False - ) -> Optional[C]: + self, component: C | type[C] | str, *, ignore_unavailable: bool = False + ) -> C | None: cls: type[C] - instance: Optional[C] + instance: C | None if isinstance(component, Component): instance = cast(C, component) @@ -137,17 +131,15 @@ def _load_component( def chat( self, - input: Union[ - None, - MetadataFilter, - Document, - str, - Path, - Collection[Union[Document, str, Path]], - ] = None, + input: None + | MetadataFilter + | Document + | str + | Path + | Collection[Document | str | Path] = None, *, - source_storage: Union[SourceStorage, type[SourceStorage], str], - assistant: Union[Assistant, type[Assistant], str], + source_storage: SourceStorage | type[SourceStorage] | str, + assistant: Assistant | type[Assistant] | str, corpus_name: str = "default", **params: Any, ) -> Chat: @@ -236,14 +228,12 @@ class Chat: def __init__( self, rag: Rag, - input: Union[ - None, - MetadataFilter, - Document, - str, - Path, - Collection[Union[Document, str, Path]], - ] = None, + input: None + | MetadataFilter + | Document + | str + | Path + | Collection[Document | str | Path] = None, *, source_storage: SourceStorage, assistant: Assistant, @@ -350,15 +340,13 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: def _parse_input( self, - input: Union[ - MetadataFilter, - None, - Document, - str, - Path, - Collection[Union[Document, str, Path]], - ], - ) -> tuple[Optional[list[Document]], Optional[MetadataFilter], bool]: + input: MetadataFilter + | None + | Document + | str + | Path + | Collection[Document | str | Path], + ) -> tuple[list[Document] | None, MetadataFilter | None, bool]: if isinstance(input, MetadataFilter) or input is None: return None, input, True @@ -500,13 +488,13 @@ def format_error( raise RagnaException("\n".join(parts)) from None def _as_awaitable( - self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any + self, fn: Callable[..., T] | Callable[..., Awaitable[T]], *args: Any ) -> Awaitable[T]: return as_awaitable(fn, *args, **self._unpacked_params[fn]) def _as_async_iterator( self, - fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]], + fn: Callable[..., Iterator[T]] | Callable[..., AsyncIterator[T]], *args: Any, ) -> AsyncIterator[T]: return as_async_iterator(fn, *args, **self._unpacked_params[fn]) diff --git a/ragna/core/_utils.py b/ragna/core/_utils.py index 6360ec76f..750143533 100644 --- a/ragna/core/_utils.py +++ b/ragna/core/_utils.py @@ -7,7 +7,8 @@ import importlib.metadata import os from collections import defaultdict -from typing import Any, Collection, Optional, Type, Union, cast +from collections.abc import Collection +from typing import Any, cast import packaging.requirements import pydantic @@ -34,7 +35,7 @@ def __init__( # FIXME: remove default value for event event: str = "", http_status_code: int = 500, - http_detail: Union[str, RagnaExceptionHttpDetail] = "", + http_detail: str | RagnaExceptionHttpDetail = "", **extra: Any, ) -> None: self.event = event @@ -121,9 +122,9 @@ def __repr__(self) -> str: def merge_models( model_name: str, - *models: Type[pydantic.BaseModel], - config: Optional[pydantic.ConfigDict] = None, -) -> Type[pydantic.BaseModel]: + *models: type[pydantic.BaseModel], + config: pydantic.ConfigDict | None = None, +) -> type[pydantic.BaseModel]: raw_field_definitions = defaultdict(list) for model_cls in models: for name, field in model_cls.__pydantic_fields__.items(): @@ -160,7 +161,7 @@ def merge_models( field_definitions[name] = (type_, pydantic.Field(**kwargs)) return cast( - Type[pydantic.BaseModel], + type[pydantic.BaseModel], pydantic.create_model( # type: ignore[call-overload] model_name, **field_definitions, __config__=config ), diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index ae4f2004e..2bd5e3c8f 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -1,6 +1,7 @@ import io import uuid -from typing import Annotated, Any, AsyncIterator +from collections.abc import AsyncIterator +from typing import Annotated, Any import pydantic from fastapi import APIRouter, Body, UploadFile diff --git a/ragna/deploy/_auth.py b/ragna/deploy/_auth.py index 4d103fa95..3a2b4fe42 100644 --- a/ragna/deploy/_auth.py +++ b/ragna/deploy/_auth.py @@ -7,7 +7,8 @@ import os import re import uuid -from typing import TYPE_CHECKING, Annotated, Awaitable, Callable, Optional, Union, cast +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Annotated, cast import httpx import panel as pn @@ -124,7 +125,7 @@ async def _cookie_dispatch( # just for panel, we just inject them into the scope here, which will be # parsed by panel down the line. After this initial request, the values are # tied to the active session and don't have to be set again. - extra_cookies: dict[str, Union[str, bytes]] = { + extra_cookies: dict[str, str | bytes] = { "user": session.user.name, "id_token": base64.b64encode(json.dumps(session.user.data).encode()), } @@ -202,7 +203,7 @@ def _delete_cookie(self, response: Response) -> None: async def _get_session(request: Request) -> Session: - session = cast(Optional[Session], request.state.session) + session = cast(Session | None, request.state.session) if session is None: raise RagnaException( "Not authenticated", @@ -268,7 +269,7 @@ async def logout() -> RedirectResponse: def login_page(self, request: Request) -> Response: ... @abc.abstractmethod - def login(self, request: Request) -> Union[schemas.User, Response]: ... + def login(self, request: Request) -> schemas.User | Response: ... class _AutomaticLoginAuthBase(Auth): @@ -307,8 +308,8 @@ def login_page( self, request: Request, *, - username: Optional[str] = None, - fail_reason: Optional[str] = None, + username: str | None = None, + fail_reason: str | None = None, ) -> HTMLResponse: return HTMLResponse( templates.render( @@ -316,7 +317,7 @@ def login_page( ) ) - async def login(self, request: Request) -> Union[schemas.User, Response]: + async def login(self, request: Request) -> schemas.User | Response: async with request.form() as form: username = cast(str, form.get("username")) password = cast(str, form.get("password")) @@ -359,7 +360,7 @@ def login_page(self, request: Request) -> HTMLResponse: ) ) - async def login(self, request: Request) -> Union[schemas.User, Response]: + async def login(self, request: Request) -> schemas.User | Response: async with httpx.AsyncClient(headers={"Accept": "application/json"}) as client: response = await client.post( "https://github.com/login/oauth/access_token", diff --git a/ragna/deploy/_config.py b/ragna/deploy/_config.py index 6cd42ab9e..51cde6894 100644 --- a/ragna/deploy/_config.py +++ b/ragna/deploy/_config.py @@ -2,7 +2,7 @@ import itertools from pathlib import Path -from typing import Annotated, ClassVar, Type, Union, cast +from typing import Annotated, ClassVar, cast import tomlkit import tomlkit.container @@ -33,7 +33,7 @@ class Config(BaseSettings): @classmethod def settings_customise_sources( cls, - settings_cls: Type[BaseSettings], + settings_cls: type[BaseSettings], init_settings: PydanticBaseSettingsSource, env_settings: PydanticBaseSettingsSource, dotenv_settings: PydanticBaseSettingsSource, @@ -104,7 +104,7 @@ def _set_multiline_array(self, item: tomlkit.items.Item) -> None: self._set_multiline_array(child) @classmethod - def from_file(cls, path: Union[str, Path]) -> Config: + def from_file(cls, path: str | Path) -> Config: path = Path(path).expanduser().resolve() if not path.is_file(): raise RagnaException(f"{path} does not exist.") @@ -113,7 +113,7 @@ def from_file(cls, path: Union[str, Path]) -> Config: type[Config], type(cls.__name__, (cls,), {"__config_path__": path}) )() - def to_file(self, path: Union[str, Path], *, force: bool = False) -> None: + def to_file(self, path: str | Path, *, force: bool = False) -> None: path = Path(path).expanduser().resolve() if path.exists() and not force: raise RagnaException(f"{path} already exists.") diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index 47dcb13ab..ce2fb9c1a 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -3,8 +3,9 @@ import time import uuid import webbrowser +from collections.abc import AsyncIterator, Callable from pathlib import Path -from typing import AsyncContextManager, AsyncIterator, Callable, Optional, cast +from typing import cast import httpx import panel.io.fastapi @@ -35,7 +36,7 @@ def make_app( ) -> FastAPI: set_redirect_root_path(config.root_path) - lifespan: Optional[Callable[[FastAPI], AsyncContextManager]] + lifespan: Callable[[FastAPI], contextlib.AbstractAsyncContextManager] | None if open_browser: @contextlib.asynccontextmanager diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index c721411e7..5a6c9880c 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -1,7 +1,8 @@ from __future__ import annotations import uuid -from typing import Any, Collection, Optional, cast +from collections.abc import Collection +from typing import Any, cast from urllib.parse import urlsplit from sqlalchemy import create_engine, select @@ -14,9 +15,7 @@ class UnknownUser(Exception): - def __init__( - self, name: Optional[str] = None, api_key: Optional[str] = None - ) -> None: + def __init__(self, name: str | None = None, api_key: str | None = None) -> None: self.name = name self.api_key = api_key @@ -38,7 +37,7 @@ def __init__(self, url: str) -> None: def _get_orm_user_by_name(self, session: Session, *, name: str) -> orm.User: user = cast( - Optional[orm.User], + orm.User | None, session.execute( select(orm.User).where(orm.User.name == name) ).scalar_one_or_none(), @@ -103,7 +102,7 @@ def delete_api_key(self, session: Session, *, user: str, id: uuid.UUID) -> None: def get_user_by_api_key( self, session: Session, api_key_value: str - ) -> Optional[tuple[schemas.User, schemas.ApiKey]]: + ) -> tuple[schemas.User, schemas.ApiKey] | None: orm_api_key = session.execute( select(orm.ApiKey) # type: ignore[attr-defined] .options(joinedload(orm.ApiKey.user)) @@ -232,7 +231,7 @@ def get_chats(self, session: Session, *, user: str) -> list[schemas.Chat]: def _get_orm_chat( self, session: Session, *, user: str, id: uuid.UUID, eager: bool = False ) -> orm.Chat: - chat: Optional[orm.Chat] = ( + chat: orm.Chat | None = ( session.execute( self._select_chat(eager=eager).where( (orm.Chat.id == id) @@ -251,7 +250,7 @@ def _get_orm_chat( def get_chat(self, session: Session, *, user: str, id: uuid.UUID) -> schemas.Chat: return self._to_schema.chat( - (self._get_orm_chat(session, user=user, id=id, eager=True)) + self._get_orm_chat(session, user=user, id=id, eager=True) ) def update_chat(self, session: Session, user: str, chat: schemas.Chat) -> None: diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 72c654348..8566e2136 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -1,6 +1,7 @@ import secrets import uuid -from typing import Any, AsyncIterator, Collection, Optional, cast +from collections.abc import AsyncIterator, Collection +from typing import Any, cast from fastapi import status as http_status_code @@ -40,7 +41,7 @@ def maybe_add_user(self, user: schemas.User) -> None: def get_user_by_api_key( self, api_key_value: str - ) -> tuple[Optional[schemas.User], bool]: + ) -> tuple[schemas.User | None, bool]: with self._database.get_session() as session: data = self._database.get_user_by_api_key( session, api_key_value=api_key_value @@ -354,7 +355,7 @@ def source(self, source: core.Source) -> schemas.Source: ) def message( - self, message: core.Message, *, content_override: Optional[str] = None + self, message: core.Message, *, content_override: str | None = None ) -> schemas.Message: return schemas.Message( id=message.id, diff --git a/ragna/deploy/_key_value_store.py b/ragna/deploy/_key_value_store.py index 17083bf89..564d7b789 100644 --- a/ragna/deploy/_key_value_store.py +++ b/ragna/deploy/_key_value_store.py @@ -3,7 +3,8 @@ import abc import os import time -from typing import Any, Callable, Generic, Optional, TypeVar, Union, cast +from collections.abc import Callable +from typing import Any, Generic, TypeVar, cast import pydantic @@ -31,37 +32,35 @@ class KeyValueStore(abc.ABC, RequirementsMixin, Generic[M]): def serialize(self, model: M) -> str: return SerializableModel.from_model(model).model_dump_json() - def deserialize(self, data: Union[str, bytes]) -> M: + def deserialize(self, data: str | bytes) -> M: return SerializableModel.model_validate_json(data).to_model() @abc.abstractmethod - def set( - self, key: str, model: M, *, expires_after: Optional[int] = None - ) -> None: ... + def set(self, key: str, model: M, *, expires_after: int | None = None) -> None: ... @abc.abstractmethod - def get(self, key: str) -> Optional[M]: ... + def get(self, key: str) -> M | None: ... @abc.abstractmethod def delete(self, key: str) -> None: ... @abc.abstractmethod - def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: ... + def refresh(self, key: str, *, expires_after: int | None = None) -> None: ... class InMemoryKeyValueStore(KeyValueStore[M]): def __init__(self) -> None: - self._store: dict[str, tuple[M, Optional[float]]] = {} + self._store: dict[str, tuple[M, float | None]] = {} self._timer: Callable[[], float] = time.monotonic - def set(self, key: str, model: M, *, expires_after: Optional[int] = None) -> None: + def set(self, key: str, model: M, *, expires_after: int | None = None) -> None: if expires_after is not None: expires_at = self._timer() + expires_after else: expires_at = None self._store[key] = (model, expires_at) - def get(self, key: str) -> Optional[M]: + def get(self, key: str) -> M | None: value = self._store.get(key) if value is None: return None @@ -79,7 +78,7 @@ def delete(self, key: str) -> None: del self._store[key] - def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: + def refresh(self, key: str, *, expires_after: int | None = None) -> None: value = self._store.get(key) if value is None: return @@ -101,10 +100,10 @@ def __init__(self) -> None: port=int(os.environ.get("RAGNA_REDIS_PORT", 6379)), ) - def set(self, key: str, model: M, *, expires_after: Optional[int] = None) -> None: + def set(self, key: str, model: M, *, expires_after: int | None = None) -> None: self._r.set(key, self.serialize(model), ex=expires_after) - def get(self, key: str) -> Optional[M]: + def get(self, key: str) -> M | None: value = cast(bytes, self._r.get(key)) if value is None: return None @@ -113,7 +112,7 @@ def get(self, key: str) -> Optional[M]: def delete(self, key: str) -> None: self._r.delete(key) - def refresh(self, key: str, *, expires_after: Optional[int] = None) -> None: + def refresh(self, key: str, *, expires_after: int | None = None) -> None: if expires_after is None: self._r.persist(key) else: diff --git a/ragna/deploy/_orm.py b/ragna/deploy/_orm.py index 543649a17..50149f8e3 100644 --- a/ragna/deploy/_orm.py +++ b/ragna/deploy/_orm.py @@ -1,6 +1,6 @@ import json from datetime import datetime, timezone -from typing import Any, Optional +from typing import Any from sqlalchemy import Column, ForeignKey, Table, types from sqlalchemy.engine import Dialect @@ -20,13 +20,13 @@ class Json(types.TypeDecorator): cache_ok = True - def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: + def process_bind_param(self, value: Any, dialect: Dialect) -> str | None: if value is None: return value return json.dumps(value) - def process_result_value(self, value: Optional[str], dialect: Dialect) -> Any: + def process_result_value(self, value: str | None, dialect: Dialect) -> Any: if value is None: return value @@ -45,16 +45,16 @@ class UtcAwareDateTime(types.TypeDecorator): cache_ok = True def process_bind_param( # type: ignore[override] - self, value: Optional[datetime], dialect: Dialect - ) -> Optional[datetime]: + self, value: datetime | None, dialect: Dialect + ) -> datetime | None: if value is not None: assert value.tzinfo == timezone.utc return value def process_result_value( - self, value: Optional[datetime], dialect: Dialect - ) -> Optional[datetime]: + self, value: datetime | None, dialect: Dialect + ) -> datetime | None: if value is None: return None diff --git a/ragna/deploy/_ui/central_view.py b/ragna/deploy/_ui/central_view.py index 4c1b35701..b5f025146 100644 --- a/ragna/deploy/_ui/central_view.py +++ b/ragna/deploy/_ui/central_view.py @@ -1,7 +1,8 @@ from __future__ import annotations import functools -from typing import Callable, Literal, Optional, cast +from collections.abc import Callable +from typing import Literal, cast import panel as pn import param @@ -52,8 +53,8 @@ def __init__( *, role: Literal["system", "user", "assistant"], user: str, - sources: Optional[list[dict]] = None, - on_click_source_info_callback: Optional[Callable] = None, + sources: list[dict] | None = None, + on_click_source_info_callback: Callable | None = None, timestamp=None, show_timestamp=True, assistant_toolbar_visible=True, # hide the toolbar during streaming @@ -138,7 +139,7 @@ def _update_placeholder(self): avatar_lookup=self.avatar_lookup, ) - def _build_message(self, *args, **kwargs) -> Optional[RagnaChatMessage]: + def _build_message(self, *args, **kwargs) -> RagnaChatMessage | None: message = super()._build_message(*args, **kwargs) if message is None: return None diff --git a/ragna/deploy/_ui/modal_configuration.py b/ragna/deploy/_ui/modal_configuration.py index aea7d88f2..0cfbebfff 100644 --- a/ragna/deploy/_ui/modal_configuration.py +++ b/ragna/deploy/_ui/modal_configuration.py @@ -1,5 +1,5 @@ +from collections.abc import AsyncIterator from datetime import datetime, timedelta, timezone -from typing import AsyncIterator import panel as pn import param diff --git a/ragna/deploy/_ui/styles.py b/ragna/deploy/_ui/styles.py index a22bbc8c6..53f443b04 100644 --- a/ragna/deploy/_ui/styles.py +++ b/ragna/deploy/_ui/styles.py @@ -1,6 +1,7 @@ """UI Helpers""" -from typing import Any, Dict, Iterable, Union +from collections.abc import Iterable +from typing import Any import panel as pn @@ -70,7 +71,7 @@ def apply_design_modifiers(): def add_modifier( modifier_class: pn.viewable.Viewable, - modifications: Dict[str, Any], + modifications: dict[str, Any], property: str = "stylesheets", ): properties = pn.theme.fast.Fast.modifiers.setdefault(modifier_class, {}) @@ -87,7 +88,7 @@ def divider(): return pn.layout.Divider(css_classes=["default_divider"]) -def css(selector: Union[str, Iterable[str]], declarations: dict[str, str]) -> str: +def css(selector: str | Iterable[str], declarations: dict[str, str]) -> str: return "\n".join( [ f"{selector if isinstance(selector, str) else ', '.join(selector)} {{", diff --git a/ragna/deploy/_utils.py b/ragna/deploy/_utils.py index 4f369a523..d04cf082d 100644 --- a/ragna/deploy/_utils.py +++ b/ragna/deploy/_utils.py @@ -1,4 +1,3 @@ -from typing import Optional from urllib.parse import SplitResult, urlsplit, urlunsplit from fastapi import status @@ -6,7 +5,7 @@ from ragna.core import RagnaException -_REDIRECT_ROOT_PATH: Optional[str] = None +_REDIRECT_ROOT_PATH: str | None = None def set_redirect_root_path(root_path: str) -> None: diff --git a/ragna/source_storages/_chroma.py b/ragna/source_storages/_chroma.py index 38ca3ffa6..1ef372fd7 100644 --- a/ragna/source_storages/_chroma.py +++ b/ragna/source_storages/_chroma.py @@ -2,7 +2,7 @@ import uuid from collections import defaultdict -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast import ragna from ragna.core import Document, MetadataFilter, MetadataOperator, Source @@ -75,7 +75,7 @@ def _get_collection( raise_non_existing_corpus(self, corpus_name) def list_metadata( - self, corpus_name: Optional[str] = None + self, corpus_name: str | None = None ) -> dict[str, dict[str, tuple[str, list[Any]]]]: corpus_names = self.list_corpuses() if corpus_name is None else [corpus_name] @@ -155,8 +155,8 @@ def store( } def _translate_metadata_filter( - self, metadata_filter: Optional[MetadataFilter] - ) -> Optional[dict[str, Any]]: + self, metadata_filter: MetadataFilter | None + ) -> dict[str, Any] | None: if metadata_filter is None: return None if metadata_filter.operator is MetadataOperator.RAW: @@ -183,7 +183,7 @@ def _translate_metadata_filter( def retrieve( self, corpus_name: str, - metadata_filter: Optional[MetadataFilter], + metadata_filter: MetadataFilter | None, prompt: str, *, chunk_size: int = 500, diff --git a/ragna/source_storages/_demo.py b/ragna/source_storages/_demo.py index 94843e581..bb8648298 100644 --- a/ragna/source_storages/_demo.py +++ b/ragna/source_storages/_demo.py @@ -2,7 +2,8 @@ import textwrap import uuid from collections import defaultdict -from typing import Any, Callable, Optional, cast +from collections.abc import Callable +from typing import Any, cast from ragna.core import ( Document, @@ -53,7 +54,7 @@ def _get_corpus( return corpus def list_metadata( - self, corpus_name: Optional[str] = None + self, corpus_name: str | None = None ) -> dict[str, dict[str, tuple[str, list[Any]]]]: corpus_names = self.list_corpuses() if corpus_name is None else [corpus_name] @@ -111,7 +112,7 @@ def store(self, corpus_name: str, documents: list[Document]) -> None: } def _apply_filter( - self, corpus: list[dict[str, Any]], metadata_filter: Optional[MetadataFilter] + self, corpus: list[dict[str, Any]], metadata_filter: MetadataFilter | None ) -> list[tuple[int, dict[str, Any]]]: if metadata_filter is None: return list(enumerate(corpus)) diff --git a/ragna/source_storages/_lancedb.py b/ragna/source_storages/_lancedb.py index fc921f280..80cd4b12c 100644 --- a/ragna/source_storages/_lancedb.py +++ b/ragna/source_storages/_lancedb.py @@ -2,7 +2,7 @@ import uuid from collections import defaultdict -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast import ragna from ragna.core import ( @@ -91,7 +91,7 @@ def _get_table( return self._db.open_table(corpus_name) def list_metadata( - self, corpus_name: Optional[str] = None + self, corpus_name: str | None = None ) -> dict[str, dict[str, tuple[str, list[Any]]]]: corpus_names = self.list_corpuses() if corpus_name is None else [corpus_name] @@ -237,7 +237,7 @@ def _translate_metadata_filter(self, metadata_filter: MetadataFilter) -> str: def retrieve( self, corpus_name: str, - metadata_filter: Optional[MetadataFilter], + metadata_filter: MetadataFilter | None, prompt: str, *, chunk_size: int = 500, diff --git a/ragna/source_storages/_qdrant.py b/ragna/source_storages/_qdrant.py index 00d0d78c3..424ef8415 100644 --- a/ragna/source_storages/_qdrant.py +++ b/ragna/source_storages/_qdrant.py @@ -4,7 +4,8 @@ import os import uuid from collections import defaultdict -from typing import TYPE_CHECKING, Any, AsyncIterator, Optional, cast +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Any, cast import ragna from ragna.core import ( @@ -144,7 +145,7 @@ async def _fetch_metadata(self, corpus_name: str) -> dict[str, Any]: } async def list_metadata( - self, corpus_name: Optional[str] = None + self, corpus_name: str | None = None ) -> dict[str, dict[str, tuple[str, list[Any]]]]: if corpus_name is None: corpus_names = await self.list_corpuses() @@ -258,7 +259,7 @@ def _translate_metadata_filter( async def retrieve( self, corpus_name: str, - metadata_filter: Optional[MetadataFilter], + metadata_filter: MetadataFilter | None, prompt: str, *, chunk_size: int = 500, diff --git a/ragna/source_storages/_vector_database.py b/ragna/source_storages/_vector_database.py index 96186aa34..cad36ec20 100644 --- a/ragna/source_storages/_vector_database.py +++ b/ragna/source_storages/_vector_database.py @@ -2,7 +2,8 @@ import hashlib import itertools from collections import deque -from typing import Deque, Iterable, Iterator, Optional, TypeVar, cast +from collections.abc import Iterable, Iterator +from typing import TypeVar, cast from ragna.core import PackageRequirement, Page, Requirement, Source, SourceStorage @@ -14,7 +15,7 @@ def _windowed_ragged( iterable: Iterable[T], *, n: int, step: int ) -> Iterator[tuple[T, ...]]: - window: Deque[T] = deque(maxlen=n) + window: deque[T] = deque(maxlen=n) i = n for _ in map(window.append, iterable): i -= 1 @@ -31,7 +32,7 @@ def _windowed_ragged( @dataclasses.dataclass class Chunk: text: str - page_numbers: Optional[list[int]] + page_numbers: list[int] | None num_tokens: int @@ -87,7 +88,7 @@ def _chunk_pages( num_tokens=len(tokens), ) - def _page_numbers_to_str(self, page_numbers: Optional[Iterable[int]]) -> str: + def _page_numbers_to_str(self, page_numbers: Iterable[int] | None) -> str: if not page_numbers: return "" From dd2c5b2237678eef476cef8235c4a2d71641245f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 10:53:07 +0200 Subject: [PATCH 10/13] add ASYNC rule --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 9666966f0..5a1531bd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,6 +193,7 @@ select = [ "PTH", "D2", "UP", + "ASYNC", ] ignore = [ @@ -216,6 +217,8 @@ ignore = [ # The examples and tutorials need to have a good the narrative rather than follow our code-style rules "docs/examples/**/*.py" = ["E402", "F704", "I001", "D"] "docs/tutorials/**/*.py" = ["E402", "F704", "I001", "D"] +# blocking code in async tests is not an issue +"tests/**/*.py" = ["ASYNC101"] [tool.pytest.ini_options] minversion = "6.0" From a28b2c2e103b5f250f1b24626b87e8d9aef3e915 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 11:00:12 +0200 Subject: [PATCH 11/13] cleanup --- pixi.lock | 4 ++-- ragna/_cli/corpus.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pixi.lock b/pixi.lock index 8e3fa43ba..760b96347 100644 --- a/pixi.lock +++ b/pixi.lock @@ -11713,8 +11713,8 @@ packages: requires_python: '>=3.8' - pypi: ./ name: ragna - version: 0.4.0.dev38+gaf99c03.d20250620 - sha256: 26a22de2005bb9ece2d75fb16617c2c0abbe5787f5b438adc77fd6e6b202426c + version: 0.4.0.dev48+gdd2c5b2 + sha256: 4741f46a0227ab0caba7fad390ca56a01424724582c3141c8acc76bb142fa7cc requires_dist: - aiofiles - fastapi diff --git a/ragna/_cli/corpus.py b/ragna/_cli/corpus.py index 7ebee9ae2..b67a64745 100644 --- a/ragna/_cli/corpus.py +++ b/ragna/_cli/corpus.py @@ -69,7 +69,7 @@ def ingest( ] = False, ) -> None: try: - document_factory = config.document.from_path + document_factory = config.document.from_path # type: ignore[attr-defined] except AttributeError as exc: raise typer.BadParameter( f"{config.document.__name__} does not support creating documents from a" From 41c0de8f4d806d7781c8a3a32d0684447b3c4ade Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 11:12:10 +0200 Subject: [PATCH 12/13] cleanup docs --- docs/tutorials/gallery_rest_api.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/tutorials/gallery_rest_api.py b/docs/tutorials/gallery_rest_api.py index e69adf885..d9e057c89 100644 --- a/docs/tutorials/gallery_rest_api.py +++ b/docs/tutorials/gallery_rest_api.py @@ -117,9 +117,11 @@ # - The field value is the binary content of the document. with open(document_path, "rb") as f: - client.put( - "/api/documents", - files=[("documents", (documents[0]["id"], f))], + print( + client.put( + "/api/documents", + files=[("documents", (documents[0]["id"], f))], + ) ) # %% From 75bcdd131503702492030eb391895b03b39ec67c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 20 Jun 2025 11:35:53 +0200 Subject: [PATCH 13/13] select correct docstring convention --- pyproject.toml | 3 +++ ragna/_utils.py | 2 -- ragna/core/_components.py | 11 ----------- ragna/core/_document.py | 6 ------ ragna/core/_rag.py | 5 ----- 5 files changed, 3 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a1531bd5..50e55d1f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -181,6 +181,9 @@ depends-on = [ target-version = "py310" [tool.ruff.lint] + +pydocstyle = { convention = "google" } + select = [ "E", "F", diff --git a/ragna/_utils.py b/ragna/_utils.py index 435e106d2..37c94a645 100644 --- a/ragna/_utils.py +++ b/ragna/_utils.py @@ -40,11 +40,9 @@ def local_root(path: str | Path | None = None) -> Path: `~/.cache/ragna`. Args: - ---- path: If passed, this is set as new local root directory. Returns: - ------- Ragna's local root directory. """ diff --git a/ragna/core/_components.py b/ragna/core/_components.py index b7a02ed25..6add32144 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -103,7 +103,6 @@ class Source(pydantic.BaseModel): """Data class for sources stored inside a source storage. Attributes - ---------- id: Unique ID of the source. document: Document this source belongs to. location: Location of the source inside the document. @@ -131,7 +130,6 @@ def store(self, corpus_name: str, documents: list[Document]) -> None: """Store content of documents. Args: - ---- corpus_name: Name of the corpus to store the documents in. documents: Documents to store. @@ -145,13 +143,11 @@ def retrieve( """Retrieve sources for a given prompt. Args: - ---- corpus_name: Name of the corpus to retrieve sources from. metadata_filter: Filter to select available sources. prompt: Prompt to retrieve sources for. Returns: - ------- Matching sources for the given prompt ordered by relevance. """ @@ -161,7 +157,6 @@ def list_corpuses(self) -> list[str]: """List available corpuses. Returns - ------- List of available corpuses. """ @@ -178,11 +173,9 @@ def list_metadata( """List available metadata for corpuses. Args: - ---- corpus_name: Only return metadata for this corpus. Returns: - ------- List of available metadata. """ @@ -198,7 +191,6 @@ class MessageRole(str, enum.Enum): """Message role Attributes - ---------- SYSTEM: The message was produced by the system. This includes the welcome message when [preparing a new chat][ragna.core.Chat.prepare] as well as error messages. @@ -216,7 +208,6 @@ class Message: """Data class for messages. Attributes - ---------- role: The message producer. sources: The sources used to produce the message. @@ -303,12 +294,10 @@ def answer(self, messages: list[Message]) -> Iterator[str]: """Answer a prompt given the chat history. Args: - ---- messages: List of messages in the chat history. The last item is the current user prompt and has the relevant sources attached to it. Returns: - ------- Answer. """ diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 0d51f7597..ada68b58d 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -47,7 +47,6 @@ def get_handler(name: str) -> DocumentHandler: """Get a document handler based on a suffix. Args: - ---- name: Name of the document. """ @@ -101,7 +100,6 @@ def from_path( """Create a [ragna.core.LocalDocument][] from a path. Args: - ---- path: Local path to the file. id: ID of the document. If omitted, one is generated. name: Name of the document. If omitted, defaults to the name of the `path`. @@ -110,7 +108,6 @@ def from_path( on the suffix of the `path`. Raises: - ------ RagnaException: If `metadata` is passed and contains a `"path"` key. """ @@ -159,7 +156,6 @@ class Page(BaseModel): """Dataclass for pages of a document Attributes - ---------- text: Text included in the page. number: Page number. @@ -183,11 +179,9 @@ def extract_pages(self, document: Document) -> Iterator[Page]: """Extract pages from a document. Args: - ---- document: Document to extract pages from. Returns: - ------- Extracted pages. """ diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 1c15c6b9b..809bbb508 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -146,7 +146,6 @@ def chat( """Create a new [ragna.core.Chat][]. Args: - ---- input: Subject of the chat. Available options: - `None`: Use the full corpus of documents specified by `corpus_name`. @@ -208,7 +207,6 @@ class Chat: ``` Args: - ---- rag: The RAG workflow this chat is associated with. input: Subject of the chat. Available options: @@ -259,7 +257,6 @@ async def prepare(self) -> Message: source storage. Afterwards prompts can be [`answer`][ragna.core.Chat.answer]ed. Returns - ------- Welcome message. """ @@ -283,11 +280,9 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: """Answer a prompt. Returns - ------- Answer. Raises - ------ ragna.core.RagnaException: If chat is not [`prepare`][ragna.core.Chat.prepare]d.