3232from bigframes .core .logging import log_adapter
3333from bigframes .ml import base as ml_base
3434from bigframes .ml import core as ml_core
35- from bigframes .operations import ai_ops , output_schemas
35+ from bigframes .operations import ai_ops , googlesql , output_schemas
3636
3737PROMPT_TYPE = Union [
3838 str ,
@@ -114,9 +114,6 @@ def generate(
114114 * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
115115 """
116116
117- prompt_context , series_list = _separate_context_and_series (prompt )
118- assert len (series_list ) > 0
119-
120117 if output_schema is None :
121118 output_schema_str = None
122119 else :
@@ -126,17 +123,21 @@ def generate(
126123 # Validate user input
127124 output_schemas .parse_sql_fields (output_schema_str )
128125
129- operator = ai_ops .AIGenerate (
130- prompt_context = tuple (prompt_context ),
131- connection_id = connection_id ,
132- endpoint = endpoint ,
133- request_type = _upper_optional (request_type ),
134- model_params = json .dumps (model_params ) if model_params else None ,
135- output_schema = output_schema_str ,
126+ prompt_struct = _construct_prompt_struct (prompt )
127+
128+ op = googlesql .AIGenerateOp (output_schema = output_schema_str )
129+ return googlesql .apply_op (
130+ op ,
131+ args = (prompt_struct ,),
132+ kwargs = {
133+ "connection_id" : connection_id ,
134+ "endpoint" : endpoint ,
135+ "request_type" : _upper_optional (request_type ),
136+ "model_params" : json .dumps (model_params ) if model_params else None ,
137+ "output_schema" : output_schema_str ,
138+ },
136139 )
137140
138- return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
139-
140141
141142@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
142143def generate_bool (
@@ -201,19 +202,19 @@ def generate_bool(
201202 * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
202203 """
203204
204- prompt_context , series_list = _separate_context_and_series (prompt )
205- assert len (series_list ) > 0
206-
207- operator = ai_ops .AIGenerateBool (
208- prompt_context = tuple (prompt_context ),
209- connection_id = connection_id ,
210- endpoint = endpoint ,
211- request_type = _upper_optional (request_type ),
212- model_params = json .dumps (model_params ) if model_params else None ,
205+ prompt_struct = _construct_prompt_struct (prompt )
206+
207+ return googlesql .apply_op (
208+ googlesql .AI_GENERATE_BOOL ,
209+ args = (prompt_struct ,),
210+ kwargs = {
211+ "connection_id" : connection_id ,
212+ "endpoint" : endpoint ,
213+ "request_type" : _upper_optional (request_type ),
214+ "model_params" : json .dumps (model_params ) if model_params else None ,
215+ },
213216 )
214217
215- return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
216-
217218
218219@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
219220def generate_int (
@@ -275,19 +276,19 @@ def generate_int(
275276 * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
276277 """
277278
278- prompt_context , series_list = _separate_context_and_series (prompt )
279- assert len (series_list ) > 0
280-
281- operator = ai_ops .AIGenerateInt (
282- prompt_context = tuple (prompt_context ),
283- connection_id = connection_id ,
284- endpoint = endpoint ,
285- request_type = _upper_optional (request_type ),
286- model_params = json .dumps (model_params ) if model_params else None ,
279+ prompt_struct = _construct_prompt_struct (prompt )
280+
281+ return googlesql .apply_op (
282+ googlesql .AI_GENERATE_INT ,
283+ args = (prompt_struct ,),
284+ kwargs = {
285+ "connection_id" : connection_id ,
286+ "endpoint" : endpoint ,
287+ "request_type" : _upper_optional (request_type ),
288+ "model_params" : json .dumps (model_params ) if model_params else None ,
289+ },
287290 )
288291
289- return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
290-
291292
292293@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
293294def generate_double (
@@ -349,19 +350,19 @@ def generate_double(
349350 * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
350351 """
351352
352- prompt_context , series_list = _separate_context_and_series (prompt )
353- assert len (series_list ) > 0
354-
355- operator = ai_ops .AIGenerateDouble (
356- prompt_context = tuple (prompt_context ),
357- connection_id = connection_id ,
358- endpoint = endpoint ,
359- request_type = _upper_optional (request_type ),
360- model_params = json .dumps (model_params ) if model_params else None ,
353+ prompt_struct = _construct_prompt_struct (prompt )
354+
355+ return googlesql .apply_op (
356+ googlesql .AI_GENERATE_DOUBLE ,
357+ args = (prompt_struct ,),
358+ kwargs = {
359+ "connection_id" : connection_id ,
360+ "endpoint" : endpoint ,
361+ "request_type" : _upper_optional (request_type ),
362+ "model_params" : json .dumps (model_params ) if model_params else None ,
363+ },
361364 )
362365
363- return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
364-
365366
366367@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
367368def generate_embedding (
@@ -751,24 +752,19 @@ def embed(
751752 * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
752753 """
753754
754- operator = ai_ops .AIEmbed (
755- endpoint = endpoint ,
756- model = model ,
757- task_type = _upper_optional (task_type ),
758- title = title ,
759- model_params = json .dumps (model_params ) if model_params else None ,
760- connection_id = connection_id ,
755+ return googlesql .apply_op (
756+ googlesql .AI_EMBED ,
757+ args = (content ,),
758+ kwargs = {
759+ "endpoint" : endpoint ,
760+ "model" : model ,
761+ "task_type" : _upper_optional (task_type ),
762+ "title" : title ,
763+ "model_params" : json .dumps (model_params ) if model_params else None ,
764+ "connection_id" : connection_id ,
765+ },
761766 )
762767
763- if isinstance (content , str ):
764- return series .Series ([content ])._apply_unary_op (operator )
765- elif isinstance (content , pd .Series ):
766- return series .Series (content )._apply_unary_op (operator )
767- elif isinstance (content , series .Series ):
768- return content ._apply_unary_op (operator )
769- else :
770- raise ValueError (f"Unsupported 'content' parameter type: { type (content )} " )
771-
772768
773769@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
774770def if_ (
@@ -824,19 +820,19 @@ def if_(
824820 bigframes.series.Series: A new series of bools.
825821 """
826822
827- prompt_context , series_list = _separate_context_and_series (prompt )
828- assert len (series_list ) > 0
829-
830- operator = ai_ops .AIIf (
831- prompt_context = tuple (prompt_context ),
832- connection_id = connection_id ,
833- endpoint = endpoint ,
834- optimization_mode = _upper_optional (optimization_mode ),
835- max_error_ratio = max_error_ratio ,
823+ prompt_struct = _construct_prompt_struct (prompt )
824+
825+ return googlesql .apply_op (
826+ googlesql .AI_IF ,
827+ args = (prompt_struct ,),
828+ kwargs = {
829+ "connection_id" : connection_id ,
830+ "endpoint" : endpoint ,
831+ "optimization_mode" : _upper_optional (optimization_mode ),
832+ "max_error_ratio" : max_error_ratio ,
833+ },
836834 )
837835
838- return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
839-
840836
841837@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
842838def classify (
@@ -901,30 +897,30 @@ def classify(
901897 bigframes.series.Series: A new series of strings (or a series of arrays of strings if ``output_mode`` is specified).
902898 """
903899
904- prompt_context , series_list = _separate_context_and_series (input )
905- assert len (series_list ) > 0
906-
907900 if examples is not None :
908- example_tuples : Any = tuple (
909- (ex [0 ], tuple (ex [1 ]) if isinstance (ex [1 ], (list , tuple )) else ex [1 ])
901+ formatted_examples = [
902+ {
903+ "input" : ex [0 ],
904+ "output" : list (ex [1 ]) if isinstance (ex [1 ], (list , tuple )) else ex [1 ],
905+ }
910906 for ex in examples
911- )
907+ ]
912908 else :
913- example_tuples = None
914-
915- operator = ai_ops .AIClassify (
916- prompt_context = tuple (prompt_context ),
917- categories = tuple (categories ),
918- examples = example_tuples ,
919- connection_id = connection_id ,
920- endpoint = endpoint ,
921- output_mode = output_mode ,
922- optimization_mode = _upper_optional (optimization_mode ),
923- max_error_ratio = max_error_ratio ,
909+ formatted_examples = None
910+
911+ return googlesql .apply_op (
912+ googlesql .AI_CLASSIFY ,
913+ args = (input , categories ),
914+ kwargs = {
915+ "examples" : formatted_examples ,
916+ "connection_id" : connection_id ,
917+ "endpoint" : endpoint ,
918+ "output_mode" : output_mode ,
919+ "optimization_mode" : _upper_optional (optimization_mode ),
920+ "max_error_ratio" : max_error_ratio ,
921+ },
924922 )
925923
926- return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
927-
928924
929925@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
930926def score (
@@ -970,18 +966,18 @@ def score(
970966 bigframes.series.Series: A new series of double (float) values.
971967 """
972968
973- prompt_context , series_list = _separate_context_and_series (prompt )
974- assert len (series_list ) > 0
969+ prompt_struct = _construct_prompt_struct (prompt )
975970
976- operator = ai_ops .AIScore (
977- prompt_context = tuple (prompt_context ),
978- connection_id = connection_id ,
979- endpoint = endpoint ,
980- max_error_ratio = max_error_ratio ,
971+ return googlesql .apply_op (
972+ googlesql .AI_SCORE ,
973+ args = (prompt_struct ,),
974+ kwargs = {
975+ "connection_id" : connection_id ,
976+ "endpoint" : endpoint ,
977+ "max_error_ratio" : max_error_ratio ,
978+ },
981979 )
982980
983- return series_list [0 ]._apply_nary_op (operator , series_list [1 :])
984-
985981
986982@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
987983def similarity (
@@ -1026,36 +1022,18 @@ def similarity(
10261022 bigframes.series.Series: A new series of FLOAT64 values representing the cosine similarity.
10271023 """
10281024
1029- operator = ai_ops .AISimilarity (
1030- endpoint = endpoint ,
1031- model = model ,
1032- model_params = json .dumps (model_params ) if model_params else None ,
1033- connection_id = connection_id ,
1025+ return googlesql .apply_op (
1026+ googlesql .AI_SIMILARITY ,
1027+ kwargs = {
1028+ "content1" : content1 ,
1029+ "content2" : content2 ,
1030+ "endpoint" : endpoint ,
1031+ "model" : model ,
1032+ "model_params" : json .dumps (model_params ) if model_params else None ,
1033+ "connection_id" : connection_id ,
1034+ },
10341035 )
10351036
1036- # Find a unifying session for the subsequent operations.
1037- bf_session = None
1038- if isinstance (content1 , series .Series ):
1039- bf_session = content1 ._session
1040- elif isinstance (content2 , series .Series ):
1041- bf_session = content2 ._session
1042-
1043- if isinstance (content1 , str ) and isinstance (content2 , str ):
1044- content1 = series .Series ([content1 ], session = bf_session )
1045- return content1 ._apply_binary_op (content2 , operator )
1046- elif isinstance (content1 , str ):
1047- # content2 must be a series
1048- content2 = convert .to_bf_series (
1049- content2 , default_index = None , session = bf_session
1050- )
1051- return content2 ._apply_binary_op (content1 , operator )
1052- else :
1053- # content1 must be a series.
1054- content1 = convert .to_bf_series (
1055- content1 , default_index = None , session = bf_session
1056- )
1057- return content1 ._apply_binary_op (content2 , operator )
1058-
10591037
10601038@log_adapter .method_logger (custom_base_name = "bigquery_ai" )
10611039def forecast (
@@ -1246,3 +1224,24 @@ def _upper_optional(value: str | None) -> str | None:
12461224 if value is None :
12471225 return None
12481226 return value .upper ()
1227+
1228+
1229+ def _construct_prompt_struct (prompt : PROMPT_TYPE ) -> series .Series :
1230+ prompt_context , series_list = _separate_context_and_series (prompt )
1231+ assert len (series_list ) > 0
1232+
1233+ prompt_elements = []
1234+ series_idx = 0
1235+ for elem in prompt_context :
1236+ if elem is None :
1237+ prompt_elements .append (series_list [series_idx ])
1238+ series_idx += 1
1239+ else :
1240+ prompt_elements .append (elem )
1241+
1242+ import bigframes .operations .generic_ops as generic_ops
1243+ struct_names = tuple (f"_field_{ i + 1 } " for i in range (len (prompt_elements )))
1244+ return googlesql .apply_op (
1245+ generic_ops .StructOp (column_names = struct_names ),
1246+ args = prompt_elements ,
1247+ )
0 commit comments