Skip to content

Commit

Permalink
implement fast image resize (#59)
Browse files Browse the repository at this point in the history
* add more benchmarks

* implement fast resize

* add python resize test

* fix readme
  • Loading branch information
edgarriba authored Mar 17, 2024
1 parent 15795e5 commit 04d3a6f
Show file tree
Hide file tree
Showing 14 changed files with 274 additions and 157 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ image = { version = "0.25.0" }
turbojpeg = {version = "1.0.0"}
memmap2 = "0.9.4"
ndarray = { version = "0.15.6", features = ["rayon"] }
fast_image_resize = "3.0.4"
# this is experimental and only used for benchmarking, so it's optional
# consider removing it in the future.
candle-core = { version = "0.3.2", optional = true }
Expand Down
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// convert the image to grayscale
let gray: Image<f32, 1> = kornia_rs::color::gray_from_rgb(&image_f32)?;

let gray_resize: Image<f32, 1> = kornia_rs::resize::resize(
let gray_resize: Image<f32, 1> = kornia_rs::resize::resize_native(
&gray,
kornia_rs::image::ImageSize {
width: 128,
height: 128,
},
kornia_rs::resize::ResizeOptions::default(),
kornia_rs::resize::InterpolationMode::Bilinear,
)?;

println!("gray_resize: {:?}", gray_resize.size());
Expand Down Expand Up @@ -186,6 +186,20 @@ image_decoder = K.ImageDecoder()
decoded_img: np.ndarray = image_decoder.decode(bytes(image_encoded))
```
Resize an image using the `kornia-rs` backend with SIMD acceleration
```python
import kornia_rs as K

# load image with kornia-rs
img = K.read_image_jpeg("dog.jpeg")

# resize the image
resized_img = K.resize(img, (128, 128), interpolation="bilinear")

assert resized_img.shape == (128, 128, 3)
```
## 🧑‍💻 Development
Pre-requisites: install `rust` and `python3` in your system.
Expand Down
1 change: 0 additions & 1 deletion benches/bench_io.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rerun::external::arrow2::ffi::mmap;

struct JpegReader {
decoder: kornia_rs::io::jpeg::ImageDecoder,
Expand Down
23 changes: 9 additions & 14 deletions benches/bench_resize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri

use kornia_rs::image::{Image, ImageSize};
use kornia_rs::resize as F;
use kornia_rs::resize::{InterpolationMode, ResizeOptions};
use kornia_rs::resize::InterpolationMode;

fn resize_image_crate(image: Image<u8, 3>, new_size: ImageSize) -> Image<u8, 3> {
let image_data = image.data.as_slice().unwrap();
Expand All @@ -17,14 +17,14 @@ fn resize_image_crate(image: Image<u8, 3>, new_size: ImageSize) -> Image<u8, 3>
let image_resized = image_crate.resize_exact(
new_size.width as u32,
new_size.height as u32,
image::imageops::FilterType::Gaussian,
image::imageops::FilterType::Nearest,
);
let data = image_resized.into_rgb8().into_raw();
Image::new(new_size, data).unwrap()
}

fn bench_resize(c: &mut Criterion) {
let mut group = c.benchmark_group("Resize");
let mut group = c.benchmark_group("resize");
let image_sizes = vec![(256, 224), (512, 448), (1024, 896)];

for (width, height) in image_sizes {
Expand All @@ -36,20 +36,15 @@ fn bench_resize(c: &mut Criterion) {
width: width / 2,
height: height / 2,
};
group.bench_with_input(BenchmarkId::new("zip", &id), &image_f32, |b, i| {
b.iter(|| {
F::resize(
black_box(i),
new_size,
ResizeOptions {
interpolation: InterpolationMode::Bilinear,
},
)
})
group.bench_with_input(BenchmarkId::new("native", &id), &image_f32, |b, i| {
b.iter(|| F::resize_native(black_box(i), new_size, InterpolationMode::Nearest))
});
group.bench_with_input(BenchmarkId::new("image_crate", &id), &image, |b, i| {
group.bench_with_input(BenchmarkId::new("image_rs", &id), &image, |b, i| {
b.iter(|| resize_image_crate(black_box(i.clone()), new_size))
});
group.bench_with_input(BenchmarkId::new("fast", &id), &image, |b, i| {
b.iter(|| F::resize_fast(black_box(i), new_size, InterpolationMode::Nearest))
});
}
group.finish();
}
Expand Down
4 changes: 2 additions & 2 deletions examples/imgproc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// convert the image to grayscale
let gray: Image<f32, 1> = kornia_rs::color::gray_from_rgb(&image_f32)?;

let gray_resize: Image<f32, 1> = kornia_rs::resize::resize(
let gray_resize: Image<f32, 1> = kornia_rs::resize::resize_native(
&gray,
kornia_rs::image::ImageSize {
width: 128,
height: 128,
},
kornia_rs::resize::ResizeOptions::default(),
kornia_rs::resize::InterpolationMode::Bilinear,
)?;

println!("gray_resize: {:?}", gray_resize.size());
Expand Down
12 changes: 12 additions & 0 deletions py-kornia/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions py-kornia/benchmark/bench_resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import timeit

import cv2
from PIL import Image
import kornia_rs
import numpy as np
# import tensorflow as tf

image_path = "tests/data/dog.jpeg"
img = kornia_rs.read_image_jpeg(image_path)
img_pil = Image.open(image_path)
new_size = (128, 128)
N = 5000 # number of iterations


def resize_image_opencv(img: np.ndarray, new_size: tuple) -> None:
return cv2.resize(img, new_size, interpolation=cv2.INTER_LINEAR)

def resize_image_pil(img: Image.Image, new_size: tuple) -> None:
return img.resize(new_size, Image.BILINEAR)

def resize_image_kornia(img: np.ndarray, new_size: tuple) -> None:
return kornia_rs.resize(img, new_size, "bilinear")

tests = [
{
"name": "OpenCV",
"stmt": "resize_image_opencv(img, new_size)",
"setup": "from __main__ import resize_image_opencv, img, new_size",
"globals": {"img": img, "new_size": new_size},
},
{
"name": "PIL",
"stmt": "resize_image_pil(img_pil, new_size)",
"setup": "from __main__ import resize_image_pil, img_pil, new_size",
"globals": {"img_pil": img_pil, "new_size": new_size},
},
{
"name": "Kornia",
"stmt": "resize_image_kornia(img, new_size)",
"setup": "from __main__ import resize_image_kornia, img, new_size",
"globals": {"img": img, "new_size": new_size},
},
]

for test in tests:
timer = timeit.Timer(
stmt=test["stmt"], setup=test["setup"], globals=test["globals"]
)
print(f"{test['name']}: {timer.timeit(N)/ N * 1e3:.2f} ms")
77 changes: 0 additions & 77 deletions py-kornia/benchmark/resize_benchmark.py

This file was deleted.

2 changes: 2 additions & 0 deletions py-kornia/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod image;
mod io;
mod resize;

use crate::image::PyImageSize;
use crate::io::functional::{read_image_any, read_image_jpeg, write_image_jpeg};
Expand All @@ -22,6 +23,7 @@ pub fn kornia_rs(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(read_image_jpeg, m)?)?;
m.add_function(wrap_pyfunction!(write_image_jpeg, m)?)?;
m.add_function(wrap_pyfunction!(read_image_any, m)?)?;
m.add_function(wrap_pyfunction!(resize::resize, m)?)?;
m.add_class::<PyImageSize>()?;
m.add_class::<PyImageDecoder>()?;
m.add_class::<PyImageEncoder>()?;
Expand Down
30 changes: 30 additions & 0 deletions py-kornia/src/resize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use pyo3::prelude::*;

use crate::image::{FromPyImage, PyImage, ToPyImage};
use kornia_rs::image::Image;

#[pyfunction]
pub fn resize(image: PyImage, new_size: (usize, usize), interpolation: &str) -> PyResult<PyImage> {
let image = Image::from_pyimage(image)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyException, _>(format!("{}", e)))?;

let new_size = kornia_rs::image::ImageSize {
height: new_size.0,
width: new_size.1,
};

let interpolation = match interpolation.to_lowercase().as_str() {
"nearest" => kornia_rs::resize::InterpolationMode::Nearest,
"bilinear" => kornia_rs::resize::InterpolationMode::Bilinear,
_ => {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Invalid interpolation mode",
))
}
};

let image = kornia_rs::resize::resize_fast(&image, new_size, interpolation)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyException, _>(format!("{}", e)))?;

Ok(image.to_pyimage())
}
20 changes: 20 additions & 0 deletions py-kornia/tests/test_resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pathlib import Path
import kornia_rs as K

import torch
import numpy as np

# TODO: inject this from elsewhere
DATA_DIR = Path(__file__).parents[2] / "tests" / "data"


def test_resize():
# load an image with libjpeg-turbo
img_path: Path = DATA_DIR / "dog.jpeg"
img: np.ndarray = K.read_image_jpeg(str(img_path.absolute()))

# check the image properties
assert img.shape == (195, 258, 3)

img_resized: np.ndarray = K.resize(img, (43, 34), "bilinear")
assert img_resized.shape == (43, 34, 3)
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ pub mod io;
pub mod metrics;
pub mod normalize;
pub mod resize;
pub mod tensor;
// NOTE: not ready yet
// pub mod tensor;
pub mod threshold;
Loading

0 comments on commit 04d3a6f

Please sign in to comment.