From 426bf6c746e3a53f24865f78eee46df317e6edfc Mon Sep 17 00:00:00 2001 From: edgar Date: Sat, 28 Sep 2024 16:40:01 +0200 Subject: [PATCH 1/4] expose to python --- Cargo.toml | 1 + crates/kornia-imgproc/src/color/gray.rs | 168 +++++++++++++++++- crates/kornia-imgproc/src/color/mod.rs | 2 +- examples/onnx/Cargo.toml | 1 - kornia-py/Cargo.toml | 1 - kornia-py/src/color.rs | 58 ++++++ kornia-py/src/image.rs | 11 +- kornia-py/src/lib.rs | 5 + kornia-py/src/warp.rs | 42 +++++ kornia-py/tests/test_color.py | 24 +++ .../{test_warp_affine.py => test_warp.py} | 14 ++ 11 files changed, 315 insertions(+), 12 deletions(-) create mode 100644 kornia-py/src/color.rs create mode 100644 kornia-py/tests/test_color.py rename kornia-py/tests/{test_warp_affine.py => test_warp.py} (59%) diff --git a/Cargo.toml b/Cargo.toml index 6512758a..a65e9eee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "crates/kornia-imgproc", "crates/kornia", "examples/*", + "kornia-py", ] exclude = ["kornia-py", "kornia-serve"] diff --git a/crates/kornia-imgproc/src/color/gray.rs b/crates/kornia-imgproc/src/color/gray.rs index cc3783c9..9ec69d3e 100644 --- a/crates/kornia-imgproc/src/color/gray.rs +++ b/crates/kornia-imgproc/src/color/gray.rs @@ -70,6 +70,93 @@ where Ok(()) } +/// Convert a grayscale image to an RGB image by replicating the grayscale value across all three channels. +/// +/// # Arguments +/// +/// * `src` - The input grayscale image. +/// * `dst` - The output RGB image. +/// +/// Precondition: the input image must have 1 channel. +/// Precondition: the output image must have 3 channels. +/// Precondition: the input and output images must have the same size. +/// +/// # Example +/// +/// ``` +/// use kornia_image::{Image, ImageSize}; +/// use kornia_imgproc::color::rgb_from_grayscale; +/// +/// let image = Image::::new( +/// ImageSize { +/// width: 4, +/// height: 5, +/// }, +/// vec![0f32; 4 * 5 * 1], +/// ) +/// .unwrap(); +/// +/// let mut rgb = Image::::from_size_val(image.size(), 0.0).unwrap(); +/// +/// rgb_from_gray(&image, &mut rgb).unwrap(); +/// ``` +pub fn rgb_from_gray(src: &Image, dst: &mut Image) -> Result<(), ImageError> +where + T: SafeTensorType, +{ + if src.size() != dst.size() { + return Err(ImageError::InvalidImageSize( + src.cols(), + src.rows(), + dst.cols(), + dst.rows(), + )); + } + + // parallelize the grayscale conversion by rows + parallel::par_iter_rows(src, dst, |src_pixel, dst_pixel| { + let gray = src_pixel[0]; + dst_pixel.iter_mut().for_each(|dst_pixel| { + *dst_pixel = gray; + }); + }); + + Ok(()) +} + +/// Convert an RGB image to BGR by swapping the red and blue channels. +/// +/// # Arguments +/// +/// * `src` - The input RGB image. +/// * `dst` - The output BGR image. +/// +/// Precondition: the input and output images must have the same size. +pub fn bgr_from_rgb(src: &Image, dst: &mut Image) -> Result<(), ImageError> +where + T: SafeTensorType, +{ + if src.size() != dst.size() { + return Err(ImageError::InvalidImageSize( + src.cols(), + src.rows(), + dst.cols(), + dst.rows(), + )); + } + + parallel::par_iter_rows(src, dst, |src_pixel, dst_pixel| { + dst_pixel + .iter_mut() + .zip(src_pixel.iter().rev()) + .for_each(|(d, s)| { + *d = *s; + }); + }); + + Ok(()) +} + #[cfg(test)] mod tests { use kornia_image::{ops, Image, ImageSize}; @@ -94,14 +181,19 @@ mod tests { #[test] fn gray_from_rgb_regression() -> Result<(), Box> { + #[rustfmt::skip] let image = Image::new( ImageSize { width: 2, height: 3, }, vec![ - 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, + 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, ], )?; @@ -123,4 +215,76 @@ mod tests { Ok(()) } + + #[test] + fn rgb_from_grayscale() -> Result<(), Box> { + let image = Image::new( + ImageSize { + width: 2, + height: 3, + }, + vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], + )?; + + let mut rgb = Image::::from_size_val(image.size(), 0.0)?; + + super::rgb_from_gray(&image, &mut rgb)?; + + #[rustfmt::skip] + let expected: Image = Image::new( + ImageSize { + width: 2, + height: 3, + }, + vec![ + 0.0, 0.0, 0.0, + 1.0, 1.0, 1.0, + 2.0, 2.0, 2.0, + 3.0, 3.0, 3.0, + 4.0, 4.0, 4.0, + 5.0, 5.0, 5.0, + ], + )?; + + assert_eq!(rgb.as_slice(), expected.as_slice()); + + Ok(()) + } + + #[test] + fn bgr_from_rgb() -> Result<(), Box> { + #[rustfmt::skip] + let image = Image::new( + ImageSize { + width: 1, + height: 3, + }, + vec![ + 0.0, 1.0, 2.0, + 3.0, 4.0, 5.0, + 6.0, 7.0, 8.0, + ], + )?; + + let mut bgr = Image::::from_size_val(image.size(), 0.0)?; + + super::bgr_from_rgb(&image, &mut bgr)?; + + #[rustfmt::skip] + let expected: Image = Image::new( + ImageSize { + width: 1, + height: 3, + }, + vec![ + 2.0, 1.0, 0.0, + 5.0, 4.0, 3.0, + 8.0, 7.0, 6.0, + ], + )?; + + assert_eq!(bgr.as_slice(), expected.as_slice()); + + Ok(()) + } } diff --git a/crates/kornia-imgproc/src/color/mod.rs b/crates/kornia-imgproc/src/color/mod.rs index 6e1fac93..7c7242cf 100644 --- a/crates/kornia-imgproc/src/color/mod.rs +++ b/crates/kornia-imgproc/src/color/mod.rs @@ -1,5 +1,5 @@ mod gray; mod hsv; -pub use gray::gray_from_rgb; +pub use gray::{bgr_from_rgb, gray_from_rgb, rgb_from_gray}; pub use hsv::hsv_from_rgb; diff --git a/examples/onnx/Cargo.toml b/examples/onnx/Cargo.toml index 6c7ce852..d395987a 100644 --- a/examples/onnx/Cargo.toml +++ b/examples/onnx/Cargo.toml @@ -7,7 +7,6 @@ edition.workspace = true homepage.workspace = true include.workspace = true license.workspace = true -license-file.workspace = true readme.workspace = true repository.workspace = true rust-version.workspace = true diff --git a/kornia-py/Cargo.toml b/kornia-py/Cargo.toml index 2a430da5..a53a356e 100644 --- a/kornia-py/Cargo.toml +++ b/kornia-py/Cargo.toml @@ -6,7 +6,6 @@ edition = "2021" homepage = "http://kornia.org" include = ["Cargo.toml"] license = "Apache-2.0" -license-file = "LICENSE" repository = "https://github.com/kornia/kornia-rs" rust-version = "1.76" version = "0.1.6-rc.5" diff --git a/kornia-py/src/color.rs b/kornia-py/src/color.rs new file mode 100644 index 00000000..01dc743a --- /dev/null +++ b/kornia-py/src/color.rs @@ -0,0 +1,58 @@ +use pyo3::prelude::*; + +use crate::image::{FromPyImage, PyImage, ToPyImage}; +use kornia_image::Image; +use kornia_imgproc::color; + +#[pyfunction] +pub fn rgb_from_gray(image: PyImage) -> PyResult { + let image_gray = Image::from_pyimage(image) + .map_err(|e| PyErr::new::(format!("src image: {}", e)))?; + + let mut image_rgb = Image::from_size_val(image_gray.size(), 0u8) + .map_err(|e| PyErr::new::(format!("dst image: {}", e)))?; + + color::rgb_from_gray(&image_gray, &mut image_rgb).map_err(|e| { + PyErr::new::(format!("failed to convert image: {}", e)) + })?; + + Ok(image_rgb.to_pyimage()) +} + +#[pyfunction] +pub fn bgr_from_rgb(image: PyImage) -> PyResult { + let image_rgb = Image::from_pyimage(image) + .map_err(|e| PyErr::new::(format!("src image: {}", e)))?; + + let mut image_bgr = Image::from_size_val(image_rgb.size(), 0u8) + .map_err(|e| PyErr::new::(format!("dst image: {}", e)))?; + + color::bgr_from_rgb(&image_rgb, &mut image_bgr).map_err(|e| { + PyErr::new::(format!("failed to convert image: {}", e)) + })?; + + Ok(image_bgr.to_pyimage()) +} + +#[pyfunction] +pub fn gray_from_rgb(image: PyImage) -> PyResult { + let image_rgb = Image::from_pyimage(image) + .map_err(|e| PyErr::new::(format!("src image: {}", e)))?; + + let image_rgb = image_rgb.cast::().map_err(|e| { + PyErr::new::(format!("failed to convert image: {}", e)) + })?; + + let mut image_gray = Image::from_size_val(image_rgb.size(), 0f32) + .map_err(|e| PyErr::new::(format!("dst image: {}", e)))?; + + color::gray_from_rgb(&image_rgb, &mut image_gray).map_err(|e| { + PyErr::new::(format!("failed to convert image: {}", e)) + })?; + + let image_gray = image_gray.cast::().map_err(|e| { + PyErr::new::(format!("failed to convert image: {}", e)) + })?; + + Ok(image_gray.to_pyimage()) +} diff --git a/kornia-py/src/image.rs b/kornia-py/src/image.rs index ddd96aa9..d3f93ac3 100644 --- a/kornia-py/src/image.rs +++ b/kornia-py/src/image.rs @@ -5,7 +5,6 @@ use pyo3::prelude::*; // type alias for a 3D numpy array of u8 pub type PyImage = Py>; -//pub type PyImage<'a> = Bound<'a, PyArray3>; /// Trait to convert an image to a PyImage (3D numpy array of u8) pub trait ToPyImage { @@ -36,12 +35,10 @@ impl FromPyImage for Image { // TODO: we should find a way to avoid copying the data // Possible solutions: // - Use a custom ndarray wrapper that does not copy the data - // - Return direectly pyarray and use it in the Rust code - let data = unsafe { - match pyarray.as_slice() { - Ok(d) => d.to_vec(), - Err(_) => return Err(ImageError::ImageDataNotContiguous), - } + // - Return directly pyarray and use it in the Rust code + let data = match pyarray.to_vec() { + Ok(d) => d, + Err(_) => return Err(ImageError::ImageDataNotContiguous), }; let size = ImageSize { diff --git a/kornia-py/src/lib.rs b/kornia-py/src/lib.rs index 30217baf..6fac034f 100644 --- a/kornia-py/src/lib.rs +++ b/kornia-py/src/lib.rs @@ -1,3 +1,4 @@ +mod color; mod histogram; mod image; mod io; @@ -22,11 +23,15 @@ pub fn get_version() -> String { #[pymodule] pub fn kornia_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add("__version__", get_version())?; + m.add_function(wrap_pyfunction!(color::rgb_from_gray, m)?)?; + m.add_function(wrap_pyfunction!(color::bgr_from_rgb, m)?)?; + m.add_function(wrap_pyfunction!(color::gray_from_rgb, m)?)?; m.add_function(wrap_pyfunction!(read_image_jpeg, m)?)?; m.add_function(wrap_pyfunction!(write_image_jpeg, m)?)?; m.add_function(wrap_pyfunction!(read_image_any, m)?)?; m.add_function(wrap_pyfunction!(resize::resize, m)?)?; m.add_function(wrap_pyfunction!(warp::warp_affine, m)?)?; + m.add_function(wrap_pyfunction!(warp::warp_perspective, m)?)?; m.add_function(wrap_pyfunction!(histogram::compute_histogram, m)?)?; m.add_class::()?; m.add_class::()?; diff --git a/kornia-py/src/warp.rs b/kornia-py/src/warp.rs index 402aa5b1..866276af 100644 --- a/kornia-py/src/warp.rs +++ b/kornia-py/src/warp.rs @@ -50,3 +50,45 @@ pub fn warp_affine( Ok(image_warped.to_pyimage()) } + +#[pyfunction] +pub fn warp_perspective( + image: PyImage, + m: [f32; 9], + new_size: (usize, usize), + interpolation: &str, +) -> PyResult { + let image: Image = Image::from_pyimage(image) + .map_err(|e| PyErr::new::(format!("{}", e)))?; + + let new_size = ImageSize { + height: new_size.0, + width: new_size.1, + }; + + let interpolation = match interpolation.to_lowercase().as_str() { + "nearest" => InterpolationMode::Nearest, + "bilinear" => InterpolationMode::Bilinear, + _ => { + return Err(PyErr::new::( + "Invalid interpolation mode", + )) + } + }; + + let image = image + .cast::() + .map_err(|e| PyErr::new::(format!("{}", e)))?; + + let mut image_warped = Image::from_size_val(new_size, 0f32) + .map_err(|e| PyErr::new::(format!("{}", e)))?; + + warp::warp_perspective(&image, &mut image_warped, &m, interpolation) + .map_err(|e| PyErr::new::(format!("{}", e)))?; + + let image_warped = image_warped + .cast::() + .map_err(|e| PyErr::new::(format!("{}", e)))?; + + Ok(image_warped.to_pyimage()) +} diff --git a/kornia-py/tests/test_color.py b/kornia-py/tests/test_color.py new file mode 100644 index 00000000..6d6feb53 --- /dev/null +++ b/kornia-py/tests/test_color.py @@ -0,0 +1,24 @@ +from pathlib import Path +import kornia_rs as K + +import numpy as np + + +def test_rgb_from_gray(): + img: np.ndarray = np.array([[[1]]], dtype=np.uint8) + img_rgb: np.ndarray = K.rgb_from_gray(img) + assert img_rgb.shape == (1, 1, 3) + assert np.allclose(img_rgb, np.array([[[1, 1, 1]]])) + + +def test_bgr_from_rgb(): + img: np.ndarray = np.array([[[1, 2, 3]]], dtype=np.uint8) + img_bgr: np.ndarray = K.bgr_from_rgb(img) + assert img_bgr.shape == (1, 1, 3) + assert np.allclose(img_bgr, np.array([[[3, 2, 1]]])) + +def test_gray_from_rgb(): + img: np.ndarray = np.array([[[1, 1, 1]]], dtype=np.uint8) + img_gray: np.ndarray = K.gray_from_rgb(img) + assert img_gray.shape == (1, 1, 1) + assert np.allclose(img_gray, np.array([[[1]]])) diff --git a/kornia-py/tests/test_warp_affine.py b/kornia-py/tests/test_warp.py similarity index 59% rename from kornia-py/tests/test_warp_affine.py rename to kornia-py/tests/test_warp.py index 962701a8..ef24a72b 100644 --- a/kornia-py/tests/test_warp_affine.py +++ b/kornia-py/tests/test_warp.py @@ -21,3 +21,17 @@ def test_warp_affine(): img, affine_matrix, img.shape[:2], "bilinear" ) assert (img_transformed == img).all() + + +def test_warp_perspective(): + img_path: Path = DATA_DIR / "dog.jpeg" + img: np.ndarray = K.read_image_jpeg(str(img_path.absolute())) + + assert img.shape == (195, 258, 3) + + perspective_matrix = (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) + + img_transformed: np.ndarray = K.warp_perspective( + img, perspective_matrix, img.shape[:2], "bilinear" + ) + assert (img_transformed == img).all() From 820f9543acb66c03662836545a53b74fa4f97e93 Mon Sep 17 00:00:00 2001 From: edgar Date: Sat, 28 Sep 2024 16:52:02 +0200 Subject: [PATCH 2/4] implement add_weighted --- crates/kornia-imgproc/src/enhance.rs | 4 +-- kornia-py/src/enhance.rs | 44 ++++++++++++++++++++++++++++ kornia-py/src/lib.rs | 2 ++ kornia-py/tests/test_color.py | 1 - kornia-py/tests/test_enhance.py | 11 +++++++ 5 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 kornia-py/src/enhance.rs create mode 100644 kornia-py/tests/test_enhance.py diff --git a/crates/kornia-imgproc/src/enhance.rs b/crates/kornia-imgproc/src/enhance.rs index 57daeb58..a6f9588f 100644 --- a/crates/kornia-imgproc/src/enhance.rs +++ b/crates/kornia-imgproc/src/enhance.rs @@ -28,9 +28,9 @@ pub fn add_weighted( src1: &Image, alpha: T, src2: &Image, - dst: &mut Image, beta: T, gamma: T, + dst: &mut Image, ) -> Result<(), ImageError> where T: num_traits::Float @@ -96,7 +96,7 @@ mod tests { let mut weighted = Image::::from_size_val(src1.size(), 0.0)?; - super::add_weighted(&src1, alpha, &src2, &mut weighted, beta, gamma)?; + super::add_weighted(&src1, alpha, &src2, beta, gamma, &mut weighted)?; weighted .as_slice() diff --git a/kornia-py/src/enhance.rs b/kornia-py/src/enhance.rs new file mode 100644 index 00000000..daaea925 --- /dev/null +++ b/kornia-py/src/enhance.rs @@ -0,0 +1,44 @@ +use pyo3::prelude::*; + +use crate::image::{FromPyImage, PyImage, ToPyImage}; +use kornia_image::Image; +use kornia_imgproc::enhance; + +#[pyfunction] +pub fn add_weighted( + src1: PyImage, + alpha: f32, + src2: PyImage, + beta: f32, + gamma: f32, +) -> PyResult { + let image1: Image = Image::from_pyimage(src1).map_err(|e| { + PyErr::new::(format!("src1 image: {}", e)) + })?; + + let image2: Image = Image::from_pyimage(src2).map_err(|e| { + PyErr::new::(format!("src2 image: {}", e)) + })?; + + // cast input images to f32 + let image1 = image1.cast::().map_err(|e| { + PyErr::new::(format!("src1 image: {}", e)) + })?; + + let image2 = image2.cast::().map_err(|e| { + PyErr::new::(format!("src2 image: {}", e)) + })?; + + let mut dst: Image = Image::from_size_val(image1.size(), 0.0f32) + .map_err(|e| PyErr::new::(format!("dst image: {}", e)))?; + + enhance::add_weighted(&image1, alpha, &image2, beta, gamma, &mut dst) + .map_err(|e| PyErr::new::(format!("dst image: {}", e)))?; + + // cast dst image to u8 + let dst = dst + .cast::() + .map_err(|e| PyErr::new::(format!("dst image: {}", e)))?; + + Ok(dst.to_pyimage()) +} diff --git a/kornia-py/src/lib.rs b/kornia-py/src/lib.rs index 6fac034f..6efc361d 100644 --- a/kornia-py/src/lib.rs +++ b/kornia-py/src/lib.rs @@ -1,4 +1,5 @@ mod color; +mod enhance; mod histogram; mod image; mod io; @@ -26,6 +27,7 @@ pub fn kornia_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(color::rgb_from_gray, m)?)?; m.add_function(wrap_pyfunction!(color::bgr_from_rgb, m)?)?; m.add_function(wrap_pyfunction!(color::gray_from_rgb, m)?)?; + m.add_function(wrap_pyfunction!(enhance::add_weighted, m)?)?; m.add_function(wrap_pyfunction!(read_image_jpeg, m)?)?; m.add_function(wrap_pyfunction!(write_image_jpeg, m)?)?; m.add_function(wrap_pyfunction!(read_image_any, m)?)?; diff --git a/kornia-py/tests/test_color.py b/kornia-py/tests/test_color.py index 6d6feb53..b6e4e46b 100644 --- a/kornia-py/tests/test_color.py +++ b/kornia-py/tests/test_color.py @@ -1,4 +1,3 @@ -from pathlib import Path import kornia_rs as K import numpy as np diff --git a/kornia-py/tests/test_enhance.py b/kornia-py/tests/test_enhance.py new file mode 100644 index 00000000..3cc2a658 --- /dev/null +++ b/kornia-py/tests/test_enhance.py @@ -0,0 +1,11 @@ +import kornia_rs as K + +import numpy as np + + +def test_add_weighted(): + img1: np.ndarray = np.array([[[1, 2, 3]]], dtype=np.uint8) + img2: np.ndarray = np.array([[[4, 5, 6]]], dtype=np.uint8) + img_weighted: np.ndarray = K.add_weighted(img1, 0.5, img2, 0.5, 0.0) + assert img_weighted.shape == (1, 1, 3) + assert np.allclose(img_weighted, np.array([[[2, 3, 4]]])) From d937bbde709555e3fe84211b2d868b98cb252046 Mon Sep 17 00:00:00 2001 From: edgar Date: Sat, 28 Sep 2024 16:52:57 +0200 Subject: [PATCH 3/4] disable py crate --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index a65e9eee..6401d10e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ members = [ "crates/kornia-imgproc", "crates/kornia", "examples/*", - "kornia-py", + # "kornia-py", ] exclude = ["kornia-py", "kornia-serve"] From 14693396acb730f76bf44f4f3b85206a2ba2e290 Mon Sep 17 00:00:00 2001 From: edgar Date: Sat, 28 Sep 2024 17:00:16 +0200 Subject: [PATCH 4/4] fix doctest --- crates/kornia-imgproc/src/color/gray.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/kornia-imgproc/src/color/gray.rs b/crates/kornia-imgproc/src/color/gray.rs index 9ec69d3e..51887259 100644 --- a/crates/kornia-imgproc/src/color/gray.rs +++ b/crates/kornia-imgproc/src/color/gray.rs @@ -85,7 +85,7 @@ where /// /// ``` /// use kornia_image::{Image, ImageSize}; -/// use kornia_imgproc::color::rgb_from_grayscale; +/// use kornia_imgproc::color::rgb_from_gray; /// /// let image = Image::::new( /// ImageSize {