diff --git a/Cargo.toml b/Cargo.toml index c7385c5..6996e96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ resolver = "2" members = [ + "examples/simple-logs", "hyperfuel-client", "hyperfuel-format", "hyperfuel-net-types", diff --git a/examples/simple-logs/Cargo.toml b/examples/simple-logs/Cargo.toml new file mode 100644 index 0000000..0989b2f --- /dev/null +++ b/examples/simple-logs/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "simple-logs" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +hex-literal = "0.4.1" +hyperfuel-client = { path = "../../hyperfuel-client" } +tokio = { version = "1", features = ["full"] } +url = "2.5.0" diff --git a/examples/simple-logs/src/main.rs b/examples/simple-logs/src/main.rs new file mode 100644 index 0000000..0404ac1 --- /dev/null +++ b/examples/simple-logs/src/main.rs @@ -0,0 +1,30 @@ +use hyperfuel_client::{Client, ClientConfig}; +use url::Url; + +#[tokio::main] +async fn main() { + let client_config = ClientConfig { + url: Some(Url::parse("https://fuel.hypersync.xyz").unwrap()), + ..Default::default() + }; + let client = Client::new(client_config).unwrap(); + + let contracts = vec![hex_literal::hex!( + "4a2ce054e3e94155f7092f7365b212f7f45105b74819c623744ebcc5d065c6ac" + )]; + let from_block = 0; + let to_block = Some(50_000); + + let logs = client + .preset_query_get_logs(contracts, from_block, to_block) + .await + .unwrap(); + + println!( + "archive_height={:?} next_block={} total_execution_time={}ms logs={}", + logs.archive_height, + logs.next_block, + logs.total_execution_time, + logs.data.len() + ); +} diff --git a/hyperfuel-client/Cargo.toml b/hyperfuel-client/Cargo.toml index a2fc0a6..3624287 100644 --- a/hyperfuel-client/Cargo.toml +++ b/hyperfuel-client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hyperfuel-client" -version = "3.0.1" +version = "3.1.0" edition = "2021" description = "client library for hyperfuel" license = "MPL-2.0" @@ -27,7 +27,6 @@ arrayvec = { version = "0.7", features = ["serde"] } tokio = { version = "1", default-features = false, features = [ "rt-multi-thread", "fs", - "test-util", "rt", "macros", ] } diff --git a/hyperfuel-client/src/column_mapping.rs b/hyperfuel-client/src/column_mapping.rs index 0a744fb..28f642f 100644 --- a/hyperfuel-client/src/column_mapping.rs +++ b/hyperfuel-client/src/column_mapping.rs @@ -26,13 +26,13 @@ pub struct ColumnMapping { /// Mapping for transaction data. #[serde(default)] pub transaction: BTreeMap, - /// Mapping for log data. + /// Mapping for receipt data. #[serde(default)] pub receipt: BTreeMap, - /// Mapping for trace data. + /// Mapping for input data. #[serde(default)] pub input: BTreeMap, - /// Mapping for decoded log data. + /// Mapping for output data. #[serde(default)] pub output: BTreeMap, } @@ -85,7 +85,7 @@ pub fn apply_to_batch( .context(format!("apply cast to column '{}'", field.name))? } else { map_column(&**col, dt) - .context(format!("apply cast to colum '{}'", field.name))? + .context(format!("apply cast to column '{}'", field.name))? } } None => col.clone(), diff --git a/hyperfuel-client/src/lib.rs b/hyperfuel-client/src/lib.rs index 26213bc..8e1b4f8 100644 --- a/hyperfuel-client/src/lib.rs +++ b/hyperfuel-client/src/lib.rs @@ -261,7 +261,7 @@ impl Client { let height: ArchiveHeight = res.json().await.context("read response body json")?; - Ok(height.height.unwrap_or(0)) + height.height.context("missing height in response") } /// Get the chain_id from the server with retries. @@ -367,12 +367,15 @@ impl Client { } let bytes = res.bytes().await.context("read response body bytes")?; + let byte_len = bytes.len(); - let res = tokio::task::block_in_place(|| { + let res = tokio::task::spawn_blocking(move || { parse_query_response(&bytes).context("parse query response") - })?; + }) + .await + .context("join parse task")??; - Ok((res, bytes.len().try_into().unwrap())) + Ok((res, byte_len.try_into().unwrap())) } /// Executes query with retries and returns the response in Arrow format. diff --git a/hyperfuel-client/src/parse_response.rs b/hyperfuel-client/src/parse_response.rs index 0e7056b..28b37a7 100644 --- a/hyperfuel-client/src/parse_response.rs +++ b/hyperfuel-client/src/parse_response.rs @@ -28,7 +28,12 @@ fn read_chunks(bytes: &[u8]) -> Result> { pub fn parse_query_response(bytes: &[u8]) -> Result { let mut opts = capnp::message::ReaderOptions::new(); - opts.nesting_limit(i32::MAX).traversal_limit_in_words(None); + // Bounded limits for untrusted network input. Default capnp limits are 64 + // nesting / 64 MiB traversal; we raise the traversal cap to 512 MiB (64M + // words * 8 bytes/word) to fit large paginated arrow payloads. Callers + // hitting this should reduce per-query block ranges via max_num_blocks. + opts.nesting_limit(64) + .traversal_limit_in_words(Some(64 * 1024 * 1024)); let message_reader = capnp::serialize_packed::read_message(bytes, opts).context("create message reader")?; diff --git a/hyperfuel-client/tests/api_test.rs b/hyperfuel-client/tests/api_test.rs deleted file mode 100644 index ab5ce6c..0000000 --- a/hyperfuel-client/tests/api_test.rs +++ /dev/null @@ -1,522 +0,0 @@ -// use std::{collections::BTreeSet, env::temp_dir, sync::Arc}; - -// use alloy_json_abi::JsonAbi; -// use hypersync_client::{ -// preset_query, simple_types::Transaction, Client, ClientConfig, ColumnMapping, StreamConfig, -// }; -// use hypersync_format::{Address, FilterWrapper, Hex, LogArgument}; -// use hypersync_net_types::{FieldSelection, Query, TransactionSelection}; -// use polars_arrow::array::UInt64Array; - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_api_arrow_ipc() { -// let client = Client::new(ClientConfig::default()).unwrap(); - -// let mut block_field_selection = BTreeSet::new(); -// block_field_selection.insert("number".to_owned()); -// block_field_selection.insert("timestamp".to_owned()); -// block_field_selection.insert("hash".to_owned()); - -// let res = client -// .get_arrow(&Query { -// from_block: 14000000, -// to_block: None, -// logs: Vec::new(), -// transactions: Vec::new(), -// include_all_blocks: true, -// field_selection: FieldSelection { -// block: block_field_selection, -// log: Default::default(), -// transaction: Default::default(), -// trace: Default::default(), -// }, -// ..Default::default() -// }) -// .await -// .unwrap(); - -// dbg!(res.next_block); -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_api_arrow_ipc_ordering() { -// let client = Client::new(ClientConfig::default()).unwrap(); - -// let mut block_field_selection = BTreeSet::new(); -// block_field_selection.insert("number".to_owned()); - -// let query: Query = serde_json::from_value(serde_json::json!({ -// "from_block": 13171881, -// "to_block": 18270333, -// "logs": [ -// { -// "address": [ -// "0x15b7c0c907e4C6b9AdaAaabC300C08991D6CEA05" -// ], -// "topics": [ -// [ -// "0x8c5be1e5ebec7d5bd14f71427d1e84f3dd0314c0f7b2291e5b200ac8c7c3b925", -// "0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef" -// ] -// ] -// } -// ], -// "field_selection": { -// "block": [ -// "number" -// ], -// "log": [ -// "log_index", -// "block_number" -// ] -// } -// })) -// .unwrap(); - -// let res = client.get_arrow(&query).await.unwrap(); - -// assert!(res.next_block > 13223105); - -// let mut last = (0, 0); -// for batch in res.data.logs { -// let block_number = batch.column::("block_number").unwrap(); -// let log_index = batch.column::("log_index").unwrap(); - -// for (&block_number, &log_index) in block_number.values_iter().zip(log_index.values_iter()) { -// let number = (block_number, log_index); -// assert!(last < number, "last: {:?};number: {:?};", last, number); -// last = number; -// } -// } -// } - -// fn get_file_path(name: &str) -> String { -// format!("{}/test-data/{name}", env!("CARGO_MANIFEST_DIR")) -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_api_decode_logs() { -// env_logger::try_init().ok(); - -// const ADDR: &str = "0xc18360217d8f7ab5e7c516566761ea12ce7f9d72"; - -// let client = Arc::new(Client::new(ClientConfig::default()).unwrap()); - -// let query: Query = serde_json::from_value(serde_json::json!({ -// "from_block": 18680952, -// "to_block": 18680953, -// "logs": [ -// { -// "address": [ -// ADDR -// ] -// } -// ], -// "field_selection": { -// "log": [ -// "address", -// "data", -// "topic0", -// "topic1", -// "topic2", -// "topic3" -// ] -// } -// })) -// .unwrap(); - -// let mut rx = client -// .stream_arrow( -// query, -// StreamConfig { -// event_signature: Some( -// "Transfer(address indexed from, address indexed to, uint indexed amount)" -// .into(), -// ), -// ..Default::default() -// }, -// ) -// .await -// .unwrap(); - -// let res = rx.recv().await.unwrap().unwrap(); - -// let decoded_logs = res.data.decoded_logs; - -// dbg!(res.data.logs); - -// assert_eq!(decoded_logs[0].chunk.len(), 1); - -// println!("{:?}", decoded_logs[0]); -// } - -// #[test] -// fn parse_nameless_abi() { -// let path = get_file_path("nameless.abi.json"); -// let abi = std::fs::read_to_string(path).unwrap(); -// let _abi: JsonAbi = serde_json::from_str(&abi).unwrap(); -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_get_events_without_join_fields() { -// env_logger::try_init().ok(); - -// let client = Client::new(ClientConfig { -// url: Some("https://base.hypersync.xyz".parse().unwrap()), -// ..Default::default() -// }) -// .unwrap(); - -// let query: Query = serde_json::from_value(serde_json::json!({ -// "from_block": 6589327, -// "to_block": 6589328, -// "logs": [{ -// "address": ["0xd981ed72b1b3bf866563a9755d41a887d3e4721a"], -// "topics": [["0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef"]], -// }], -// "field_selection": { -// "log": ["block_number", "topic0", "topic1", "topic2", "topic3", "data", "address"], -// "transaction": ["value"], -// "block": ["gas_used"], -// } -// })) -// .unwrap(); - -// let res = client.get_events(query).await.unwrap(); - -// dbg!(res.data); -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_stream_decode_with_invalid_log() { -// env_logger::try_init().ok(); - -// let client = Client::new(ClientConfig { -// url: Some("https://base.hypersync.xyz".parse().unwrap()), -// ..Default::default() -// }) -// .unwrap(); -// let client = Arc::new(client); - -// let query: Query = serde_json::from_value(serde_json::json!({ -// "from_block": 6589327, -// "to_block": 6589328, -// "logs": [{ -// "address": ["0xd981ed72b1b3bf866563a9755d41a887d3e4721a"], -// "topics": [["0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef"]], -// }], -// "field_selection": { -// "log": ["block_number", "topic0", "topic1", "topic2", "topic3", "data", "address"], -// } -// })) -// .unwrap(); - -// let data = client -// .collect_arrow( -// query, -// StreamConfig { -// column_mapping: Some(ColumnMapping { -// block: maplit::btreemap! { -// "number".to_owned() => hypersync_client::DataType::Float32, -// }, -// transaction: maplit::btreemap! { -// "value".to_owned() => hypersync_client::DataType::Float64, -// }, -// log: Default::default(), -// trace: Default::default(), -// decoded_log: maplit::btreemap! { -// "amount".to_owned() => hypersync_client::DataType::Float64, -// }, -// }), -// event_signature: Some( -// "Transfer(address indexed from, address indexed to, uint indexed amount)" -// .into(), -// ), -// ..Default::default() -// }, -// ) -// .await -// .unwrap(); - -// dbg!(data); -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_parquet_out() { -// env_logger::try_init().ok(); - -// let client = Arc::new(Client::new(ClientConfig::default()).unwrap()); - -// let path = format!("{}/{}", temp_dir().to_string_lossy(), uuid::Uuid::new_v4()); - -// let query: Query = serde_json::from_value(serde_json::json!({ -// "from_block": 19277345, -// "to_block": 19277346, -// "logs": [{ -// "address": ["0xdAC17F958D2ee523a2206206994597C13D831ec7"], -// "topics": [["0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef"]], -// }], -// "transactions": [{}], -// "include_all_blocks": true, -// "field_selection": { -// "log": ["block_number", "topic0", "topic1", "topic2", "topic3", "data", "address"], -// } -// })) -// .unwrap(); - -// client -// .collect_parquet( -// &path, -// query, -// StreamConfig { -// column_mapping: Some(ColumnMapping { -// block: maplit::btreemap! { -// "number".to_owned() => hypersync_client::DataType::Float32, -// }, -// transaction: maplit::btreemap! { -// "value".to_owned() => hypersync_client::DataType::Float64, -// }, -// log: Default::default(), -// trace: Default::default(), -// decoded_log: maplit::btreemap! { -// //"amount".to_owned() => hypersync_client::DataType::Float64, -// }, -// }), -// event_signature: Some( -// "Transfer(address indexed from, address indexed to, uint indexed amount)" -// .into(), -// ), -// ..Default::default() -// }, -// ) -// .await -// .unwrap(); -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_api_preset_query_blocks_and_transactions() { -// let client = Arc::new(Client::new(ClientConfig::default()).unwrap()); -// let query = preset_query::blocks_and_transactions(18_000_000, Some(18_000_010)); -// let res = client.get_arrow(&query).await.unwrap(); - -// let num_blocks: usize = res -// .data -// .blocks -// .into_iter() -// .map(|batch| batch.chunk.len()) -// .sum(); -// let num_txs: usize = res -// .data -// .transactions -// .into_iter() -// .map(|batch| batch.chunk.len()) -// .sum(); - -// assert_eq!(res.next_block, 18_000_010); -// assert_eq!(num_blocks, 10); -// assert!(num_txs > 1); -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_api_preset_query_blocks_and_transaction_hashes() { -// let client = Client::new(ClientConfig::default()).unwrap(); -// let query = preset_query::blocks_and_transaction_hashes(18_000_000, Some(18_000_010)); -// let res = client.get_arrow(&query).await.unwrap(); - -// let num_blocks: usize = res -// .data -// .blocks -// .into_iter() -// .map(|batch| batch.chunk.len()) -// .sum(); -// let num_txs: usize = res -// .data -// .transactions -// .into_iter() -// .map(|batch| batch.chunk.len()) -// .sum(); - -// assert_eq!(res.next_block, 18_000_010); -// assert_eq!(num_blocks, 10); -// assert!(num_txs > 1); -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_api_preset_query_logs() { -// let client = Client::new(ClientConfig::default()).unwrap(); - -// let usdt_addr = Address::decode_hex("0xdAC17F958D2ee523a2206206994597C13D831ec7").unwrap(); -// let query = preset_query::logs(18_000_000, Some(18_000_010), usdt_addr); -// let res = client.get_arrow(&query).await.unwrap(); - -// let num_logs: usize = res -// .data -// .logs -// .into_iter() -// .map(|batch| batch.chunk.len()) -// .sum(); - -// assert_eq!(res.next_block, 18_000_010); -// assert!(num_logs > 1); -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_api_preset_query_logs_of_event() { -// let client = Client::new(ClientConfig::default()).unwrap(); - -// let usdt_addr = Address::decode_hex("0xdAC17F958D2ee523a2206206994597C13D831ec7").unwrap(); -// let transfer_topic0 = LogArgument::decode_hex( -// "0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef", -// ) -// .unwrap(); -// let query = -// preset_query::logs_of_event(18_000_000, Some(18_000_010), transfer_topic0, usdt_addr); - -// let res = client.get_arrow(&query).await.unwrap(); - -// let num_logs: usize = res -// .data -// .logs -// .into_iter() -// .map(|batch| batch.chunk.len()) -// .sum(); - -// assert_eq!(res.next_block, 18_000_010); -// assert!(num_logs > 1); -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_api_preset_query_transactions() { -// let client = Client::new(ClientConfig::default()).unwrap(); -// let query = preset_query::transactions(18_000_000, Some(18_000_010)); -// let res = client.get_arrow(&query).await.unwrap(); - -// let num_txs: usize = res -// .data -// .transactions -// .into_iter() -// .map(|batch| batch.chunk.len()) -// .sum(); - -// assert_eq!(res.next_block, 18_000_010); -// assert!(num_txs > 1); -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_api_preset_query_transactions_from_address() { -// let client = Client::new(ClientConfig::default()).unwrap(); - -// let vitalik_eth_addr = -// Address::decode_hex("0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045").unwrap(); -// let query = -// preset_query::transactions_from_address(19_000_000, Some(19_300_000), vitalik_eth_addr); -// let res = client.get_arrow(&query).await.unwrap(); - -// let num_txs: usize = res -// .data -// .transactions -// .into_iter() -// .map(|batch| batch.chunk.len()) -// .sum(); - -// assert!(res.next_block == 19_300_000); -// assert!(num_txs == 21); -// } - -// // same query as above (test_api_preset_query_transactions_from_address) except it uses a bloom filter instead of a -// // vector of addresses to target the specified address -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_small_bloom_filter_query() { -// let client = Arc::new(Client::new(ClientConfig::default()).unwrap()); - -// let vitalik_eth_addr = -// Address::decode_hex("0xd8dA6BF26964aF9D7eEd9e03E53415D37aA96045").unwrap(); - -// let mut txn_field_selection = BTreeSet::new(); -// txn_field_selection.insert("block_number".to_owned()); -// txn_field_selection.insert("from".to_owned()); -// txn_field_selection.insert("hash".to_owned()); - -// let addrs = [vitalik_eth_addr.clone()]; -// let from_address_filter = -// FilterWrapper::from_keys(addrs.iter().map(|d| d.as_ref()), None).unwrap(); - -// let query = Query { -// from_block: 19_000_000, -// to_block: Some(19_300_000), -// logs: Vec::new(), -// transactions: vec![TransactionSelection { -// from_filter: Some(from_address_filter), -// ..Default::default() -// }], -// field_selection: FieldSelection { -// block: Default::default(), -// log: Default::default(), -// transaction: txn_field_selection, -// trace: Default::default(), -// }, -// ..Default::default() -// }; - -// let stream_config = StreamConfig::default(); - -// let res = client.collect(query, stream_config).await.unwrap(); - -// let txns: Vec = res.data.transactions.into_iter().flatten().collect(); -// let num_txns = txns.len(); - -// for txn in txns { -// if txn.from.as_ref() != Some(&vitalik_eth_addr) { -// panic!("returned an address not in the bloom filter") -// } -// } - -// assert_eq!(res.next_block, 19_300_000); -// assert_eq!(num_txns, 21); -// } - -// #[tokio::test(flavor = "multi_thread")] -// #[ignore] -// async fn test_decode_string_param_into_arrow() { -// let client = Arc::new( -// Client::new(ClientConfig { -// url: Some("https://mev-commit.hypersync.xyz".parse().unwrap()), -// ..Default::default() -// }) -// .unwrap(), -// ); - -// let query: Query = serde_json::from_value(serde_json::json!({ -// "from_block": 0, -// "logs": [{ -// "address": ["0xCAC68D97a56b19204Dd3dbDC103CB24D47A825A3"], -// "topics": [["0xe44dd4d002deb2c79cf08ce285a9d80c69753f31ca65c8e49f0a60d27ed9fea3"]], -// }], -// "field_selection": { -// "log": ["block_number", "topic0", "topic1", "topic2", "topic3", "data", "address"], -// } -// })) -// .unwrap(); - -// let conf = StreamConfig { -// event_signature: Some("CommitmentStored(bytes32 indexed commitmentIndex, address bidder, address commiter, uint256 bid, uint64 blockNumber, bytes32 bidHash, uint64 decayStartTimeStamp, uint64 decayEndTimeStamp, string txnHash, string revertingTxHashes, bytes32 commitmentHash, bytes bidSignature, bytes commitmentSignature, uint64 dispatchTimestamp, bytes sharedSecretKey)".into()), -// ..Default::default() -// }; - -// let data = client.collect_arrow(query, conf).await.unwrap(); - -// dbg!(data.data.decoded_logs); -// } diff --git a/hyperfuel-format/Cargo.toml b/hyperfuel-format/Cargo.toml index 921ac9e..d8e74d3 100644 --- a/hyperfuel-format/Cargo.toml +++ b/hyperfuel-format/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hyperfuel-format" -version = "4.0.1" +version = "4.1.0" edition = "2021" description = "fuel network format library" license = "MPL-2.0" diff --git a/hyperfuel-format/src/types/quantity.rs b/hyperfuel-format/src/types/quantity.rs index 2ad7255..ef65483 100644 --- a/hyperfuel-format/src/types/quantity.rs +++ b/hyperfuel-format/src/types/quantity.rs @@ -26,7 +26,12 @@ impl From> for Quantity { impl From for Quantity { fn from(value: u64) -> Self { - Self(value.to_be_bytes().into()) + if value == 0 { + return Quantity::default(); + } + let bytes = value.to_be_bytes(); + let first_non_zero = bytes.iter().position(|b| *b != 0).unwrap(); + Self(bytes[first_non_zero..].to_vec().into()) } } @@ -208,4 +213,15 @@ mod tests { fn test_from_slice_leading_zeroes() { let _ = Quantity::from(vec![0, 1].as_slice()); } + + #[test] + fn test_from_u64_canonical() { + assert_eq!(Quantity::from(0u64), Quantity::default()); + assert_eq!(Quantity::from(5u64), Quantity::from(vec![5])); + assert_eq!(Quantity::from(0x4200u64), Quantity::from(hex!("4200"))); + assert_eq!( + Quantity::from(u64::MAX), + Quantity::from(hex!("ffffffffffffffff")) + ); + } } diff --git a/hyperfuel-net-types/Cargo.toml b/hyperfuel-net-types/Cargo.toml index ca7efcd..f725966 100644 --- a/hyperfuel-net-types/Cargo.toml +++ b/hyperfuel-net-types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hyperfuel-net-types" -version = "4.0.3" +version = "4.1.0" edition = "2021" description = "hyperfuel types for transport over network" license = "MPL-2.0" diff --git a/hyperfuel-schema/Cargo.toml b/hyperfuel-schema/Cargo.toml index f12389b..3875740 100644 --- a/hyperfuel-schema/Cargo.toml +++ b/hyperfuel-schema/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hyperfuel-schema" -version = "4.0.0" +version = "4.1.0" edition = "2021" description = "schema utilities for hyperfuel" license = "MPL-2.0" diff --git a/hyperfuel-schema/src/lib.rs b/hyperfuel-schema/src/lib.rs index a09b735..70b3c75 100644 --- a/hyperfuel-schema/src/lib.rs +++ b/hyperfuel-schema/src/lib.rs @@ -8,7 +8,7 @@ use polars_arrow::record_batch::RecordBatchT as Chunk; mod util; -pub use util::project_schema; +pub use util::{project_schema, try_project_schema}; pub type ArrowChunk = Chunk>; diff --git a/hyperfuel-schema/src/util.rs b/hyperfuel-schema/src/util.rs index c514be2..5268acd 100644 --- a/hyperfuel-schema/src/util.rs +++ b/hyperfuel-schema/src/util.rs @@ -1,27 +1,42 @@ use std::collections::BTreeSet; +use anyhow::{anyhow, Result}; use polars_arrow::datatypes::ArrowSchema as Schema; +/// Project a schema down to the named fields. Missing field names are silently +/// dropped from the result. +/// +/// Prefer [`try_project_schema`], which returns an error listing any unknown +/// names. This non-erroring variant is retained for backward compatibility. pub fn project_schema(schema: &Schema, field_selection: &BTreeSet) -> Schema { - let mut select_indices = Vec::new(); - for col_name in field_selection.iter() { - if let Some((idx, _)) = schema - .fields - .iter() - .enumerate() - .find(|(_, f)| &f.name == col_name) - { - select_indices.push(idx); - } - } - - let schema: Schema = schema + schema .fields .iter() .filter(|f| field_selection.contains(&f.name)) .cloned() .collect::>() - .into(); + .into() +} - schema +/// Project a schema down to the named fields, returning an error if any +/// requested field is not present in the source schema. Use this to catch +/// typos in field selections instead of silently receiving a partial schema. +pub fn try_project_schema(schema: &Schema, field_selection: &BTreeSet) -> Result { + let known: BTreeSet<&String> = schema.fields.iter().map(|f| &f.name).collect(); + let missing: Vec<&String> = field_selection + .iter() + .filter(|name| !known.contains(name)) + .collect(); + if !missing.is_empty() { + return Err(anyhow!( + "selected columns not found in schema: {}", + missing + .into_iter() + .map(String::as_str) + .collect::>() + .join(", ") + )); + } + + Ok(project_schema(schema, field_selection)) }