Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 106 additions & 2 deletions datafusion/physical-plan/src/async_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::coalesce::LimitedBatchCoalescer;
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::stream::RecordBatchStreamAdapter;
use crate::{
Expand All @@ -24,16 +25,19 @@ use arrow::array::RecordBatch;
use arrow_schema::{Fields, Schema, SchemaRef};
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion_common::{Result, assert_eq_or_internal_err};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::ScalarFunctionExpr;
use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr;
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use futures::Stream;
use futures::stream::StreamExt;
use log::trace;
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};

/// This structure evaluates a set of async expressions on a record
/// batch producing a new record batch
Expand Down Expand Up @@ -188,7 +192,16 @@ impl ExecutionPlan for AsyncFuncExec {
let schema_captured = self.schema();
let config_options_ref = Arc::clone(context.session_config().options());

let stream_with_async_functions = input_stream.then(move |batch| {
let coalesced_input_stream = CoalesceInputStream {
input_stream,
batch_coalescer: LimitedBatchCoalescer::new(
Arc::clone(&self.input.schema()),
config_options_ref.execution.batch_size,
None,
),
};

let stream_with_async_functions = coalesced_input_stream.then(move |batch| {
// need to clone *again* to capture the async_exprs and schema in the
// stream and satisfy lifetime requirements.
let async_exprs_captured = Arc::clone(&async_exprs_captured);
Expand Down Expand Up @@ -221,6 +234,49 @@ impl ExecutionPlan for AsyncFuncExec {
}
}

struct CoalesceInputStream {
input_stream: Pin<Box<dyn RecordBatchStream + Send>>,
batch_coalescer: LimitedBatchCoalescer,
}

impl Stream for CoalesceInputStream {
type Item = Result<RecordBatch>;

fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut completed = false;

loop {
if let Some(batch) = self.batch_coalescer.next_completed_batch() {
return Poll::Ready(Some(Ok(batch)));
}

if completed {
return Poll::Ready(None);
}

match ready!(self.input_stream.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
if let Err(err) = self.batch_coalescer.push_batch(batch) {
return Poll::Ready(Some(Err(err)));
}
}
Some(err) => {
return Poll::Ready(Some(err));
}
None => {
completed = true;
if let Err(err) = self.batch_coalescer.finish() {
return Poll::Ready(Some(Err(err)));
}
}
}
}
}
}

const ASYNC_FN_PREFIX: &str = "__async_fn_";

/// Maps async_expressions to new columns
Expand Down Expand Up @@ -307,3 +363,51 @@ impl AsyncMapper {
Arc::new(Column::new(async_expr.name(), output_idx))
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow::array::{RecordBatch, UInt32Array};
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::Result;
use datafusion_execution::{TaskContext, config::SessionConfig};
use futures::StreamExt;

use crate::{ExecutionPlan, async_func::AsyncFuncExec, test::TestMemoryExec};

#[tokio::test]
async fn test_async_fn_with_coalescing() -> Result<()> {
let schema =
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));

let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6]))],
)?;

let batches: Vec<RecordBatch> = (0..50).map(|_| batch.clone()).collect();

let session_config = SessionConfig::new().with_batch_size(200);
let task_ctx = TaskContext::default().with_session_config(session_config);
let task_ctx = Arc::new(task_ctx);

let test_exec =
TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?;
let exec = AsyncFuncExec::try_new(vec![], test_exec)?;

let mut stream = exec.execute(0, Arc::clone(&task_ctx))?;
let batch = stream
.next()
.await
.expect("expected to get a record batch")?;
assert_eq!(200, batch.num_rows());
let batch = stream
.next()
.await
.expect("expected to get a record batch")?;
assert_eq!(100, batch.num_rows());

Ok(())
}
}