Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ harness = false
name = "replace"
required-features = ["string_expressions"]

[[bench]]
harness = false
name = "overlay"

[[bench]]
harness = false
name = "random"
Expand Down
68 changes: 68 additions & 0 deletions datafusion/functions/benches/overlay.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

mod helper;

use arrow::datatypes::{DataType, Field};
use criterion::{Criterion, criterion_group, criterion_main};
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use helper::gen_string_array;
use std::hint::black_box;
use std::sync::Arc;

fn criterion_benchmark(c: &mut Criterion) {
const N_ROWS: usize = 8192;
const STR_LEN: usize = 128;

let overlay = datafusion_functions::core::overlay();
let config_options = Arc::new(ConfigOptions::default());

let mut args = gen_string_array(N_ROWS, STR_LEN, 0.1, 0.5, false);
args.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
"DataFusion".to_string(),
))));
args.push(ColumnarValue::Scalar(ScalarValue::Int64(Some(32))));
args.push(ColumnarValue::Scalar(ScalarValue::Int64(Some(8))));

let arg_fields = args
.iter()
.enumerate()
.map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into())
.collect::<Vec<_>>();
let return_field = Arc::new(Field::new("f", DataType::Utf8, true));

c.bench_function("overlay_StringArray_utf8_scalar_args", |b| {
b.iter(|| {
black_box(
overlay
.invoke_with_args(ScalarFunctionArgs {
args: args.clone(),
arg_fields: arg_fields.clone(),
number_rows: N_ROWS,
return_field: Arc::clone(&return_field),
config_options: Arc::clone(&config_options),
})
.unwrap(),
)
})
});
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
252 changes: 142 additions & 110 deletions datafusion/functions/src/core/overlay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,140 +112,170 @@ impl ScalarUDFImpl for OverlayFunc {
}
}

/// Converts a 0-based character index into a byte index suitable for UTF-8
/// slicing.
fn byte_index_for_char(string: &str, char_idx: usize, is_ascii: bool) -> usize {
if is_ascii {
char_idx.min(string.len())
} else {
string
.char_indices()
.nth(char_idx)
.map_or(string.len(), |(byte_idx, _)| byte_idx)
}
}

/// Builds the OVERLAY result for a single (non-null) row.
///
/// `start_pos` is a 1-based character position; `replace_len` is the number
/// of characters of `string` to replace with `characters`.
fn overlay_one(
string: &str,
characters: &str,
start_pos: i64,
replace_len: i64,
) -> String {
debug_assert!(start_pos >= 1);

let is_ascii = string.is_ascii();
let string_char_len = if is_ascii {
string.len() as i64
} else {
string.chars().count() as i64
};

// Convert SQL's 1-based character position into 0-based character indexes.
// `start_char_idx` is the first replaced character; `end_char_idx` is the
// first character after the replaced span.
//
// No upper-bound check on `start_char_idx`: when it exceeds `string_char_len`
// we want the whole string as the prefix (PostgreSQL-compatible "insert past
// end" semantics).
let start_char_idx = start_pos - 1;
let end_char_idx = start_char_idx.saturating_add(replace_len);

let prefix_char_idx = usize::try_from(start_char_idx).unwrap_or(usize::MAX);
let prefix_end_byte = byte_index_for_char(string, prefix_char_idx, is_ascii);

let mut res = String::with_capacity(string.len() + characters.len());
res.push_str(&string[..prefix_end_byte]);
res.push_str(characters);

if end_char_idx < string_char_len {
let suffix_char_idx = usize::try_from(end_char_idx.max(0)).unwrap_or(usize::MAX);
let suffix_start_byte = byte_index_for_char(string, suffix_char_idx, is_ascii);
res.push_str(&string[suffix_start_byte..]);
}
res
}

macro_rules! process_overlay {
// For the three-argument case
($string_array:expr, $characters_array:expr, $pos_num:expr) => {{
// Three argument case
($string_array:expr, $characters_array:expr, $pos_array:expr) => {{
$string_array
.iter()
.zip($characters_array.iter())
.zip($pos_num.iter())
.map(|((string, characters), start_pos)| {
match (string, characters, start_pos) {
(Some(string), Some(characters), Some(start_pos)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = characters_len as i64;
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
.iter()
.zip($characters_array.iter())
.zip($pos_array.iter())
.map(|((string, characters), start_pos)| {
match (string, characters, start_pos) {
(Some(string), Some(characters), Some(start_pos)) => {
if start_pos < 1 {
return exec_err!("negative substring length not allowed");
}
let replace_len = characters.chars().count() as i64;
Ok(Some(overlay_one(
string,
characters,
start_pos,
replace_len,
)))
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
}
Ok(Some(res))
_ => Ok(None),
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()
})
.collect::<Result<GenericStringArray<T>>>()
}};

// For the four-argument case
($string_array:expr, $characters_array:expr, $pos_num:expr, $len_num:expr) => {{
// Four argument case
($string_array:expr, $characters_array:expr, $pos_array:expr, $len_array:expr) => {{
$string_array
.iter()
.zip($characters_array.iter())
.zip($pos_num.iter())
.zip($len_num.iter())
.map(|(((string, characters), start_pos), len)| {
match (string, characters, start_pos, len) {
(Some(string), Some(characters), Some(start_pos), Some(len)) => {
let string_len = string.chars().count();
let characters_len = characters.chars().count();
let replace_len = len.min(string_len as i64);
let mut res =
String::with_capacity(string_len.max(characters_len));

//as sql replace index start from 1 while string index start from 0
if start_pos > 1 && start_pos - 1 < string_len as i64 {
let start = (start_pos - 1) as usize;
res.push_str(&string[..start]);
}
res.push_str(characters);
// if start + replace_len - 1 >= string_length, just to string end
if start_pos + replace_len - 1 < string_len as i64 {
let end = (start_pos + replace_len - 1) as usize;
res.push_str(&string[end..]);
.iter()
.zip($characters_array.iter())
.zip($pos_array.iter())
.zip($len_array.iter())
.map(|(((string, characters), start_pos), len)| {
match (string, characters, start_pos, len) {
(Some(string), Some(characters), Some(start_pos), Some(len)) => {
if start_pos < 1 {
return exec_err!("negative substring length not allowed");
}
let string_char_len = string.chars().count() as i64;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small optimization idea: the four-argument path calls string.chars().count() to clamp len, but overlay_one ends up computing the character length again internally. Since overlay_one already handles lengths that extend past the end of the string, it might be simpler to pass len through directly or move the clamp logic into overlay_one so we avoid the extra Unicode scan per row.

let replace_len = len.min(string_char_len);
Ok(Some(overlay_one(
string,
characters,
start_pos,
replace_len,
)))
}
Ok(Some(res))
_ => Ok(None),
}
_ => Ok(None),
}
})
.collect::<Result<GenericStringArray<T>>>()
})
.collect::<Result<GenericStringArray<T>>>()
}};
}

/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2)
/// Replaces a substring of string1 with string2 starting at the integer bit
/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas
/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead
/// `OVERLAY(string PLACING substring FROM start [FOR count])`
///
/// Replaces a region of `string` with `substring`, starting at the 1-based
/// character position `start`. If `count` is supplied, that many characters
/// of `string` are replaced; otherwise `count` defaults to the character
/// length of `substring`.
///
/// ```text
/// overlay('Txxxxas' placing 'hom' from 2 for 4) → 'Thomas'
/// overlay('Txxxxas' placing 'hom' from 2) → 'Thomxas'
/// ```
fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let use_string_view = args[0].data_type() == &DataType::Utf8View;
if use_string_view {
if !matches!(args.len(), 3 | 4) {
return exec_err!(
"overlay was called with {} arguments. It requires 3 or 4.",
args.len()
);
}
if args[0].data_type() == &DataType::Utf8View {
string_view_overlay::<T>(args)
} else {
string_overlay::<T>(args)
}
}

fn string_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
3 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_num = as_int64_array(&args[2])?;
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_array = as_int64_array(&args[2])?;

let result = process_overlay!(string_array, characters_array, pos_num)?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let characters_array = as_generic_string_array::<T>(&args[1])?;
let pos_num = as_int64_array(&args[2])?;
let len_num = as_int64_array(&args[3])?;

let result =
process_overlay!(string_array, characters_array, pos_num, len_num)?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
exec_err!("overlay was called with {other} arguments. It requires 3 or 4.")
}
}
let result = if args.len() == 4 {
let len_array = as_int64_array(&args[3])?;
process_overlay!(string_array, characters_array, pos_array, len_array)?
} else {
process_overlay!(string_array, characters_array, pos_array)?
};
Ok(Arc::new(result) as ArrayRef)
}

fn string_view_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
3 => {
let string_array = as_string_view_array(&args[0])?;
let characters_array = as_string_view_array(&args[1])?;
let pos_num = as_int64_array(&args[2])?;

let result = process_overlay!(string_array, characters_array, pos_num)?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = as_string_view_array(&args[0])?;
let characters_array = as_string_view_array(&args[1])?;
let pos_num = as_int64_array(&args[2])?;
let len_num = as_int64_array(&args[3])?;
let string_array = as_string_view_array(&args[0])?;
let characters_array = as_string_view_array(&args[1])?;
let pos_array = as_int64_array(&args[2])?;

let result =
process_overlay!(string_array, characters_array, pos_num, len_num)?;
Ok(Arc::new(result) as ArrayRef)
}
other => {
exec_err!("overlay was called with {other} arguments. It requires 3 or 4.")
}
}
let result = if args.len() == 4 {
let len_array = as_int64_array(&args[3])?;
process_overlay!(string_array, characters_array, pos_array, len_array)?
} else {
process_overlay!(string_array, characters_array, pos_array)?
};
Ok(Arc::new(result) as ArrayRef)
}

#[cfg(test)]
Expand All @@ -265,7 +295,9 @@ mod tests {

let res = overlay::<i32>(&[string, replace_string, start, end]).unwrap();
let result = as_generic_string_array::<i32>(&res).unwrap();
let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]);
// First row: start=4 is past the end of "123" (len 3). PostgreSQL
// takes the whole string as prefix and appends the replacement.
let expected = StringArray::from(vec!["123abc", "qwertyasdfg", "ijkz", "Thomas"]);
assert_eq!(&expected, result);

Ok(())
Expand Down
Loading
Loading