Skip to content

Commit e34b4ab

Browse files
more work
1 parent a26aeba commit e34b4ab

4 files changed

Lines changed: 29 additions & 158 deletions

File tree

packages/bigframes/bigframes/core/compile/substrait/compiler.py

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -39,30 +39,13 @@ class SubstraitCompiler:
3939
"""
4040
Compiles BigFrameNode plans to Substrait schema (JSON representation).
4141
"""
42-
43-
def _print_node_tree(self, node: bigframe_node.BigFrameNode, indent: int = 0):
44-
import sys
45-
try:
46-
ids = list(node.ids)
47-
except Exception as e:
48-
ids = f"<error: {e}>"
49-
sys.stderr.write(" " * indent + f"- {type(node).__name__}: ids={ids}\n")
50-
sys.stderr.flush()
51-
for child in node.child_nodes:
52-
self._print_node_tree(child, indent + 1)
53-
5442
def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]:
5543
"""
5644
Compiles a BigFrameNode to Substrait bytes (JSON encoded via protobuf).
5745
"""
5846
if not self.can_compile(plan):
5947
return None
6048

61-
import sys
62-
sys.stderr.write("DEBUG TREE:\n")
63-
sys.stderr.flush()
64-
self._print_node_tree(plan)
65-
6649
pb_rel = self._compile_node(plan)
6750

6851
pb_plan = plan_pb2.Plan()
@@ -84,7 +67,6 @@ def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]:
8467
def can_compile(self, plan: bigframe_node.BigFrameNode) -> bool:
8568
"""
8669
Checks if the plan can be compiled to Substrait.
87-
For the skeleton, we support ReadLocalNode, SelectionNode, and FilterNode.
8870
"""
8971
supported_nodes = (
9072
nodes.ReadLocalNode,
@@ -95,7 +77,6 @@ def can_compile(self, plan: bigframe_node.BigFrameNode) -> bool:
9577
nodes.JoinNode,
9678
nodes.AggregateNode,
9779
nodes.OrderByNode,
98-
nodes.PromoteOffsetsNode,
9980
nodes.WindowOpNode,
10081
nodes.ConcatNode,
10182
)
@@ -123,8 +104,6 @@ def _compile_node(self, node: bigframe_node.BigFrameNode) -> algebra_pb2.Rel:
123104
return self._compile_orderby(node)
124105
elif isinstance(node, nodes.SliceNode):
125106
return self._compile_slice(node)
126-
elif isinstance(node, nodes.PromoteOffsetsNode):
127-
return self._compile_promote_offsets(node)
128107
elif isinstance(node, nodes.WindowOpNode):
129108
return self._compile_window(node)
130109
elif isinstance(node, nodes.ConcatNode):
@@ -180,22 +159,6 @@ def _compile_selection(self, node: nodes.SelectionNode) -> algebra_pb2.Rel:
180159

181160
return rel
182161

183-
def _compile_promote_offsets(self, node: nodes.PromoteOffsetsNode) -> algebra_pb2.Rel:
184-
input_rel = self._compile_node(node.child)
185-
186-
rel = algebra_pb2.Rel()
187-
project_rel = rel.project
188-
project_rel.input.CopyFrom(input_rel)
189-
190-
# Add a dummy literal i64 = 0 for the offsets column
191-
expr = project_rel.expressions.add()
192-
expr.literal.i64 = 0
193-
194-
child_ids = list(node.child.ids)
195-
project_rel.common.emit.output_mapping.extend(range(len(child_ids) + 1))
196-
197-
return rel
198-
199162
def _compile_filter(self, node: nodes.FilterNode) -> algebra_pb2.Rel:
200163
input_rel = self._compile_node(node.child)
201164

@@ -684,6 +647,10 @@ def _compile_aggregate(self, node: nodes.AggregateNode) -> algebra_pb2.Rel:
684647
func_ref = self._EXTENSIONS["product"]
685648
elif isinstance(agg.op, agg_ops.MedianOp):
686649
func_ref = self._EXTENSIONS["median"]
650+
elif isinstance(agg.op, agg_ops.CovOp):
651+
func_ref = self._EXTENSIONS["cov"]
652+
elif isinstance(agg.op, agg_ops.CorrOp):
653+
func_ref = self._EXTENSIONS["corr"]
687654
else:
688655
raise NotImplementedError(f"Aggregation {type(agg.op)} not supported in Substrait compiler yet")
689656

@@ -846,6 +813,9 @@ def _compile_slice(self, node: nodes.SliceNode) -> algebra_pb2.Rel:
846813
"lead": 66,
847814
"struct": 67,
848815
"get_field": 68,
816+
"pow": 69,
817+
"cov": 70,
818+
"corr": 71,
849819
}
850820

851821
_OP_TO_EXTENSION = {
@@ -854,6 +824,8 @@ def _compile_slice(self, node: nodes.SliceNode) -> algebra_pb2.Rel:
854824
numeric_ops.MulOp: "multiply",
855825
numeric_ops.DivOp: "divide",
856826
numeric_ops.ModOp: "mod",
827+
numeric_ops.PowOp: "pow",
828+
numeric_ops.UnsafePowOp: "pow",
857829
comparison_ops.EqOp: "equal",
858830
comparison_ops.NeOp: "not_equal",
859831
comparison_ops.LtOp: "lt",
@@ -1156,6 +1128,8 @@ def _compile_fillna_op(self, op: generic_ops.FillNaOp, inputs: Sequence[ex.Expre
11561128
@_compile_op.register(numeric_ops.AddOp)
11571129
@_compile_op.register(numeric_ops.SubOp)
11581130
@_compile_op.register(numeric_ops.MulOp)
1131+
@_compile_op.register(numeric_ops.PowOp)
1132+
@_compile_op.register(numeric_ops.UnsafePowOp)
11591133
@_compile_op.register(comparison_ops.EqOp)
11601134
@_compile_op.register(comparison_ops.NeOp)
11611135
@_compile_op.register(comparison_ops.LtOp)

packages/bigframes/bigframes/session/substrait_executor.py

Lines changed: 4 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from bigframes.session import executor, semi_executor
2222
import bigframes.core.rewrite.slices as slices_rewrite
2323
from bigframes.core import nodes
24+
import asyncio
2425

2526
if TYPE_CHECKING:
2627
import pyarrow as pa
@@ -62,25 +63,14 @@ def consume(self, plan_proto: bytes, tables: dict[str, pa.Table]) -> pa.Table:
6263
"Install it with `pip install datafusion`."
6364
)
6465

65-
# Create a DataFusion context
6666
ctx = datafusion.SessionContext()
6767

6868
for name, table in tables.items():
6969
df = ctx.from_arrow_table(table)
7070
ctx.register_table(name, df)
7171

72-
# NOTE: The actual API for running Substrait in DataFusion python bindings may vary.
73-
# Assuming something like ctx.from_substrait(plan) or ctx.execute_substrait(plan).
74-
# We will need to verify this with the actual datafusion python package if available.
75-
# For now, we raise NotImplementedError if we cannot find the method, or try a likely one.
76-
7772
import datafusion.substrait
7873

79-
import substrait.plan_pb2 as plan_pb2
80-
from google.protobuf import json_format
81-
plan_obj = plan_pb2.Plan.FromString(plan_proto)
82-
print("DEBUG PLAN JSON:")
83-
print(json_format.MessageToJson(plan_obj))
8474
datafusion_substrait_plan = datafusion.substrait.Serde.deserialize_bytes(plan_proto)
8575
logical_plan = datafusion.substrait.Consumer.from_substrait_plan(ctx, datafusion_substrait_plan)
8676
df = ctx.create_dataframe_from_logical_plan(logical_plan)
@@ -98,62 +88,14 @@ def __init__(self, consumer: SubstraitConsumer):
9888
from bigframes.core.compile.substrait.compiler import SubstraitCompiler
9989
self._compiler = SubstraitCompiler()
10090

101-
def execute(
91+
async def execute(
10292
self,
10393
plan: bigframe_node.BigFrameNode,
10494
ordered: bool,
10595
peek: Optional[int] = None,
10696
) -> Optional[executor.ExecuteResult]:
107-
def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode:
108-
if isinstance(node, nodes.PromoteOffsetsNode):
109-
res = self.execute(node.child, ordered=ordered)
110-
if res is None:
111-
return node
112-
table = res.batches().to_arrow_table()
113-
import pyarrow as pa
114-
table = table.append_column(node.col_id.name, pa.array(range(len(table)), type=pa.int64()))
115-
116-
from bigframes.core import local_data, identifiers
117-
from bigframes.core.schema import ArraySchema, SchemaItem
118-
import bigframes.dtypes
119-
120-
schema_items = []
121-
for col_name in table.column_names:
122-
if col_name == node.col_id.name:
123-
schema_items.append(SchemaItem(col_name, bigframes.dtypes.INT_DTYPE))
124-
else:
125-
schema_items.append(SchemaItem(col_name, node.child.schema.get_type(col_name)))
126-
new_schema = ArraySchema(tuple(schema_items))
127-
128-
scan_items = []
129-
for col_name in table.column_names:
130-
col_id = identifiers.ColumnId(col_name)
131-
scan_items.append(nodes.ScanItem(col_id, col_name))
132-
scan_list = nodes.ScanList(tuple(scan_items))
133-
134-
session = None
135-
for child_node in node.child.unique_nodes():
136-
if isinstance(child_node, nodes.ReadLocalNode):
137-
session = child_node.session
138-
break
139-
140-
managed_table = local_data.ManagedArrowTable.from_pyarrow(table, schema=new_schema)
141-
new_node = nodes.ReadLocalNode(
142-
local_data_source=managed_table,
143-
scan_list=scan_list,
144-
session=session,
145-
offsets_col=None,
146-
)
147-
return new_node
148-
return node
149-
150-
# 1. Rewrite all SliceNodes to standard Selection/Filter/Projection/PromoteOffsetsNodes
15197
plan = plan.bottom_up(slices_rewrite.rewrite_slice)
15298

153-
# 2. Resolve all PromoteOffsetsNodes to concrete local tables
154-
plan = plan.bottom_up(resolve_promote_offsets)
155-
156-
# 3. Wrap plan in a ResultNode to apply defer_order
15799
from bigframes.core import expression, rewrite
158100
output_cols = tuple((expression.DerefOp(id), id.name) for id in plan.ids)
159101
result_node = nodes.ResultNode(
@@ -166,14 +108,12 @@ def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.B
166108

167109
rewritten_plan = result_node.child
168110

169-
# 4. Apply outermost sorting if ordered
170111
if ordered and result_node.order_by and result_node.order_by.all_ordering_columns:
171112
rewritten_plan = nodes.OrderByNode(
172113
rewritten_plan,
173114
by=tuple(result_node.order_by.all_ordering_columns),
174115
)
175116

176-
# 5. Project only the original output columns to preserve correct result schema
177117
original_ids = tuple(id for id in plan.ids)
178118
if rewritten_plan.ids != original_ids:
179119
rewritten_plan = nodes.SelectionNode(
@@ -188,17 +128,6 @@ def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.B
188128
if substrait_plan_proto is None:
189129
return None
190130

191-
import google.protobuf.json_format as json_format
192-
from substrait.plan_pb2 import Plan
193-
plan_proto = Plan()
194-
plan_proto.ParseFromString(substrait_plan_proto)
195-
import os
196-
import uuid
197-
os.makedirs("/usr/local/google/home/tbergeron/src/google-cloud-python/packages/bigframes/scratch", exist_ok=True)
198-
filename = f"/usr/local/google/home/tbergeron/src/google-cloud-python/packages/bigframes/scratch/plan_{rewritten_plan.__class__.__name__}_{uuid.uuid4().hex[:8]}.json"
199-
with open(filename, "w") as f:
200-
f.write(json_format.MessageToJson(plan_proto))
201-
202131
tables = {}
203132
for node in rewritten_plan.unique_nodes():
204133
if isinstance(node, nodes.ReadLocalNode):
@@ -211,52 +140,9 @@ def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.B
211140
table = pyarrow_utils.append_offsets(table, node.offsets_col.sql)
212141
tables[table_name] = table
213142

214-
pa_table = self._consumer.consume(substrait_plan_proto, tables)
215-
216-
# Sanitize pa_table: replace inf/nan/is_inf with null for INT_DTYPE columns
217-
import pyarrow.compute as pc
218-
import bigframes.dtypes as dtypes
219-
import pyarrow as pa
220-
sanitized_columns = []
221-
for col_name in pa_table.column_names:
222-
col_data = pa_table.column(col_name)
223-
try:
224-
expected_dtype = rewritten_plan.schema.get_type(col_name)
225-
except ValueError:
226-
expected_dtype = None
227-
228-
if expected_dtype == dtypes.INT_DTYPE and pa.types.is_floating(col_data.type):
229-
is_nan = pc.is_nan(col_data)
230-
is_inf = pc.is_inf(col_data)
231-
is_invalid = pc.or_(is_nan, is_inf)
232-
null_val = pa.scalar(None, type=col_data.type)
233-
col_data = pc.if_else(is_invalid, null_val, col_data)
234-
sanitized_columns.append(col_data)
235-
pa_table = pa.Table.from_arrays(sanitized_columns, names=pa_table.column_names)
236-
237-
# Handle SliceNode post-processing
238-
for node in rewritten_plan.unique_nodes():
239-
if isinstance(node, nodes.SliceNode):
240-
is_simple = (node.start is None or node.start >= 0) and (node.stop is None or node.stop >= 0) and (node.step is None or node.step == 1)
241-
if not is_simple:
242-
df = pa_table.to_pandas()
243-
df = df.iloc[node.start:node.stop:node.step]
244-
pa_table = pa.Table.from_pandas(df, schema=pa_table.schema)
245-
offset_cols = set()
246-
for node in rewritten_plan.unique_nodes():
247-
if isinstance(node, nodes.PromoteOffsetsNode):
248-
offset_cols.add(node.col_id.name)
249-
250-
for col_name in pa_table.column_names:
251-
if col_name in offset_cols:
252-
idx = pa_table.column_names.index(col_name)
253-
pa_table = pa_table.set_column(idx, col_name, pa.array(range(len(pa_table)), type=pa.int64()))
254-
255-
import sys
256-
sys.stderr.write(f"PA_TABLE ON EXECUTE:\n{pa_table.to_pandas()}\n")
257-
sys.stderr.flush()
143+
pa_table = await asyncio.to_thread(self._consumer.consume, substrait_plan_proto, tables)
258144

259-
if peek is not None:
145+
if peek is not None:
260146
pa_table = pa_table.slice(0, peek)
261147

262148
return executor.LocalExecuteResult(

packages/bigframes/tests/system/small/engines/conftest.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
local_scan_executor,
2727
polars_executor,
2828
semi_executor,
29+
substrait_executor,
2930
)
3031

3132
CURRENT_DIR = pathlib.Path(__file__).parent
@@ -81,9 +82,17 @@ def sqlglot_engine(
8182
)
8283

8384

84-
@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq", "bq-sqlglot"])
85+
@pytest.fixture(scope="session")
86+
def substrait_datafusion_engine(
87+
) -> semi_executor.SemiExecutor:
88+
return substrait_executor.SubstraitExecutor(
89+
consumer = substrait_executor.DataFusionSubstraitConsumer()
90+
)
91+
92+
93+
@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq", "bq-sqlglot", "substrait-datafusion"])
8594
def engine(
86-
request, pyarrow_engine, polars_engine, bq_engine, sqlglot_engine
95+
request, pyarrow_engine, polars_engine, bq_engine, sqlglot_engine, substrait_datafusion_engine
8796
) -> semi_executor.SemiExecutor:
8897
if request.param == "pyarrow":
8998
return pyarrow_engine
@@ -93,6 +102,8 @@ def engine(
93102
return bq_engine
94103
if request.param == "bq-sqlglot":
95104
return sqlglot_engine
105+
if request.param == "substrait-datafusion":
106+
return substrait_datafusion_engine
96107
raise ValueError(f"Unrecognized param: {request.param}")
97108

98109

packages/bigframes/tests/system/small/engines/test_aggregation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def apply_agg_to_all_valid(
5555
return new_arr
5656

5757

58-
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
58+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot", "substrait-datafusion"], indirect=True)
5959
def test_engines_aggregate_post_filter_size(
6060
scalars_array_value: array_value.ArrayValue,
6161
engine,

0 commit comments

Comments
 (0)