From b734b6a85fa547075fb8b23746d4ff6d45d8d3c5 Mon Sep 17 00:00:00 2001 From: espressolee <70549809+espressolee@users.noreply.github.com> Date: Wed, 4 Feb 2026 14:11:31 +0900 Subject: [PATCH] fix: Vec extraction from buffer exporters --- src/conversion.rs | 4 +- src/conversions/std/num.rs | 90 ++++++++++++++++++++++++++++++++--- src/conversions/std/vec.rs | 2 +- tests/test_buffer_protocol.rs | 45 +++++++++++++++++- 4 files changed, 130 insertions(+), 11 deletions(-) diff --git a/src/conversion.rs b/src/conversion.rs index 2e5f3fef507..b55b9a7e53d 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -451,7 +451,7 @@ pub trait FromPyObject<'a, 'py>: Sized { impl FromPyObjectSequence for NeverASequence { type Target = T; - fn to_vec(&self) -> Vec { + fn to_vec(&self) -> PyResult> { unreachable!() } @@ -480,7 +480,7 @@ mod from_py_object_sequence { pub trait FromPyObjectSequence { type Target; - fn to_vec(&self) -> Vec; + fn to_vec(&self) -> PyResult>; fn to_array(&self) -> PyResult<[Self::Target; N]>; } diff --git a/src/conversions/std/num.rs b/src/conversions/std/num.rs index 2161c2f8935..92305653f72 100644 --- a/src/conversions/std/num.rs +++ b/src/conversions/std/num.rs @@ -1,3 +1,5 @@ +#[cfg(any(not(Py_LIMITED_API), Py_3_11))] +use crate::buffer::PyBuffer; use crate::conversion::private::Reference; use crate::conversion::{FromPyObjectSequence, IntoPyObject}; use crate::ffi_ptr_ext::FfiPtrExt; @@ -6,10 +8,14 @@ use crate::inspect::types::TypeInfo; #[cfg(feature = "experimental-inspect")] use crate::inspect::PyStaticExpr; use crate::py_result_ext::PyResultExt; -#[cfg(feature = "experimental-inspect")] -use crate::type_object::PyTypeInfo; -use crate::types::{PyByteArray, PyByteArrayMethods, PyBytes, PyInt}; -use crate::{exceptions, ffi, Borrowed, Bound, FromPyObject, PyAny, PyErr, PyResult, Python}; +use crate::types::sequence::PySequenceMethods; +use crate::types::{ + any::PyAnyMethods, PyByteArray, PyByteArrayMethods, PyBytes, PyInt, PySequence, +}; +use crate::{ + exceptions, ffi, Borrowed, Bound, CastError, FromPyObject, PyAny, PyErr, PyResult, PyTypeInfo, + Python, +}; use std::convert::Infallible; use std::ffi::c_long; use std::mem::MaybeUninit; @@ -317,6 +323,10 @@ impl<'py> FromPyObject<'_, 'py> for u8 { } else if let Ok(byte_array) = obj.cast::() { Some(BytesSequenceExtractor::ByteArray(byte_array)) } else { + #[cfg(any(not(Py_LIMITED_API), Py_3_11))] + if unsafe { ffi::PyObject_CheckBuffer(obj.as_ptr()) != 0 } { + return Some(BytesSequenceExtractor::Buffer(obj.to_any())); + } None } } @@ -325,6 +335,8 @@ impl<'py> FromPyObject<'_, 'py> for u8 { pub(crate) enum BytesSequenceExtractor<'a, 'py> { Bytes(Borrowed<'a, 'py, PyBytes>), ByteArray(Borrowed<'a, 'py, PyByteArray>), + #[cfg(any(not(Py_LIMITED_API), Py_3_11))] + Buffer(Borrowed<'a, 'py, PyAny>), } impl BytesSequenceExtractor<'_, '_> { @@ -348,6 +360,21 @@ impl BytesSequenceExtractor<'_, '_> { copy_slice(unsafe { b.as_bytes() }) }) } + #[cfg(any(not(Py_LIMITED_API), Py_3_11))] + BytesSequenceExtractor::Buffer(any) => { + // Fall back to sequence semantics if the buffer is incompatible with u8 + // (e.g., array('I')). + if let Ok(buf) = PyBuffer::::get(any) { + // Safety: we're about to write the entire `out` slice. + let target = unsafe { + std::slice::from_raw_parts_mut(out.as_mut_ptr().cast::(), out.len()) + }; + buf.copy_to_slice(any.py(), target)?; + Ok(()) + } else { + fill_u8_slice_from_sequence(*any, out) + } + } } } } @@ -355,11 +382,21 @@ impl BytesSequenceExtractor<'_, '_> { impl FromPyObjectSequence for BytesSequenceExtractor<'_, '_> { type Target = u8; - fn to_vec(&self) -> Vec { - match self { + fn to_vec(&self) -> PyResult> { + Ok(match self { BytesSequenceExtractor::Bytes(b) => b.as_bytes().to_vec(), BytesSequenceExtractor::ByteArray(b) => b.to_vec(), - } + #[cfg(any(not(Py_LIMITED_API), Py_3_11))] + BytesSequenceExtractor::Buffer(any) => { + // Fall back to sequence semantics if the buffer is incompatible with u8 + // (e.g., array('I')). + if let Ok(buf) = PyBuffer::::get(any) { + return buf.to_vec(any.py()); + } else { + return extract_u8_vec_from_sequence(*any); + } + } + }) } fn to_array(&self) -> PyResult<[u8; N]> { @@ -377,6 +414,45 @@ impl FromPyObjectSequence for BytesSequenceExtractor<'_, '_> { } } +fn extract_u8_vec_from_sequence<'a, 'py>(obj: Borrowed<'a, 'py, PyAny>) -> PyResult> { + // Types that pass `PySequence_Check` usually implement enough of the sequence protocol + // to support this function and if not, we will only fail extraction safely. + let seq = unsafe { + if ffi::PySequence_Check(obj.as_ptr()) != 0 { + obj.cast_unchecked::() + } else { + return Err(CastError::new(obj, PySequence::type_object(obj.py()).into_any()).into()); + } + }; + + let mut v = Vec::with_capacity(seq.len().unwrap_or(0)); + for item in seq.try_iter()? { + v.push(item?.extract::()?); + } + Ok(v) +} + +fn fill_u8_slice_from_sequence<'a, 'py>( + obj: Borrowed<'a, 'py, PyAny>, + out: &mut [MaybeUninit], +) -> PyResult<()> { + let seq = unsafe { + if ffi::PySequence_Check(obj.as_ptr()) != 0 { + obj.cast_unchecked::() + } else { + return Err(CastError::new(obj, PySequence::type_object(obj.py()).into_any()).into()); + } + }; + let seq_len = seq.len()?; + if seq_len != out.len() { + return Err(invalid_sequence_length(out.len(), seq_len)); + } + for (idx, item) in seq.try_iter()?.enumerate() { + out[idx].write(item?.extract::()?); + } + Ok(()) +} + int_fits_c_long!(i8); int_fits_c_long!(i16); int_fits_c_long!(u16); diff --git a/src/conversions/std/vec.rs b/src/conversions/std/vec.rs index f56e362c41d..afb582a2cd1 100644 --- a/src/conversions/std/vec.rs +++ b/src/conversions/std/vec.rs @@ -73,7 +73,7 @@ where fn extract(obj: Borrowed<'_, 'py, PyAny>) -> PyResult { if let Some(extractor) = T::sequence_extractor(obj, crate::conversion::private::Token) { - return Ok(extractor.to_vec()); + return extractor.to_vec(); } if obj.is_instance_of::() { diff --git a/tests/test_buffer_protocol.rs b/tests/test_buffer_protocol.rs index 9af163f515b..0e9ccc7f5a9 100644 --- a/tests/test_buffer_protocol.rs +++ b/tests/test_buffer_protocol.rs @@ -6,7 +6,7 @@ use pyo3::buffer::PyBuffer; use pyo3::exceptions::PyBufferError; use pyo3::ffi; use pyo3::prelude::*; -use pyo3::types::IntoPyDict; +use pyo3::types::{IntoPyDict, PyBytes, PyDict}; use std::ffi::CString; use std::ffi::{c_int, c_void}; use std::ptr; @@ -15,6 +15,11 @@ use std::sync::Arc; mod test_utils; +#[pyfunction] +fn vec_u8_to_pybytes(py: Python<'_>, bytes: Vec) -> Bound<'_, PyBytes> { + PyBytes::new(py, &bytes) +} + #[pyclass] struct TestBufferClass { vec: Vec, @@ -94,6 +99,44 @@ fn test_buffer_referenced() { assert!(drop_called.load(Ordering::Relaxed)); } +#[test] +fn test_extract_vec_u8_from_buffer_exporter() { + let drop_called = Arc::new(AtomicBool::new(false)); + + Python::attach(|py| { + let instance = Py::new( + py, + TestBufferClass { + vec: vec![b'A', b'B', b'C'], + drop_called: drop_called.clone(), + }, + ) + .unwrap(); + let f = wrap_pyfunction!(vec_u8_to_pybytes)(py).unwrap(); + let env = PyDict::new(py); + env.set_item("ob", instance).unwrap(); + env.set_item("f", f).unwrap(); + py_assert!(py, *env, "f(ob) == b'ABC'"); + }); + + assert!(drop_called.load(Ordering::Relaxed)); +} + +#[test] +fn test_extract_vec_u8_falls_back_when_buffer_incompatible() { + Python::attach(|py| { + let array_mod = py.import("array").unwrap(); + let ob = array_mod + .call_method1("array", ("I", vec![65u32, 66u32, 67u32])) + .unwrap(); + let f = wrap_pyfunction!(vec_u8_to_pybytes)(py).unwrap(); + let env = PyDict::new(py); + env.set_item("ob", ob).unwrap(); + env.set_item("f", f).unwrap(); + py_assert!(py, *env, "f(ob) == b'ABC'"); + }); +} + #[test] #[cfg(Py_3_8)] // sys.unraisablehook not available until Python 3.8 fn test_releasebuffer_unraisable_error() {