From a2265a75af171b11a859fef7151b16ae0110e350 Mon Sep 17 00:00:00 2001 From: Bas Schoenmaeckers Date: Fri, 29 Nov 2024 17:21:20 +0100 Subject: [PATCH] improve signature of `ffi::PyIter_Send` & add `PyIterator::send` --- newsfragments/4746.added.md | 1 + newsfragments/4746.fixed.md | 1 + pyo3-ffi/src/abstract_.rs | 6 +++- src/types/iterator.rs | 69 +++++++++++++++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 newsfragments/4746.added.md create mode 100644 newsfragments/4746.fixed.md diff --git a/newsfragments/4746.added.md b/newsfragments/4746.added.md new file mode 100644 index 00000000000..43fbab18f2c --- /dev/null +++ b/newsfragments/4746.added.md @@ -0,0 +1 @@ +Added `PyIterator::send` method to allow sending values into a python generator. diff --git a/newsfragments/4746.fixed.md b/newsfragments/4746.fixed.md new file mode 100644 index 00000000000..51611432e30 --- /dev/null +++ b/newsfragments/4746.fixed.md @@ -0,0 +1 @@ +Fixed the return value of pyo3-ffi's PyIter_Send() function to return PySendResult. \ No newline at end of file diff --git a/pyo3-ffi/src/abstract_.rs b/pyo3-ffi/src/abstract_.rs index 1899545011a..123fbf0e35e 100644 --- a/pyo3-ffi/src/abstract_.rs +++ b/pyo3-ffi/src/abstract_.rs @@ -120,7 +120,11 @@ extern "C" { pub fn PyIter_Next(arg1: *mut PyObject) -> *mut PyObject; #[cfg(all(not(PyPy), Py_3_10))] #[cfg_attr(PyPy, link_name = "PyPyIter_Send")] - pub fn PyIter_Send(iter: *mut PyObject, arg: *mut PyObject, presult: *mut *mut PyObject); + pub fn PyIter_Send( + iter: *mut PyObject, + arg: *mut PyObject, + presult: *mut *mut PyObject, + ) -> PySendResult; #[cfg_attr(PyPy, link_name = "PyPyNumber_Check")] pub fn PyNumber_Check(o: *mut PyObject) -> c_int; diff --git a/src/types/iterator.rs b/src/types/iterator.rs index 068ab1fce34..d6b391ca753 100644 --- a/src/types/iterator.rs +++ b/src/types/iterator.rs @@ -52,6 +52,35 @@ impl PyIterator { } } +#[derive(Debug)] +#[cfg(all(not(PyPy), Py_3_10))] +pub enum PySendResult<'py> { + Next(Bound<'py, PyAny>), + Return(Bound<'py, PyAny>), +} + +impl Bound<'_, PyIterator> { + /// Sends a value into a python generator. This is the equivalent of calling `generator.send(value)` in Python. + /// This resumes the generator and continues its execution until the next `yield` or `return` statement. + /// If the generator exits without returning a value, this function returns a `StopException`. + /// The first call to `send` must be made with `None` as the argument to start the generator, failing to do so will raise a `TypeError`. + #[inline] + #[cfg(all(not(PyPy), Py_3_10))] + pub fn send(&self, value: &Bound<'py, PyAny>) -> PyResult> { + let py = self.py(); + let mut result = std::ptr::null_mut(); + match unsafe { ffi::PyIter_Send(self.as_ptr(), value.as_ptr(), &mut result) } { + ffi::PySendResult::PYGEN_ERROR => Err(PyErr::fetch(py)), + ffi::PySendResult::PYGEN_RETURN => Ok(PySendResult::Return(unsafe { + result.assume_owned_unchecked(py) + })), + ffi::PySendResult::PYGEN_NEXT => Ok(PySendResult::Next(unsafe { + result.assume_owned_unchecked(py) + })), + } + } +} + impl<'py> Iterator for Bound<'py, PyIterator> { type Item = PyResult>; @@ -106,7 +135,11 @@ impl PyTypeCheck for PyIterator { #[cfg(test)] mod tests { use super::PyIterator; + #[cfg(all(not(PyPy), Py_3_10))] + use super::PySendResult; use crate::exceptions::PyTypeError; + #[cfg(all(not(PyPy), Py_3_10))] + use crate::types::PyNone; use crate::types::{PyAnyMethods, PyDict, PyList, PyListMethods}; use crate::{ffi, IntoPyObject, Python}; @@ -201,6 +234,42 @@ def fibonacci(target): }); } + #[test] + #[cfg(all(not(PyPy), Py_3_10))] + fn send_generator() { + let generator = ffi::c_str!( + r#" +def gen(): + value = None + while(True): + value = yield value + if value is None: + return +"# + ); + + Python::with_gil(|py| { + let context = PyDict::new(py); + py.run(generator, None, Some(&context)).unwrap(); + + let generator = py.eval(ffi::c_str!("gen()"), None, Some(&context)).unwrap(); + + let one = 1i32.into_pyobject(py).unwrap(); + assert!(matches!( + generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(), + PySendResult::Next(value) if value.is_none() + )); + assert!(matches!( + generator.try_iter().unwrap().send(&one).unwrap(), + PySendResult::Next(value) if value.is(&one) + )); + assert!(matches!( + generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(), + PySendResult::Return(value) if value.is_none() + )); + }); + } + #[test] fn fibonacci_generator_bound() { use crate::types::any::PyAnyMethods;