Skip to content

Commit 59abbd5

Browse files
refactor(bigframes): Move AI ops to googlesql op framework
1 parent f93911c commit 59abbd5

45 files changed

Lines changed: 851 additions & 909 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

packages/bigframes/bigframes/bigquery/_operations/ai.py

Lines changed: 128 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from bigframes.core.logging import log_adapter
3333
from bigframes.ml import base as ml_base
3434
from bigframes.ml import core as ml_core
35-
from bigframes.operations import ai_ops, output_schemas
35+
from bigframes.operations import ai_ops, googlesql, output_schemas
3636

3737
PROMPT_TYPE = Union[
3838
str,
@@ -114,9 +114,6 @@ def generate(
114114
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
115115
"""
116116

117-
prompt_context, series_list = _separate_context_and_series(prompt)
118-
assert len(series_list) > 0
119-
120117
if output_schema is None:
121118
output_schema_str = None
122119
else:
@@ -126,17 +123,21 @@ def generate(
126123
# Validate user input
127124
output_schemas.parse_sql_fields(output_schema_str)
128125

129-
operator = ai_ops.AIGenerate(
130-
prompt_context=tuple(prompt_context),
131-
connection_id=connection_id,
132-
endpoint=endpoint,
133-
request_type=_upper_optional(request_type),
134-
model_params=json.dumps(model_params) if model_params else None,
135-
output_schema=output_schema_str,
126+
prompt_struct = _construct_prompt_struct(prompt)
127+
128+
op = googlesql.AIGenerateOp(output_schema=output_schema_str)
129+
return googlesql.apply_op(
130+
op,
131+
args=(prompt_struct,),
132+
kwargs={
133+
"connection_id": connection_id,
134+
"endpoint": endpoint,
135+
"request_type": _upper_optional(request_type),
136+
"model_params": json.dumps(model_params) if model_params else None,
137+
"output_schema": output_schema_str,
138+
},
136139
)
137140

138-
return series_list[0]._apply_nary_op(operator, series_list[1:])
139-
140141

141142
@log_adapter.method_logger(custom_base_name="bigquery_ai")
142143
def generate_bool(
@@ -201,19 +202,19 @@ def generate_bool(
201202
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
202203
"""
203204

204-
prompt_context, series_list = _separate_context_and_series(prompt)
205-
assert len(series_list) > 0
206-
207-
operator = ai_ops.AIGenerateBool(
208-
prompt_context=tuple(prompt_context),
209-
connection_id=connection_id,
210-
endpoint=endpoint,
211-
request_type=_upper_optional(request_type),
212-
model_params=json.dumps(model_params) if model_params else None,
205+
prompt_struct = _construct_prompt_struct(prompt)
206+
207+
return googlesql.apply_op(
208+
googlesql.AI_GENERATE_BOOL,
209+
args=(prompt_struct,),
210+
kwargs={
211+
"connection_id": connection_id,
212+
"endpoint": endpoint,
213+
"request_type": _upper_optional(request_type),
214+
"model_params": json.dumps(model_params) if model_params else None,
215+
},
213216
)
214217

215-
return series_list[0]._apply_nary_op(operator, series_list[1:])
216-
217218

218219
@log_adapter.method_logger(custom_base_name="bigquery_ai")
219220
def generate_int(
@@ -275,19 +276,19 @@ def generate_int(
275276
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
276277
"""
277278

278-
prompt_context, series_list = _separate_context_and_series(prompt)
279-
assert len(series_list) > 0
280-
281-
operator = ai_ops.AIGenerateInt(
282-
prompt_context=tuple(prompt_context),
283-
connection_id=connection_id,
284-
endpoint=endpoint,
285-
request_type=_upper_optional(request_type),
286-
model_params=json.dumps(model_params) if model_params else None,
279+
prompt_struct = _construct_prompt_struct(prompt)
280+
281+
return googlesql.apply_op(
282+
googlesql.AI_GENERATE_INT,
283+
args=(prompt_struct,),
284+
kwargs={
285+
"connection_id": connection_id,
286+
"endpoint": endpoint,
287+
"request_type": _upper_optional(request_type),
288+
"model_params": json.dumps(model_params) if model_params else None,
289+
},
287290
)
288291

289-
return series_list[0]._apply_nary_op(operator, series_list[1:])
290-
291292

292293
@log_adapter.method_logger(custom_base_name="bigquery_ai")
293294
def generate_double(
@@ -349,19 +350,19 @@ def generate_double(
349350
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
350351
"""
351352

352-
prompt_context, series_list = _separate_context_and_series(prompt)
353-
assert len(series_list) > 0
354-
355-
operator = ai_ops.AIGenerateDouble(
356-
prompt_context=tuple(prompt_context),
357-
connection_id=connection_id,
358-
endpoint=endpoint,
359-
request_type=_upper_optional(request_type),
360-
model_params=json.dumps(model_params) if model_params else None,
353+
prompt_struct = _construct_prompt_struct(prompt)
354+
355+
return googlesql.apply_op(
356+
googlesql.AI_GENERATE_DOUBLE,
357+
args=(prompt_struct,),
358+
kwargs={
359+
"connection_id": connection_id,
360+
"endpoint": endpoint,
361+
"request_type": _upper_optional(request_type),
362+
"model_params": json.dumps(model_params) if model_params else None,
363+
},
361364
)
362365

363-
return series_list[0]._apply_nary_op(operator, series_list[1:])
364-
365366

366367
@log_adapter.method_logger(custom_base_name="bigquery_ai")
367368
def generate_embedding(
@@ -751,24 +752,19 @@ def embed(
751752
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
752753
"""
753754

754-
operator = ai_ops.AIEmbed(
755-
endpoint=endpoint,
756-
model=model,
757-
task_type=_upper_optional(task_type),
758-
title=title,
759-
model_params=json.dumps(model_params) if model_params else None,
760-
connection_id=connection_id,
755+
return googlesql.apply_op(
756+
googlesql.AI_EMBED,
757+
args=(content,),
758+
kwargs={
759+
"endpoint": endpoint,
760+
"model": model,
761+
"task_type": _upper_optional(task_type),
762+
"title": title,
763+
"model_params": json.dumps(model_params) if model_params else None,
764+
"connection_id": connection_id,
765+
},
761766
)
762767

763-
if isinstance(content, str):
764-
return series.Series([content])._apply_unary_op(operator)
765-
elif isinstance(content, pd.Series):
766-
return series.Series(content)._apply_unary_op(operator)
767-
elif isinstance(content, series.Series):
768-
return content._apply_unary_op(operator)
769-
else:
770-
raise ValueError(f"Unsupported 'content' parameter type: {type(content)}")
771-
772768

773769
@log_adapter.method_logger(custom_base_name="bigquery_ai")
774770
def if_(
@@ -824,19 +820,19 @@ def if_(
824820
bigframes.series.Series: A new series of bools.
825821
"""
826822

827-
prompt_context, series_list = _separate_context_and_series(prompt)
828-
assert len(series_list) > 0
829-
830-
operator = ai_ops.AIIf(
831-
prompt_context=tuple(prompt_context),
832-
connection_id=connection_id,
833-
endpoint=endpoint,
834-
optimization_mode=_upper_optional(optimization_mode),
835-
max_error_ratio=max_error_ratio,
823+
prompt_struct = _construct_prompt_struct(prompt)
824+
825+
return googlesql.apply_op(
826+
googlesql.AI_IF,
827+
args=(prompt_struct,),
828+
kwargs={
829+
"connection_id": connection_id,
830+
"endpoint": endpoint,
831+
"optimization_mode": _upper_optional(optimization_mode),
832+
"max_error_ratio": max_error_ratio,
833+
},
836834
)
837835

838-
return series_list[0]._apply_nary_op(operator, series_list[1:])
839-
840836

841837
@log_adapter.method_logger(custom_base_name="bigquery_ai")
842838
def classify(
@@ -901,30 +897,30 @@ def classify(
901897
bigframes.series.Series: A new series of strings (or a series of arrays of strings if ``output_mode`` is specified).
902898
"""
903899

904-
prompt_context, series_list = _separate_context_and_series(input)
905-
assert len(series_list) > 0
906-
907900
if examples is not None:
908-
example_tuples: Any = tuple(
909-
(ex[0], tuple(ex[1]) if isinstance(ex[1], (list, tuple)) else ex[1])
901+
formatted_examples = [
902+
{
903+
"input": ex[0],
904+
"output": list(ex[1]) if isinstance(ex[1], (list, tuple)) else ex[1],
905+
}
910906
for ex in examples
911-
)
907+
]
912908
else:
913-
example_tuples = None
914-
915-
operator = ai_ops.AIClassify(
916-
prompt_context=tuple(prompt_context),
917-
categories=tuple(categories),
918-
examples=example_tuples,
919-
connection_id=connection_id,
920-
endpoint=endpoint,
921-
output_mode=output_mode,
922-
optimization_mode=_upper_optional(optimization_mode),
923-
max_error_ratio=max_error_ratio,
909+
formatted_examples = None
910+
911+
return googlesql.apply_op(
912+
googlesql.AI_CLASSIFY,
913+
args=(input, categories),
914+
kwargs={
915+
"examples": formatted_examples,
916+
"connection_id": connection_id,
917+
"endpoint": endpoint,
918+
"output_mode": output_mode,
919+
"optimization_mode": _upper_optional(optimization_mode),
920+
"max_error_ratio": max_error_ratio,
921+
},
924922
)
925923

926-
return series_list[0]._apply_nary_op(operator, series_list[1:])
927-
928924

929925
@log_adapter.method_logger(custom_base_name="bigquery_ai")
930926
def score(
@@ -970,18 +966,18 @@ def score(
970966
bigframes.series.Series: A new series of double (float) values.
971967
"""
972968

973-
prompt_context, series_list = _separate_context_and_series(prompt)
974-
assert len(series_list) > 0
969+
prompt_struct = _construct_prompt_struct(prompt)
975970

976-
operator = ai_ops.AIScore(
977-
prompt_context=tuple(prompt_context),
978-
connection_id=connection_id,
979-
endpoint=endpoint,
980-
max_error_ratio=max_error_ratio,
971+
return googlesql.apply_op(
972+
googlesql.AI_SCORE,
973+
args=(prompt_struct,),
974+
kwargs={
975+
"connection_id": connection_id,
976+
"endpoint": endpoint,
977+
"max_error_ratio": max_error_ratio,
978+
},
981979
)
982980

983-
return series_list[0]._apply_nary_op(operator, series_list[1:])
984-
985981

986982
@log_adapter.method_logger(custom_base_name="bigquery_ai")
987983
def similarity(
@@ -1026,36 +1022,18 @@ def similarity(
10261022
bigframes.series.Series: A new series of FLOAT64 values representing the cosine similarity.
10271023
"""
10281024

1029-
operator = ai_ops.AISimilarity(
1030-
endpoint=endpoint,
1031-
model=model,
1032-
model_params=json.dumps(model_params) if model_params else None,
1033-
connection_id=connection_id,
1025+
return googlesql.apply_op(
1026+
googlesql.AI_SIMILARITY,
1027+
kwargs={
1028+
"content1": content1,
1029+
"content2": content2,
1030+
"endpoint": endpoint,
1031+
"model": model,
1032+
"model_params": json.dumps(model_params) if model_params else None,
1033+
"connection_id": connection_id,
1034+
},
10341035
)
10351036

1036-
# Find a unifying session for the subsequent operations.
1037-
bf_session = None
1038-
if isinstance(content1, series.Series):
1039-
bf_session = content1._session
1040-
elif isinstance(content2, series.Series):
1041-
bf_session = content2._session
1042-
1043-
if isinstance(content1, str) and isinstance(content2, str):
1044-
content1 = series.Series([content1], session=bf_session)
1045-
return content1._apply_binary_op(content2, operator)
1046-
elif isinstance(content1, str):
1047-
# content2 must be a series
1048-
content2 = convert.to_bf_series(
1049-
content2, default_index=None, session=bf_session
1050-
)
1051-
return content2._apply_binary_op(content1, operator)
1052-
else:
1053-
# content1 must be a series.
1054-
content1 = convert.to_bf_series(
1055-
content1, default_index=None, session=bf_session
1056-
)
1057-
return content1._apply_binary_op(content2, operator)
1058-
10591037

10601038
@log_adapter.method_logger(custom_base_name="bigquery_ai")
10611039
def forecast(
@@ -1246,3 +1224,24 @@ def _upper_optional(value: str | None) -> str | None:
12461224
if value is None:
12471225
return None
12481226
return value.upper()
1227+
1228+
1229+
def _construct_prompt_struct(prompt: PROMPT_TYPE) -> series.Series:
1230+
prompt_context, series_list = _separate_context_and_series(prompt)
1231+
assert len(series_list) > 0
1232+
1233+
prompt_elements = []
1234+
series_idx = 0
1235+
for elem in prompt_context:
1236+
if elem is None:
1237+
prompt_elements.append(series_list[series_idx])
1238+
series_idx += 1
1239+
else:
1240+
prompt_elements.append(elem)
1241+
1242+
import bigframes.operations.generic_ops as generic_ops
1243+
struct_names = tuple(f"_field_{i+1}" for i in range(len(prompt_elements)))
1244+
return googlesql.apply_op(
1245+
generic_ops.StructOp(column_names=struct_names),
1246+
args=prompt_elements,
1247+
)

0 commit comments

Comments
 (0)