Skip to content

Commit c1f523d

Browse files
committed
- Added translate API
- bug fix during profile list
1 parent 121cfe6 commit c1f523d

File tree

5 files changed

+93
-29
lines changed

5 files changed

+93
-29
lines changed

src/select_ai/action.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ class Action(StrEnum):
2121
SHOWPROMPT = "showprompt"
2222
FEEDBACK = "feedback"
2323
SUMMARIZE = "summarize"
24+
TRANSLATE = "translate"

src/select_ai/async_profile.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import json
99
from contextlib import asynccontextmanager
10-
from dataclasses import replace as dataclass_replace
1110
from typing import (
1211
AsyncGenerator,
1312
List,
@@ -32,7 +31,6 @@
3231
from select_ai.db import async_cursor, async_get_connection
3332
from select_ai.errors import (
3433
ProfileAttributesEmptyError,
35-
ProfileExistsError,
3634
ProfileNotFoundError,
3735
)
3836
from select_ai.feedback import (
@@ -73,14 +71,14 @@ async def _init_profile(self):
7371
if self.profile_name:
7472
profile_exists = False
7573
try:
76-
saved_attributes = await self._get_attributes(
77-
profile_name=self.profile_name,
78-
raise_on_empty=True,
79-
)
8074
saved_description = await self._get_profile_description(
8175
profile_name=self.profile_name
8276
)
8377
profile_exists = True
78+
saved_attributes = await self._get_attributes(
79+
profile_name=self.profile_name,
80+
raise_on_empty=True,
81+
)
8482
self._raise_error_if_profile_exists()
8583
except ProfileAttributesEmptyError:
8684
if self.raise_error_on_empty_attributes:
@@ -624,6 +622,34 @@ async def run_pipeline(
624622
responses.append(result.error)
625623
return responses
626624

625+
async def translate(
626+
self, text: str, source_language: str, target_language: str
627+
) -> Union[str, None]:
628+
"""
629+
Translate a text using a source language and a target language
630+
631+
:param str text: Text to translate
632+
:param str source_language: Source language
633+
:param str target_language: Target language
634+
:return: str
635+
"""
636+
parameters = {
637+
"profile_name": self.profile_name,
638+
"text": text,
639+
"source_language": source_language,
640+
"target_language": target_language,
641+
}
642+
async with async_cursor() as cr:
643+
data = await cr.callfunc(
644+
"DBMS_CLOUD_AI.TRANSLATE",
645+
oracledb.DB_TYPE_CLOB,
646+
keyword_parameters=parameters,
647+
)
648+
if data is not None:
649+
result = await data.read()
650+
return result
651+
return None
652+
627653

628654
class AsyncSession:
629655
"""AsyncSession lets you persist request parameters across DBMS_CLOUD_AI

src/select_ai/profile.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import json
99
from contextlib import contextmanager
10-
from dataclasses import replace as dataclass_replace
1110
from typing import Generator, Mapping, Optional, Tuple, Union
1211

1312
import oracledb
@@ -25,7 +24,6 @@
2524
from select_ai.db import cursor
2625
from select_ai.errors import (
2726
ProfileAttributesEmptyError,
28-
ProfileExistsError,
2927
ProfileNotFoundError,
3028
)
3129
from select_ai.feedback import FeedbackOperation, FeedbackType
@@ -59,14 +57,14 @@ def _init_profile(self) -> None:
5957
if self.profile_name:
6058
profile_exists = False
6159
try:
62-
saved_attributes = self._get_attributes(
63-
profile_name=self.profile_name,
64-
raise_on_empty=True,
65-
)
6660
saved_description = self._get_profile_description(
6761
profile_name=self.profile_name
6862
)
6963
profile_exists = True
64+
saved_attributes = self._get_attributes(
65+
profile_name=self.profile_name,
66+
raise_on_empty=True,
67+
)
7068
self._raise_error_if_profile_exists()
7169
except ProfileAttributesEmptyError:
7270
if self.raise_error_on_empty_attributes:
@@ -551,6 +549,34 @@ def generate_synthetic_data(
551549
keyword_parameters=keyword_parameters,
552550
)
553551

552+
def translate(
553+
self, text: str, source_language: str, target_language: str
554+
) -> Union[str, None]:
555+
"""
556+
Translate a text using a source language and a target language
557+
558+
:param str text: Text to translate
559+
:param str source_language: Source language
560+
:param str target_language: Target language
561+
:return: str
562+
"""
563+
parameters = {
564+
"profile_name": self.profile_name,
565+
"text": text,
566+
"source_language": source_language,
567+
"target_language": target_language,
568+
}
569+
with cursor() as cr:
570+
data = cr.callfunc(
571+
"DBMS_CLOUD_AI.TRANSLATE",
572+
oracledb.DB_TYPE_CLOB,
573+
keyword_parameters=parameters,
574+
)
575+
if data is not None:
576+
result = data.read()
577+
return result
578+
return None
579+
554580

555581
class Session:
556582
"""Session lets you persist request parameters across DBMS_CLOUD_AI

tests/profiles/test_1200_profile.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,9 @@ def python_gen_ai_neg_feedback(cursor, python_gen_ai_profile):
7979
],
8080
)
8181
cursor.execute(
82-
f"""
83-
BEGIN
84-
dbms_cloud_ai.set_profile('{python_gen_ai_profile.profile_name}');
85-
END;
86-
"""
82+
f"""BEGIN
83+
dbms_cloud_ai.set_profile('{python_gen_ai_profile.profile_name}');
84+
END;"""
8785
)
8886
prompt = "Total points of each gymnasts"
8987
action = select_ai.Action.SHOWSQL
@@ -113,10 +111,9 @@ def python_gen_ai_pos_feedback(cursor, python_gen_ai_profile):
113111
["prompt", "action", "sql_text"],
114112
)
115113
cursor.execute(
116-
f"""
117-
BEGIN
118-
dbms_cloud_ai.set_profile('{python_gen_ai_profile.profile_name}');
119-
END;
114+
f"""BEGIN
115+
dbms_cloud_ai.set_profile('{python_gen_ai_profile.profile_name}');
116+
END;
120117
"""
121118
)
122119
prompt = "Lists the name of all people"
@@ -349,3 +346,11 @@ def test_1217(cursor, python_gen_ai_profile, python_gen_ai_pos_feedback):
349346
assert (
350347
feedback_attributes["sql_text"] == python_gen_ai_pos_feedback.sql_text
351348
)
349+
350+
351+
def test_1218(python_gen_ai_profile):
352+
"""Test translate"""
353+
response = python_gen_ai_profile.translate(
354+
text="Thank you", source_language="en", target_language="de"
355+
)
356+
assert response == "Danke"

tests/profiles/test_1300_profile_async.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,9 @@ async def python_gen_ai_neg_feedback(async_cursor, python_gen_ai_profile):
8383
],
8484
)
8585
await async_cursor.execute(
86-
f"""
87-
BEGIN
88-
dbms_cloud_ai.set_profile('{python_gen_ai_profile.profile_name}');
89-
END;
86+
f"""BEGIN
87+
dbms_cloud_ai.set_profile('{python_gen_ai_profile.profile_name}');
88+
END;
9089
"""
9190
)
9291
prompt = "Total points of each gymnasts"
@@ -117,10 +116,9 @@ async def python_gen_ai_pos_feedback(async_cursor, python_gen_ai_profile):
117116
["prompt", "action", "sql_text"],
118117
)
119118
await async_cursor.execute(
120-
f"""
121-
BEGIN
122-
dbms_cloud_ai.set_profile('{python_gen_ai_profile.profile_name}');
123-
END;
119+
f"""BEGIN
120+
dbms_cloud_ai.set_profile('{python_gen_ai_profile.profile_name}');
121+
END;
124122
"""
125123
)
126124
prompt = "Lists the name of all people"
@@ -370,3 +368,11 @@ async def test_1317(
370368
assert (
371369
feedback_attributes["sql_text"] == python_gen_ai_pos_feedback.sql_text
372370
)
371+
372+
373+
async def test_1318(python_gen_ai_profile):
374+
"""Test translate"""
375+
response = await python_gen_ai_profile.translate(
376+
text="Thank you", source_language="en", target_language="de"
377+
)
378+
assert response == "Danke"

0 commit comments

Comments
 (0)