-
Notifications
You must be signed in to change notification settings - Fork 128
Open
Description
Bug Description
When using PyArrayLikeDyn with AllowTypeChange, trailing singleton axes may be removed from inputs that are ndarrays but have the wrong dtype.
Steps to Reproduce
Cargo.toml
[package]
name = "singleton-removed"
version = "0.1.0"
edition = "2024"
[dependencies]
numpy = "0.24.0"
pyo3 = { version = "0.24.2", features = ["auto-initialize"] }main.rs
use numpy::{AllowTypeChange, PyArrayDyn, PyArrayLikeDyn, PyArrayMethods};
use pyo3::ffi::c_str;
use pyo3::prelude::*;
#[pyfunction]
fn double<'py>(
py: Python<'py>,
a: PyArrayLikeDyn<'py, f64, AllowTypeChange>,
) -> Bound<'py, PyArrayDyn<f64>> {
PyArrayDyn::from_owned_array(py, a.to_owned_array() * 2.0)
}
#[pymodule]
fn singleton_removed(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(double, m)?)?;
Ok(())
}
fn main() -> PyResult<()> {
pyo3::append_to_inittab!(singleton_removed);
Python::with_gil(|py| {
let code = c_str!(include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/example.py"
)));
py.run(code, None, None)?;
Ok(())
})
}example.py
import singleton_removed
import numpy as np
a = np.ones((3, 1), dtype=np.int32)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"This results in the following error (plus a deprecation warning from numpy, seemingly for implicitly removing the singleton axis):
<string>:5: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
Error: PyErr { type: <class 'AssertionError'>, value: AssertionError('a.shape=(3, 1), b.shape=(3,)'), traceback: Some("Traceback (most recent call last):\n File \"<string>\", line 6, in <module>\n") }
If no type change occurs, the axis is preserved, e.g.
import singleton_removed
import numpy as np
a = np.ones((3, 1), dtype=np.float64)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"succeeds. Non-array inputs also behave properly, e.g.
import singleton_removed
import numpy as np
a = [[1], [2], [3]]
b = singleton_removed.double(a)
assert b.shape == (3, 1), f"{a.shape=}, {b.shape=}"also succeeds.
Oddly enough, if I add a third axis, the trailing singleton dimension is no longer removed:
import singleton_removed
import numpy as np
a = np.ones((3, 2, 1), dtype=np.int32)
b = singleton_removed.double(a)
assert a.shape == b.shape, f"{a.shape=}, {b.shape=}"Relevant Info
Python Version
3.13.3
NumPy Version
2.2.5
PyO3 Version
0.24.2
rust-numpy Version
0.24.0
rustc Version
1.86.0
OS
Distributor ID: Ubuntu
Description: Ubuntu 24.04.2 LTS
Release: 24.04
Codename: noble
(via WSL)
Metadata
Metadata
Assignees
Labels
No labels