Skip to content

Commit

Permalink
improve signature of ffi::PyIter_Send & add PyIterator::send
Browse files Browse the repository at this point in the history
  • Loading branch information
bschoenmaeckers committed Dec 9, 2024
1 parent 992865b commit 24e5272
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 1 deletion.
1 change: 1 addition & 0 deletions newsfragments/4746.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added `PyIterator::send` method to allow sending values into a python generator.
1 change: 1 addition & 0 deletions newsfragments/4746.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed the return value of pyo3-ffi's PyIter_Send() function to return PySendResult.
6 changes: 5 additions & 1 deletion pyo3-ffi/src/abstract_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ extern "C" {
pub fn PyIter_Next(arg1: *mut PyObject) -> *mut PyObject;
#[cfg(all(not(PyPy), Py_3_10))]
#[cfg_attr(PyPy, link_name = "PyPyIter_Send")]
pub fn PyIter_Send(iter: *mut PyObject, arg: *mut PyObject, presult: *mut *mut PyObject);
pub fn PyIter_Send(
iter: *mut PyObject,
arg: *mut PyObject,
presult: *mut *mut PyObject,
) -> PySendResult;

#[cfg_attr(PyPy, link_name = "PyPyNumber_Check")]
pub fn PyNumber_Check(o: *mut PyObject) -> c_int;
Expand Down
69 changes: 69 additions & 0 deletions src/types/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,35 @@ impl PyIterator {
}
}

#[derive(Debug)]
#[cfg(all(not(PyPy), Py_3_10))]
pub enum PySendResult<'py> {
Next(Bound<'py, PyAny>),
Return(Bound<'py, PyAny>),
}

impl Bound<'_, PyIterator> {
/// Sends a value into a python generator. This is the equivalent of calling `generator.send(value)` in Python.
/// This resumes the generator and continues its execution until the next `yield` or `return` statement.
/// If the generator exits without returning a value, this function returns a `StopException`.
/// The first call to `send` must be made with `None` as the argument to start the generator, failing to do so will raise a `TypeError`.
#[inline]
#[cfg(all(not(PyPy), Py_3_10))]
pub fn send<'py>(&self, value: &Bound<'py, PyAny>) -> PyResult<PySendResult<'py>> {
let py = self.py();
let mut result = std::ptr::null_mut();
match unsafe { ffi::PyIter_Send(self.as_ptr(), value.as_ptr(), &mut result) } {
ffi::PySendResult::PYGEN_ERROR => Err(PyErr::fetch(py)),
ffi::PySendResult::PYGEN_RETURN => Ok(PySendResult::Return(unsafe {
result.assume_owned_unchecked(py)
})),
ffi::PySendResult::PYGEN_NEXT => Ok(PySendResult::Next(unsafe {
result.assume_owned_unchecked(py)
})),
}
}
}

impl<'py> Iterator for Bound<'py, PyIterator> {
type Item = PyResult<Bound<'py, PyAny>>;

Expand Down Expand Up @@ -106,7 +135,11 @@ impl PyTypeCheck for PyIterator {
#[cfg(test)]
mod tests {
use super::PyIterator;
#[cfg(all(not(PyPy), Py_3_10))]
use super::PySendResult;
use crate::exceptions::PyTypeError;
#[cfg(all(not(PyPy), Py_3_10))]
use crate::types::PyNone;
use crate::types::{PyAnyMethods, PyDict, PyList, PyListMethods};
use crate::{ffi, IntoPyObject, Python};

Expand Down Expand Up @@ -201,6 +234,42 @@ def fibonacci(target):
});
}

#[test]
#[cfg(all(not(PyPy), Py_3_10))]
fn send_generator() {
let generator = ffi::c_str!(
r#"
def gen():
value = None
while(True):
value = yield value
if value is None:
return
"#
);

Python::with_gil(|py| {
let context = PyDict::new(py);
py.run(generator, None, Some(&context)).unwrap();

let generator = py.eval(ffi::c_str!("gen()"), None, Some(&context)).unwrap();

let one = 1i32.into_pyobject(py).unwrap();
assert!(matches!(
generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(),
PySendResult::Next(value) if value.is_none()
));
assert!(matches!(
generator.try_iter().unwrap().send(&one).unwrap(),
PySendResult::Next(value) if value.is(&one)
));
assert!(matches!(
generator.try_iter().unwrap().send(&PyNone::get(py)).unwrap(),
PySendResult::Return(value) if value.is_none()
));
});
}

#[test]
fn fibonacci_generator_bound() {
use crate::types::any::PyAnyMethods;
Expand Down

0 comments on commit 24e5272

Please sign in to comment.