2121from bigframes .session import executor , semi_executor
2222import bigframes .core .rewrite .slices as slices_rewrite
2323from bigframes .core import nodes
24+ import asyncio
2425
2526if TYPE_CHECKING :
2627 import pyarrow as pa
@@ -62,25 +63,14 @@ def consume(self, plan_proto: bytes, tables: dict[str, pa.Table]) -> pa.Table:
6263 "Install it with `pip install datafusion`."
6364 )
6465
65- # Create a DataFusion context
6666 ctx = datafusion .SessionContext ()
6767
6868 for name , table in tables .items ():
6969 df = ctx .from_arrow_table (table )
7070 ctx .register_table (name , df )
7171
72- # NOTE: The actual API for running Substrait in DataFusion python bindings may vary.
73- # Assuming something like ctx.from_substrait(plan) or ctx.execute_substrait(plan).
74- # We will need to verify this with the actual datafusion python package if available.
75- # For now, we raise NotImplementedError if we cannot find the method, or try a likely one.
76-
7772 import datafusion .substrait
7873
79- import substrait .plan_pb2 as plan_pb2
80- from google .protobuf import json_format
81- plan_obj = plan_pb2 .Plan .FromString (plan_proto )
82- print ("DEBUG PLAN JSON:" )
83- print (json_format .MessageToJson (plan_obj ))
8474 datafusion_substrait_plan = datafusion .substrait .Serde .deserialize_bytes (plan_proto )
8575 logical_plan = datafusion .substrait .Consumer .from_substrait_plan (ctx , datafusion_substrait_plan )
8676 df = ctx .create_dataframe_from_logical_plan (logical_plan )
@@ -98,62 +88,14 @@ def __init__(self, consumer: SubstraitConsumer):
9888 from bigframes .core .compile .substrait .compiler import SubstraitCompiler
9989 self ._compiler = SubstraitCompiler ()
10090
101- def execute (
91+ async def execute (
10292 self ,
10393 plan : bigframe_node .BigFrameNode ,
10494 ordered : bool ,
10595 peek : Optional [int ] = None ,
10696 ) -> Optional [executor .ExecuteResult ]:
107- def resolve_promote_offsets (node : bigframe_node .BigFrameNode ) -> bigframe_node .BigFrameNode :
108- if isinstance (node , nodes .PromoteOffsetsNode ):
109- res = self .execute (node .child , ordered = ordered )
110- if res is None :
111- return node
112- table = res .batches ().to_arrow_table ()
113- import pyarrow as pa
114- table = table .append_column (node .col_id .name , pa .array (range (len (table )), type = pa .int64 ()))
115-
116- from bigframes .core import local_data , identifiers
117- from bigframes .core .schema import ArraySchema , SchemaItem
118- import bigframes .dtypes
119-
120- schema_items = []
121- for col_name in table .column_names :
122- if col_name == node .col_id .name :
123- schema_items .append (SchemaItem (col_name , bigframes .dtypes .INT_DTYPE ))
124- else :
125- schema_items .append (SchemaItem (col_name , node .child .schema .get_type (col_name )))
126- new_schema = ArraySchema (tuple (schema_items ))
127-
128- scan_items = []
129- for col_name in table .column_names :
130- col_id = identifiers .ColumnId (col_name )
131- scan_items .append (nodes .ScanItem (col_id , col_name ))
132- scan_list = nodes .ScanList (tuple (scan_items ))
133-
134- session = None
135- for child_node in node .child .unique_nodes ():
136- if isinstance (child_node , nodes .ReadLocalNode ):
137- session = child_node .session
138- break
139-
140- managed_table = local_data .ManagedArrowTable .from_pyarrow (table , schema = new_schema )
141- new_node = nodes .ReadLocalNode (
142- local_data_source = managed_table ,
143- scan_list = scan_list ,
144- session = session ,
145- offsets_col = None ,
146- )
147- return new_node
148- return node
149-
150- # 1. Rewrite all SliceNodes to standard Selection/Filter/Projection/PromoteOffsetsNodes
15197 plan = plan .bottom_up (slices_rewrite .rewrite_slice )
15298
153- # 2. Resolve all PromoteOffsetsNodes to concrete local tables
154- plan = plan .bottom_up (resolve_promote_offsets )
155-
156- # 3. Wrap plan in a ResultNode to apply defer_order
15799 from bigframes .core import expression , rewrite
158100 output_cols = tuple ((expression .DerefOp (id ), id .name ) for id in plan .ids )
159101 result_node = nodes .ResultNode (
@@ -166,14 +108,12 @@ def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.B
166108
167109 rewritten_plan = result_node .child
168110
169- # 4. Apply outermost sorting if ordered
170111 if ordered and result_node .order_by and result_node .order_by .all_ordering_columns :
171112 rewritten_plan = nodes .OrderByNode (
172113 rewritten_plan ,
173114 by = tuple (result_node .order_by .all_ordering_columns ),
174115 )
175116
176- # 5. Project only the original output columns to preserve correct result schema
177117 original_ids = tuple (id for id in plan .ids )
178118 if rewritten_plan .ids != original_ids :
179119 rewritten_plan = nodes .SelectionNode (
@@ -188,17 +128,6 @@ def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.B
188128 if substrait_plan_proto is None :
189129 return None
190130
191- import google .protobuf .json_format as json_format
192- from substrait .plan_pb2 import Plan
193- plan_proto = Plan ()
194- plan_proto .ParseFromString (substrait_plan_proto )
195- import os
196- import uuid
197- os .makedirs ("/usr/local/google/home/tbergeron/src/google-cloud-python/packages/bigframes/scratch" , exist_ok = True )
198- filename = f"/usr/local/google/home/tbergeron/src/google-cloud-python/packages/bigframes/scratch/plan_{ rewritten_plan .__class__ .__name__ } _{ uuid .uuid4 ().hex [:8 ]} .json"
199- with open (filename , "w" ) as f :
200- f .write (json_format .MessageToJson (plan_proto ))
201-
202131 tables = {}
203132 for node in rewritten_plan .unique_nodes ():
204133 if isinstance (node , nodes .ReadLocalNode ):
@@ -211,52 +140,9 @@ def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.B
211140 table = pyarrow_utils .append_offsets (table , node .offsets_col .sql )
212141 tables [table_name ] = table
213142
214- pa_table = self ._consumer .consume (substrait_plan_proto , tables )
215-
216- # Sanitize pa_table: replace inf/nan/is_inf with null for INT_DTYPE columns
217- import pyarrow .compute as pc
218- import bigframes .dtypes as dtypes
219- import pyarrow as pa
220- sanitized_columns = []
221- for col_name in pa_table .column_names :
222- col_data = pa_table .column (col_name )
223- try :
224- expected_dtype = rewritten_plan .schema .get_type (col_name )
225- except ValueError :
226- expected_dtype = None
227-
228- if expected_dtype == dtypes .INT_DTYPE and pa .types .is_floating (col_data .type ):
229- is_nan = pc .is_nan (col_data )
230- is_inf = pc .is_inf (col_data )
231- is_invalid = pc .or_ (is_nan , is_inf )
232- null_val = pa .scalar (None , type = col_data .type )
233- col_data = pc .if_else (is_invalid , null_val , col_data )
234- sanitized_columns .append (col_data )
235- pa_table = pa .Table .from_arrays (sanitized_columns , names = pa_table .column_names )
236-
237- # Handle SliceNode post-processing
238- for node in rewritten_plan .unique_nodes ():
239- if isinstance (node , nodes .SliceNode ):
240- is_simple = (node .start is None or node .start >= 0 ) and (node .stop is None or node .stop >= 0 ) and (node .step is None or node .step == 1 )
241- if not is_simple :
242- df = pa_table .to_pandas ()
243- df = df .iloc [node .start :node .stop :node .step ]
244- pa_table = pa .Table .from_pandas (df , schema = pa_table .schema )
245- offset_cols = set ()
246- for node in rewritten_plan .unique_nodes ():
247- if isinstance (node , nodes .PromoteOffsetsNode ):
248- offset_cols .add (node .col_id .name )
249-
250- for col_name in pa_table .column_names :
251- if col_name in offset_cols :
252- idx = pa_table .column_names .index (col_name )
253- pa_table = pa_table .set_column (idx , col_name , pa .array (range (len (pa_table )), type = pa .int64 ()))
254-
255- import sys
256- sys .stderr .write (f"PA_TABLE ON EXECUTE:\n { pa_table .to_pandas ()} \n " )
257- sys .stderr .flush ()
143+ pa_table = await asyncio .to_thread (self ._consumer .consume , substrait_plan_proto , tables )
258144
259- if peek is not None :
145+ if peek is not None :
260146 pa_table = pa_table .slice (0 , peek )
261147
262148 return executor .LocalExecuteResult (
0 commit comments