diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index eeffb28..537f6ef 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,7 +10,7 @@ jobs: fail-fast: false matrix: os: - - ubuntu-20.04 + - ubuntu-24.04 toolchain: - 1.68.0 diff --git a/.gitignore b/.gitignore index 96ef6c0..77147e2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target Cargo.lock +.idea/ diff --git a/convergence-arrow/Cargo.toml b/convergence-arrow/Cargo.toml index 3d483b5..8b5fa2a 100644 --- a/convergence-arrow/Cargo.toml +++ b/convergence-arrow/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "convergence-arrow" -version = "0.16.0" +version = "0.17.0" authors = ["Ruan Pearce-Authers "] edition = "2018" description = "Utils for bridging Apache Arrow and PostgreSQL's wire protocol" @@ -10,9 +10,8 @@ repository = "https://github.com/returnString/convergence" [dependencies] tokio = { version = "1" } async-trait = "0.1" -datafusion = "38" -convergence = { path = "../convergence", version = "0.16.0" } -chrono = "0.4" - -[dev-dependencies] +datafusion = "43" +convergence = { path = "../convergence", version = "0.17.0" } +chrono = "=0.4.39" tokio-postgres = { version = "0.7", features = [ "with-chrono-0_4" ] } +rust_decimal = { version = "1.36.0", features = ["default", "db-postgres"] } diff --git a/convergence-arrow/examples/datafusion.rs b/convergence-arrow/examples/datafusion.rs index aa4591e..d1c40f3 100644 --- a/convergence-arrow/examples/datafusion.rs +++ b/convergence-arrow/examples/datafusion.rs @@ -2,8 +2,9 @@ use convergence::server::{self, BindOptions}; use convergence_arrow::datafusion::DataFusionEngine; use convergence_arrow::metadata::Catalog; use datafusion::arrow::datatypes::DataType; -use datafusion::catalog::schema::MemorySchemaProvider; -use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider}; +use datafusion::catalog_common::memory::MemorySchemaProvider; +use datafusion::catalog::CatalogProvider; +use datafusion::catalog_common::MemoryCatalogProvider; use datafusion::logical_expr::Volatility; use datafusion::physical_plan::ColumnarValue; use datafusion::prelude::*; @@ -35,7 +36,7 @@ async fn new_engine() -> DataFusionEngine { ctx.register_udf(create_udf( "pg_backend_pid", vec![], - Arc::new(DataType::Int32), + DataType::Int32, Volatility::Stable, Arc::new(|_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(0))))), )); @@ -43,7 +44,7 @@ async fn new_engine() -> DataFusionEngine { ctx.register_udf(create_udf( "current_schema", vec![], - Arc::new(DataType::Utf8), + DataType::Utf8, Volatility::Stable, Arc::new(|_| Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some("public".to_owned()))))), )); diff --git a/convergence-arrow/src/metadata.rs b/convergence-arrow/src/metadata.rs index ad63207..6417711 100644 --- a/convergence-arrow/src/metadata.rs +++ b/convergence-arrow/src/metadata.rs @@ -3,8 +3,9 @@ use datafusion::arrow::array::{ArrayRef, Int32Builder, StringBuilder, UInt32Builder}; use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::catalog::schema::{MemorySchemaProvider, SchemaProvider}; use datafusion::catalog::CatalogProvider; +use datafusion::catalog::SchemaProvider; +use datafusion::catalog_common::memory::MemorySchemaProvider; use datafusion::datasource::{MemTable, TableProvider}; use datafusion::error::DataFusionError; use std::convert::TryInto; @@ -153,6 +154,7 @@ impl MetadataBuilder { } /// Wrapper catalog supporting generation of pg metadata (e.g. pg_catalog schema). +#[derive(Debug)] pub struct Catalog { wrapped: Arc, } diff --git a/convergence-arrow/src/table.rs b/convergence-arrow/src/table.rs index fb23e98..992d373 100644 --- a/convergence-arrow/src/table.rs +++ b/convergence-arrow/src/table.rs @@ -3,9 +3,10 @@ use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlState}; use convergence::protocol_ext::DataRowBatch; use datafusion::arrow::array::{ - BooleanArray, Date32Array, Date64Array, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, + StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array }; use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit}; use datafusion::arrow::record_batch::RecordBatch; @@ -47,7 +48,9 @@ pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBat DataType::Float16 => row.write_float4(array_val!(Float16Array, col, row_idx).to_f32()), DataType::Float32 => row.write_float4(array_val!(Float32Array, col, row_idx)), DataType::Float64 => row.write_float8(array_val!(Float64Array, col, row_idx)), + DataType::Decimal128(p, s) => row.write_numeric_16(array_val!(Decimal128Array, col, row_idx), p, s), DataType::Utf8 => row.write_string(array_val!(StringArray, col, row_idx)), + DataType::Utf8View => row.write_string(array_val!(StringViewArray, col, row_idx)), DataType::Date32 => { row.write_date(array_val!(Date32Array, col, row_idx, value_as_date).ok_or_else(|| { ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported date type") @@ -102,7 +105,8 @@ pub fn data_type_to_oid(ty: &DataType) -> Result { DataType::UInt64 => DataTypeOid::Int8, DataType::Float16 | DataType::Float32 => DataTypeOid::Float4, DataType::Float64 => DataTypeOid::Float8, - DataType::Utf8 => DataTypeOid::Text, + DataType::Decimal128(_, _) => DataTypeOid::Numeric, + DataType::Utf8 | DataType::Utf8View => DataTypeOid::Text, DataType::Date32 | DataType::Date64 => DataTypeOid::Date, DataType::Timestamp(_, None) => DataTypeOid::Timestamp, other => { diff --git a/convergence-arrow/tests/test_arrow.rs b/convergence-arrow/tests/test_arrow.rs index f1dc31e..aecf492 100644 --- a/convergence-arrow/tests/test_arrow.rs +++ b/convergence-arrow/tests/test_arrow.rs @@ -6,10 +6,11 @@ use convergence::protocol_ext::DataRowBatch; use convergence::server::{self, BindOptions}; use convergence::sqlparser::ast::Statement; use convergence_arrow::table::{record_batch_to_rows, schema_to_field_desc}; -use datafusion::arrow::array::{ArrayRef, Date32Array, Float32Array, Int32Array, StringArray, TimestampSecondArray}; +use datafusion::arrow::array::{ArrayRef, Date32Array, Decimal128Array, Float32Array, Int32Array, StringArray, StringViewArray, TimestampSecondArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::arrow::record_batch::RecordBatch; use std::sync::Arc; +use rust_decimal::Decimal; use tokio_postgres::{connect, NoTls}; struct ArrowPortal { @@ -31,20 +32,24 @@ impl ArrowEngine { fn new() -> Self { let int_col = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; let float_col = Arc::new(Float32Array::from(vec![1.5, 2.5, 3.5])) as ArrayRef; + let decimal_col = Arc::new(Decimal128Array::from(vec![11, 22, 33]).with_precision_and_scale(2, 0).unwrap()) as ArrayRef; let string_col = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let string_view_col = Arc::new(StringViewArray::from(vec!["aa", "bb", "cc"])) as ArrayRef; let ts_col = Arc::new(TimestampSecondArray::from(vec![1577836800, 1580515200, 1583020800])) as ArrayRef; let date_col = Arc::new(Date32Array::from(vec![0, 1, 2])) as ArrayRef; let schema = Schema::new(vec![ Field::new("int_col", DataType::Int32, true), Field::new("float_col", DataType::Float32, true), + Field::new("decimal_col", DataType::Decimal128(2, 0), true), Field::new("string_col", DataType::Utf8, true), + Field::new("string_view_col", DataType::Utf8View, true), Field::new("ts_col", DataType::Timestamp(TimeUnit::Second, None), true), Field::new("date_col", DataType::Date32, true), ]); Self { - batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, string_col, ts_col, date_col]) + batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, decimal_col, string_col, string_view_col, ts_col, date_col]) .expect("failed to create batch"), } } @@ -89,8 +94,8 @@ async fn basic_data_types() { let rows = client.query("select 1", &[]).await.unwrap(); let get_row = |idx: usize| { let row = &rows[idx]; - let cols: (i32, f32, &str, NaiveDateTime, NaiveDate) = - (row.get(0), row.get(1), row.get(2), row.get(3), row.get(4)); + let cols: (i32, f32, Decimal, &str, &str, NaiveDateTime, NaiveDate) = + (row.get(0), row.get(1), row.get(2), row.get(3), row.get(4), row.get(5), row.get(6)); cols }; @@ -99,7 +104,9 @@ async fn basic_data_types() { ( 1, 1.5, + Decimal::from(11), "a", + "aa", NaiveDate::from_ymd_opt(2020, 1, 1) .unwrap() .and_hms_opt(0, 0, 0) @@ -112,7 +119,9 @@ async fn basic_data_types() { ( 2, 2.5, + Decimal::from(22), "b", + "bb", NaiveDate::from_ymd_opt(2020, 2, 1) .unwrap() .and_hms_opt(0, 0, 0) @@ -125,7 +134,9 @@ async fn basic_data_types() { ( 3, 3.5, + Decimal::from(33), "c", + "cc", NaiveDate::from_ymd_opt(2020, 3, 1) .unwrap() .and_hms_opt(0, 0, 0) diff --git a/convergence/Cargo.toml b/convergence/Cargo.toml index b433b33..9f6287e 100644 --- a/convergence/Cargo.toml +++ b/convergence/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "convergence" -version = "0.16.0" +version = "0.17.0" authors = ["Ruan Pearce-Authers "] edition = "2018" description = "Write servers that speak PostgreSQL's wire protocol" @@ -15,7 +15,6 @@ bytes = "1" futures = "0.3" sqlparser = "0.46" async-trait = "0.1" -chrono = "0.4" - -[dev-dependencies] +chrono = "=0.4.39" +rust_decimal = { version = "1.36.0", features = ["default", "db-postgres"] } tokio-postgres = "0.7" diff --git a/convergence/src/protocol.rs b/convergence/src/protocol.rs index 8ac5bdf..bfae419 100644 --- a/convergence/src/protocol.rs +++ b/convergence/src/protocol.rs @@ -75,6 +75,8 @@ data_types! { Float4 = 700, 4 Float8 = 701, 8 + Numeric = 1700, -1 + Date = 1082, 4 Timestamp = 1114, 8 diff --git a/convergence/src/protocol_ext.rs b/convergence/src/protocol_ext.rs index 575090f..f2db117 100644 --- a/convergence/src/protocol_ext.rs +++ b/convergence/src/protocol_ext.rs @@ -3,6 +3,8 @@ use crate::protocol::{ConnectionCodec, FormatCode, ProtocolError, RowDescription}; use bytes::{BufMut, BytesMut}; use chrono::{NaiveDate, NaiveDateTime}; +use rust_decimal::Decimal; +use tokio_postgres::types::{ToSql, Type}; use tokio_util::codec::Encoder; /// Supports batched rows for e.g. returning portal result sets. @@ -131,6 +133,24 @@ impl<'a> DataRowWriter<'a> { } } + /// Writes a numeric value for the next column. + pub fn write_numeric_16(&mut self, val: i128, _p: &u8, s: &i8) { + let decimal = Decimal::from_i128_with_scale(val, *s as u32); + match self.parent.format_code { + FormatCode::Text => { + self.write_string(&decimal.to_string()) + } + FormatCode::Binary => { + let numeric_type = Type::from_oid(1700).expect("failed to create numeric type"); + let mut buf = BytesMut::new(); + decimal.to_sql(&numeric_type, &mut buf) + .expect("failed to write numeric"); + + self.write_value(&buf.freeze()) + } + }; + } + primitive_write!(write_int2, i16); primitive_write!(write_int4, i32); primitive_write!(write_int8, i64); @@ -138,7 +158,7 @@ impl<'a> DataRowWriter<'a> { primitive_write!(write_float8, f64); } -impl<'a> Drop for DataRowWriter<'a> { +impl Drop for DataRowWriter<'_> { fn drop(&mut self) { assert_eq!( self.parent.num_cols, self.current_col, diff --git a/convergence/tests/test_connection.rs b/convergence/tests/test_connection.rs index c234a57..1f23bdf 100644 --- a/convergence/tests/test_connection.rs +++ b/convergence/tests/test_connection.rs @@ -79,16 +79,16 @@ async fn extended_query_flow() { async fn simple_query_flow() { let client = setup().await; let messages = client.simple_query("select 1").await.unwrap(); - assert_eq!(messages.len(), 2); + assert_eq!(messages.len(), 3); - let row = match &messages[0] { + let row = match &messages[1] { SimpleQueryMessage::Row(row) => row, _ => panic!("expected row"), }; assert_eq!(row.get(0), Some("1")); - let num_rows = match &messages[1] { + let num_rows = match &messages[2] { SimpleQueryMessage::CommandComplete(rows) => *rows, _ => panic!("expected command complete"), };