Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Iterative Closest Point algorithm and 3d utils #182

Merged
merged 17 commits into from
Dec 16, 2024
11 changes: 11 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ members = [
"crates/kornia-imgproc",
"crates/kornia",
"examples/*",
"crates/kornia-icp",
# "kornia-py",
"kornia-viz",
"crates/kornia-3d",
Expand All @@ -32,6 +33,7 @@ version = "0.1.8-rc.1"
# NOTE: remember to update the kornia-py package version in `kornia-py/Cargo.toml` when updating the Rust package version
kornia-core = { path = "crates/kornia-core", version = "0.1.8-rc.1" }
kornia-core-ops = { path = "crates/kornia-core-ops", version = "0.1.8-rc.1" }
kornia-icp = { path = "crates/kornia-icp", version = "0.1.8-rc.1" }
kornia-image = { path = "crates/kornia-image", version = "0.1.8-rc.1" }
kornia-io = { path = "crates/kornia-io", version = "0.1.8-rc.1" }
kornia-imgproc = { path = "crates/kornia-imgproc", version = "0.1.8-rc.1" }
Expand All @@ -40,4 +42,13 @@ kornia = { path = "crates/kornia", version = "0.1.8-rc.1" }

# dev dependencies for workspace
argh = "0.1.12"
approx = "0.5.1"
bincode = "1.3"
criterion = "0.5.1"
env_logger = "0.11.5"
faer = { version = "0.19.4", features = ["rayon"] }
log = "0.4.22"
rand = "0.8.5"
rerun = "^0.20"
serde = { version = "1.0", features = ["derive"] }
thiserror = "1.0"
15 changes: 12 additions & 3 deletions crates/kornia-3d/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ rust-version.workspace = true
version.workspace = true

[dependencies]
bincode = "1.3"
serde = { version = "1.0", features = ["derive"] }
thiserror = "1.0"
bincode = { workspace = true }
faer = { workspace = true }
serde = { workspace = true }
thiserror = { workspace = true }

[dev-dependencies]
approx = { workspace = true }
criterion = { workspace = true }

[[bench]]
name = "bench_linalg"
harness = false
176 changes: 176 additions & 0 deletions crates/kornia-3d/benches/bench_linalg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};

use kornia_3d::linalg;

// transform_points3d_col using faer with cols point by point
fn transform_points3d_col(
src_points: &[[f64; 3]],
dst_r_src: &[[f64; 3]; 3],
dst_t_src: &[f64; 3],
dst_points: &mut [[f64; 3]],
) {
assert_eq!(src_points.len(), dst_points.len());

// create views of the rotation and translation matrices
let dst_r_src_mat = faer::Mat::<f64>::from_fn(3, 3, |i, j| dst_r_src[i][j]);
let dst_t_src_col = faer::col![dst_t_src[0], dst_t_src[1], dst_t_src[2]];

for (point_dst, point_src) in dst_points.iter_mut().zip(src_points.iter()) {
let point_src_col = faer::col![point_src[0], point_src[1], point_src[2]];
let point_dst_col = &dst_r_src_mat * point_src_col + &dst_t_src_col;
for (i, point_dst_col_val) in point_dst_col.iter().enumerate().take(3) {
point_dst[i] = *point_dst_col_val;
}
}
}

// transform_points3d_matmul using faer with matmul
fn transform_points3d_matmul(
src_points: &[[f64; 3]],
dst_r_src: &[[f64; 3]; 3],
dst_t_src: &[f64; 3],
dst_points: &mut [[f64; 3]],
) {
// create views of the rotation and translation matrices
let dst_r_src_mat = {
let dst_r_src_slice = unsafe {
std::slice::from_raw_parts(dst_r_src.as_ptr() as *const f64, dst_r_src.len() * 3)
};
faer::mat::from_row_major_slice(dst_r_src_slice, 3, 3)
};
let dst_t_src_col = faer::col![dst_t_src[0], dst_t_src[1], dst_t_src[2]];

// create view of the source points
let points_in_src: faer::MatRef<'_, f64> = {
let src_points_slice = unsafe {
std::slice::from_raw_parts(src_points.as_ptr() as *const f64, src_points.len() * 3)
};
// SAFETY: src_points_slice is a 3xN matrix where each column represents a 3D point
faer::mat::from_row_major_slice(src_points_slice, 3, src_points.len())
};

// create a mutable view of the destination points
let mut points_in_dst = {
let dst_points_slice = unsafe {
std::slice::from_raw_parts_mut(
dst_points.as_mut_ptr() as *mut f64,
dst_points.len() * 3,
)
};
// SAFETY: dst_points_slice is a 3xN matrix where each column represents a 3D point
faer::mat::from_column_major_slice_mut(dst_points_slice, 3, dst_points.len())
};

// perform the matrix multiplication
faer::linalg::matmul::matmul(
&mut points_in_dst,
dst_r_src_mat,
points_in_src,
None,
1.0,
faer::Parallelism::Rayon(4),
);

// apply translation to each point
for mut col_mut in points_in_dst.col_iter_mut() {
let sum = &dst_t_src_col + col_mut.to_owned();
col_mut.copy_from(&sum);
}
}

fn bench_transform_points3d(c: &mut Criterion) {
let mut group = c.benchmark_group("transform_points3d");

for num_points in [1000, 10000, 100000, 200000, 500000].iter() {
group.throughput(criterion::Throughput::Elements(*num_points as u64));
let parameter_string = format!("{}", num_points);

let src_points = vec![[2.0, 2.0, 2.0]; *num_points];
let rotation = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
let translation = [0.0, 0.0, 0.0];
let mut dst_points = vec![[0.0; 3]; src_points.len()];

group.bench_with_input(
BenchmarkId::new("transform_points3d", &parameter_string),
&(&src_points, &rotation, &translation, &mut dst_points),
|b, i| {
let (src, rot, trans, mut dst) = (i.0, i.1, i.2, i.3.clone());
b.iter(|| {
linalg::transform_points3d(src, rot, trans, &mut dst);
black_box(());
});
},
);

group.bench_with_input(
BenchmarkId::new("transform_points3d_col", &parameter_string),
&(&src_points, &rotation, &translation, &mut dst_points),
|b, i| {
let (src, rot, trans, mut dst) = (i.0, i.1, i.2, i.3.clone());
b.iter(|| {
transform_points3d_col(src, rot, trans, &mut dst);
black_box(());
});
},
);

group.bench_with_input(
BenchmarkId::new("transform_points3d_matmul", &parameter_string),
&(&src_points, &rotation, &translation, &mut dst_points),
|b, i| {
let (src, rot, trans, mut dst) = (i.0, i.1, i.2, i.3.clone());
b.iter(|| {
transform_points3d_matmul(src, rot, trans, &mut dst);
black_box(());
});
},
);
}
}

fn matmul33_dot(a: &[[f64; 3]; 3], b: &[[f64; 3]; 3], m: &mut [[f64; 3]; 3]) {
let row0 = &a[0];
let row1 = &a[1];
let row2 = &a[2];

let col0 = &[b[0][0], b[1][0], b[2][0]];
let col1 = &[b[0][1], b[1][1], b[2][1]];
let col2 = &[b[0][2], b[1][2], b[2][2]];

m[0][0] = linalg::dot_product3(row0, col0);
m[0][1] = linalg::dot_product3(row0, col1);
m[0][2] = linalg::dot_product3(row0, col2);

m[1][0] = linalg::dot_product3(row1, col0);
m[1][1] = linalg::dot_product3(row1, col1);
m[1][2] = linalg::dot_product3(row1, col2);

m[2][0] = linalg::dot_product3(row2, col0);
m[2][1] = linalg::dot_product3(row2, col1);
m[2][2] = linalg::dot_product3(row2, col2);
}

fn bench_matmul33(c: &mut Criterion) {
let mut group = c.benchmark_group("matmul33");

let a_mat = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let b_mat = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let mut m_mat = [[0.0; 3]; 3];

group.bench_function(BenchmarkId::new("matmul33", ""), |b| {
b.iter(|| {
linalg::matmul33(&a_mat, &b_mat, &mut m_mat);
black_box(());
});
});

group.bench_function(BenchmarkId::new("matmul33_dot", ""), |b| {
b.iter(|| {
matmul33_dot(&a_mat, &b_mat, &mut m_mat);
black_box(());
});
});
}

criterion_group!(benches, bench_transform_points3d, bench_matmul33);
criterion_main!(benches);
1 change: 0 additions & 1 deletion crates/kornia-3d/src/io/colmap/text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ pub fn read_images_txt(path: impl AsRef<Path>) -> Result<Vec<ColmapImage>, Colma
}

/// Utility functions for parsing COLMAP text files

fn parse_part<T: std::str::FromStr>(s: &str) -> Result<T, ColmapError>
where
T::Err: std::fmt::Display,
Expand Down
1 change: 1 addition & 0 deletions crates/kornia-3d/src/io/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod colmap;
pub mod pcd;
pub mod ply;
3 changes: 3 additions & 0 deletions crates/kornia-3d/src/io/pcd/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
/// PCD file parser
mod parser;
pub use parser::*;
93 changes: 93 additions & 0 deletions crates/kornia-3d/src/io/pcd/parser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use serde::Deserialize;
use std::io::{BufRead, Read};
use std::path::Path;

use crate::pointcloud::PointCloud;

#[derive(Debug, thiserror::Error)]
pub enum PcdError {
#[error("Failed to read PCD file")]
Io(#[from] std::io::Error),

#[error("Failed to deserialize PCD file")]
Deserialize(#[from] bincode::Error),

#[error("Unsupported PCD property")]
UnsupportedProperty,

#[error("Invalid PCD file extension. Got:{0}")]
InvalidFileExtension(String),
}

/// A property of a point in a PCD file.
#[derive(Debug, Deserialize)]
pub struct PcdPropertyXYZRGBNCurvature {
pub x: f32,
pub y: f32,
pub z: f32,
pub rgb: u32,
pub nx: f32,
pub ny: f32,
pub nz: f32,
pub curvature: f32,
}

/// Read a PCD file in binary format.
///
/// Args:
/// path: The path to the PCD file.
///
/// Returns:
/// A `PointCloud` struct containing the points, colors, and normals.
pub fn read_pcd_binary(path: impl AsRef<Path>) -> Result<PointCloud, PcdError> {
let Some(file_ext) = path.as_ref().extension() else {
return Err(PcdError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"File extension is missing",
)));
};

if file_ext != "pcd" {
return Err(PcdError::InvalidFileExtension(
file_ext.to_string_lossy().to_string(),
));
}

// open the file
let file = std::fs::File::open(path)?;
let mut reader = std::io::BufReader::new(file);

// read the header
// TODO support other formats headers
let mut header = String::new();
loop {
let mut line = String::new();
reader.read_line(&mut line)?;
if line.starts_with("DATA binary") {
break;
}
header.push_str(&line);
}

// create a buffer for the points
let mut buffer = vec![0u8; std::mem::size_of::<PcdPropertyXYZRGBNCurvature>()];

// read the points and store them in a vector
let mut points = Vec::new();
let mut colors = Vec::new();
let mut normals = Vec::new();

while reader.read_exact(&mut buffer).is_ok() {
let property: PcdPropertyXYZRGBNCurvature = bincode::deserialize(&buffer)?;
points.push([property.x as f64, property.y as f64, property.z as f64]);
let rgb = property.rgb as u32;
colors.push([
((rgb >> 16) & 0xFF) as u8,
((rgb >> 8) & 0xFF) as u8,
rgb as u8,
]);
normals.push([property.nx as f64, property.ny as f64, property.nz as f64]);
}

Ok(PointCloud::new(points, Some(colors), Some(normals)))
}
24 changes: 8 additions & 16 deletions crates/kornia-3d/src/io/ply/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::io::{BufRead, Read};
use std::path::Path;

use super::properties::{OpenSplatProperty, PlyProperty};
use crate::pointcloud::{PointCloud, Vec3};
use crate::pointcloud::PointCloud;

#[derive(Debug, thiserror::Error)]
pub enum PlyError {
Expand Down Expand Up @@ -62,21 +62,13 @@ pub fn read_ply_binary(
match property {
PlyProperty::OpenSplat => {
let property: OpenSplatProperty = bincode::deserialize(&buffer)?;
points.push(Vec3 {
x: property.x,
y: property.y,
z: property.z,
});
colors.push(Vec3 {
x: property.f_dc_0,
y: property.f_dc_1,
z: property.f_dc_2,
});
normals.push(Vec3 {
x: property.nx,
y: property.ny,
z: property.nz,
});
points.push([property.x as f64, property.y as f64, property.z as f64]);
colors.push([
(property.f_dc_0 * 255.0) as u8,
(property.f_dc_1 * 255.0) as u8,
(property.f_dc_2 * 255.0) as u8,
]);
normals.push([property.nx as f64, property.ny as f64, property.nz as f64]);
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions crates/kornia-3d/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
pub mod io;
pub mod linalg;
pub mod ops;
pub mod pointcloud;
pub mod transforms;
pub mod vector;
Loading
Loading