Skip to content
Open
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ env:
jobs:
test-matrix:
runs-on: ubuntu-latest
timeout-minutes: 30
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now things pass in 4 minutes but I accidentally committed a change that deadlocked and would have run for 6 hours if I didn't push new commits. Seems like a nice guard to footguns

strategy:
fail-fast: false
matrix:
Expand Down
29 changes: 23 additions & 6 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,29 @@

"""Pytest configuration for doctest namespace injection."""

import datafusion as dfn
import numpy as np
import pyarrow as pa
import pytest
from datafusion import col, lit
from datafusion import functions as F
import sys
from pathlib import Path

# Ensure ``python/`` is reachable by ``import``. The ``tests`` package lives at
# ``python/tests`` and spawn-based multiprocessing tests need workers to be
# able to resolve ``tests._pickle_multiprocessing_helpers`` by its real dotted
# name when unpickling task args. Editable installs add this path via a .pth
# file, but the wheel install used in CI does not, which led to spawn workers
# dying with ``ModuleNotFoundError`` and ``Pool.map`` hanging.
#
# Append (don't prepend) so the wheel-installed ``datafusion`` in
# site-packages still wins over the source tree at ``python/datafusion`` —
# the source tree has no compiled ``_internal`` module on a fresh checkout.
_python_dir = str(Path(__file__).parent / "python")
if _python_dir not in sys.path:
sys.path.append(_python_dir)

import datafusion as dfn # noqa: E402
import numpy as np # noqa: E402
import pyarrow as pa # noqa: E402
import pytest # noqa: E402
from datafusion import col, lit # noqa: E402
from datafusion import functions as F # noqa: E402


@pytest.fixture(autouse=True)
Expand Down
15 changes: 13 additions & 2 deletions crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory;
use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
use datafusion_python_util::{
create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, get_global_ctx,
get_tokio_runtime, spawn_future, wait_for_future,
get_tokio_runtime, set_global_ctx, spawn_future, wait_for_future,
};
use object_store::ObjectStore;
use pyo3::IntoPyObjectExt;
Expand Down Expand Up @@ -407,11 +407,22 @@ impl PySessionContext {
#[staticmethod]
#[pyo3(signature = ())]
pub fn global_ctx() -> PyResult<Self> {
let ctx = get_global_ctx().clone();
let ctx = get_global_ctx();
let logical_codec = Self::default_logical_codec(&ctx);
Ok(Self { ctx, logical_codec })
}

/// Replace the process-wide global `SessionContext` with this one.
///
/// All subsequent callers of `SessionContext.global_ctx()` (and Rust
/// helpers that fall back to the global context, such as the
/// `read_parquet` / `read_csv` / etc. module-level helpers) will see this
/// context. Existing references already obtained from `global_ctx()` are
/// not affected.
pub fn set_as_global(&self) {
set_global_ctx(self.ctx.clone());
}

/// Register an object store with the given name
#[pyo3(signature = (scheme, store, host=None))]
pub fn register_object_store(
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ impl PyDataFrame {
Some(f) => f
.parse::<datafusion::common::format::ExplainFormat>()
.map_err(|e| {
PyDataFusionError::Common(format!("Invalid explain format '{}': {}", f, e))
PyDataFusionError::Common(format!("Invalid explain format '{f}': {e}"))
})?,
None => datafusion::common::format::ExplainFormat::Indent,
};
Expand Down
26 changes: 26 additions & 0 deletions crates/core/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@ use datafusion::logical_expr::{
Between, BinaryExpr, Case, Cast, Expr, ExprFuncBuilder, ExprFunctionExt, Like, LogicalPlan,
Operator, TryCast, WindowFunctionDefinition, col, lit, lit_with_metadata,
};
use datafusion_proto::bytes::Serializeable;
use datafusion_python_util::get_global_ctx;
use pyo3::IntoPyObjectExt;
use pyo3::basic::CompareOp;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use window::PyWindowFrame;

use self::alias::PyAlias;
Expand Down Expand Up @@ -256,6 +259,29 @@ impl PyExpr {
Ok(format!("Expr({})", self.expr))
}

/// Serialize the underlying expression to bytes via the `datafusion-proto`
/// wire format. Used by the Python `Expr` wrapper to implement
/// `__getstate__` / `__setstate__`; also exposed directly so callers can
/// persist or transmit expressions without going through `pickle`.
fn to_bytes<'py>(&self, py: Python<'py>) -> PyDataFusionResult<Bound<'py, PyBytes>> {
let bytes = self.expr.to_bytes()?;
Ok(PyBytes::new(py, &bytes))
}

/// Reconstruct a `RawExpr` from bytes produced by [`PyExpr::to_bytes`].
///
/// Function references (built-ins, UDFs, UDAFs, UDWFs) are resolved by
/// name against the process-wide global `SessionContext`. Built-in
/// functions are registered on every fresh context, so they always
/// roundtrip. To roundtrip user-defined functions, register them on a
/// context and call `SessionContext.set_as_global()` before unpickling.
#[staticmethod]
fn from_bytes(bytes: &[u8]) -> PyDataFusionResult<PyExpr> {
let ctx = get_global_ctx();
let expr = Expr::from_bytes_with_registry(bytes, ctx.as_ref())?;
Ok(expr.into())
}

fn __add__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
Ok((self.expr.clone() + rhs.expr).into())
}
Expand Down
65 changes: 60 additions & 5 deletions crates/util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use std::future::Future;
use std::ptr::NonNull;
use std::sync::{Arc, OnceLock};
use std::sync::{Arc, OnceLock, RwLock};
use std::time::Duration;

use datafusion::datasource::TableProvider;
Expand Down Expand Up @@ -59,11 +59,29 @@ pub fn is_ipython_env(py: Python) -> &'static bool {
})
}

/// Utility to get the Global Datafussion CTX
fn global_ctx_slot() -> &'static RwLock<Arc<SessionContext>> {
static CTX: OnceLock<RwLock<Arc<SessionContext>>> = OnceLock::new();
CTX.get_or_init(|| RwLock::new(Arc::new(SessionContext::new())))
}

/// Utility to get the Global DataFusion CTX.
///
/// Returns an owned `Arc<SessionContext>` snapshot. The underlying slot can be
/// replaced via [`set_global_ctx`]; existing snapshots are unaffected.
#[inline]
pub fn get_global_ctx() -> &'static Arc<SessionContext> {
static CTX: OnceLock<Arc<SessionContext>> = OnceLock::new();
CTX.get_or_init(|| Arc::new(SessionContext::new()))
pub fn get_global_ctx() -> Arc<SessionContext> {
global_ctx_slot()
.read()
.expect("global SessionContext lock poisoned")
.clone()
}

/// Replace the Global DataFusion CTX. Subsequent calls to [`get_global_ctx`]
/// will return the new context. Already-cloned `Arc`s are not affected.
pub fn set_global_ctx(ctx: Arc<SessionContext>) {
*global_ctx_slot()
.write()
.expect("global SessionContext lock poisoned") = ctx;
}

/// Utility to collect rust futures with GIL released and respond to
Expand Down Expand Up @@ -224,3 +242,40 @@ pub fn ffi_logical_codec_from_pycapsule(obj: Bound<PyAny>) -> PyResult<FFI_Logic

Ok(codec.clone())
}

#[cfg(test)]
mod tests {
use super::*;

/// The global slot must round-trip a custom `SessionContext`. Since the
/// global is process-wide, this test only asserts identity through a
/// single set/get cycle and restores the prior value at the end so the
/// test is independent of ordering with other tests in the binary.
#[test]
fn set_global_ctx_replaces_default() {
let prior = get_global_ctx();
let custom = Arc::new(SessionContext::new());
let custom_ptr = Arc::as_ptr(&custom);

set_global_ctx(custom.clone());
let observed = get_global_ctx();
assert_eq!(
Arc::as_ptr(&observed),
custom_ptr,
"get_global_ctx should return the context installed by set_global_ctx",
);

// A snapshot taken before the swap should be unaffected after another
// set_global_ctx call, because get_global_ctx clones the Arc.
let snapshot = get_global_ctx();
let replacement = Arc::new(SessionContext::new());
set_global_ctx(replacement);
assert_eq!(
Arc::as_ptr(&snapshot),
custom_ptr,
"previously cloned snapshots must not be invalidated by set_global_ctx",
);

set_global_ctx(prior);
}
}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ ignore-words-list = ["IST", "ans"]
dev = [
"arro3-core==0.6.5",
"codespell==2.4.1",
"dill>=0.3.8",
"maturin>=1.8.1",
"nanoarrow==0.8.0",
"numpy>1.25.0;python_version<'3.14'",
Expand All @@ -196,6 +197,7 @@ dev = [
"pyarrow>=19.0.0",
"pygithub==2.5.0",
"pytest-asyncio>=0.23.3",
"pytest-timeout>=2.3.1",
"pytest>=7.4.4",
"pyyaml>=6.0.3",
"ruff>=0.15.1",
Expand Down
16 changes: 16 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,22 @@ def global_ctx(cls) -> SessionContext:
wrapper.ctx = internal_ctx
return wrapper

def set_as_global(self) -> None:
"""Install this context as the process-wide global ``SessionContext``.

After this call, :meth:`SessionContext.global_ctx` (and the module-level
helpers in :mod:`datafusion.io` that fall back to the global context)
will return this context. Existing references already obtained from
``global_ctx()`` are not invalidated.

Example::

ctx = SessionContext()
ctx.register_udf(my_udf)
ctx.set_as_global()
"""
self.ctx.set_as_global()

def enable_url_table(self) -> SessionContext:
"""Control if local files can be queried as tables.

Expand Down
25 changes: 25 additions & 0 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,31 @@ def __init__(self, expr: expr_internal.RawExpr) -> None:
"""This constructor should not be called by the end user."""
self.expr = expr

def to_bytes(self) -> bytes:
"""Serialize this expression to bytes via the ``datafusion-proto`` wire format.

Function references (built-ins and UDFs/UDAFs/UDWFs) are encoded by
name; on :py:meth:`from_bytes` the names are resolved against the
process-wide global :py:class:`SessionContext`. Built-in functions
always roundtrip; for user-defined functions, register them on a
context and call :py:meth:`SessionContext.set_as_global` before
loading.
"""
return self.expr.to_bytes()

@classmethod
def from_bytes(cls, data: bytes) -> Expr:
"""Inverse of :py:meth:`to_bytes`. See that method for caveats."""
return cls(expr_internal.RawExpr.from_bytes(data))

def __getstate__(self) -> bytes:
"""Serialize for ``pickle`` / ``dill``. Delegates to :py:meth:`to_bytes`."""
return self.to_bytes()

def __setstate__(self, state: bytes) -> None:
"""Inverse of :py:meth:`__getstate__`."""
self.expr = expr_internal.RawExpr.from_bytes(state)

def to_variant(self) -> Any:
"""Convert this expression into a python object if possible."""
return self.expr.to_variant()
Expand Down
85 changes: 85 additions & 0 deletions python/tests/_pickle_multiprocessing_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Helpers for :mod:`test_pickle_multiprocessing`.

Spawn workers re-import the module that defines a pickled function by the
function's ``__module__`` attribute. Pytest's ``--import-mode=importlib``
loads test modules under synthetic names that the worker cannot resolve via
the normal import machinery, which can cause ``Pool.map`` to hang waiting
for a worker that died during unpickling.

Keeping the helpers in this regular (non-test) module side-steps that: it
is importable under its real dotted name (``tests._pickle_multiprocessing_helpers``)
in both parent and worker, and the leading underscore keeps pytest from
collecting it as a test module.
"""

from __future__ import annotations

import pyarrow as pa
import pyarrow.compute as pc
from datafusion import SessionContext, udf

UDF_NAME = "mp_pickle_add_ten"


def add_ten_impl(array: pa.Array) -> pa.Array:
return pc.add(array, 10)


def build_add_ten_udf():
return udf(
add_ten_impl,
[pa.int64()],
pa.int64(),
volatility="immutable",
name=UDF_NAME,
)


def register_udf_on_global_ctx() -> None:
"""Pool initializer: install a global ctx in the worker that knows the UDF.

``Expr.__setstate__`` resolves UDF references by name against the
*global* context, so the registration must happen before any task arg is
unpickled — i.e. in the Pool's ``initializer``, not in the task body.
"""
ctx = SessionContext()
ctx.register_udf(build_add_ten_udf())
ctx.set_as_global()


def apply_builtin_expr(args: tuple) -> list:
expr, values = args
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays([pa.array(values, type=pa.int64())], names=["a"])
df = ctx.create_dataframe([[batch]], name="t")
return df.select(expr.alias("out")).collect()[0].column(0).to_pylist()


def apply_udf_expr(args: tuple) -> list:
expr, values = args
# Reuse the worker's global ctx so the UDF registered by the initializer
# is visible during execution as well as during arg unpickling. Omit the
# table name so each call gets a fresh auto-generated one — a worker may
# process multiple tasks, and reusing a fixed name on the shared ctx would
# collide on the second call.
ctx = SessionContext.global_ctx()
batch = pa.RecordBatch.from_arrays([pa.array(values, type=pa.int64())], names=["a"])
df = ctx.create_dataframe([[batch]])
return df.select(expr.alias("out")).collect()[0].column(0).to_pylist()
Loading
Loading