diff --git a/crates/kornia-io/Cargo.toml b/crates/kornia-io/Cargo.toml index d904ee1b..37f3246d 100644 --- a/crates/kornia-io/Cargo.toml +++ b/crates/kornia-io/Cargo.toml @@ -20,6 +20,7 @@ kornia-image.workspace = true # external image = { version = "0.25" } +log = "0.4" thiserror = "1" # optional dependencies @@ -27,7 +28,11 @@ futures = { version = "0.3.1", optional = true } gst = { version = "0.23.0", package = "gstreamer", optional = true } gst-app = { version = "0.23.0", package = "gstreamer-app", optional = true } memmap2 = "0.9.4" -tokio = { version = "1", features = ["full"], optional = true } +tokio = { version = "1", features = [ + "sync", + "rt-multi-thread", + "macros", +], optional = true } turbojpeg = { version = "1.0.0", optional = true } [dev-dependencies] diff --git a/crates/kornia-io/src/stream/capture.rs b/crates/kornia-io/src/stream/capture.rs index 22219b18..d8f477eb 100644 --- a/crates/kornia-io/src/stream/capture.rs +++ b/crates/kornia-io/src/stream/capture.rs @@ -122,7 +122,7 @@ impl StreamCapture { } _ = signal_rx.changed() => { self.close()?; - return Err(StreamCaptureError::PipelineCancelled); + break; } _ = async { if let Some(ref mut s) = sig { s.as_mut().await } }, if sig.is_some() => { self.close()?; diff --git a/crates/kornia-io/src/stream/error.rs b/crates/kornia-io/src/stream/error.rs index 3333ebd0..d28e5985 100644 --- a/crates/kornia-io/src/stream/error.rs +++ b/crates/kornia-io/src/stream/error.rs @@ -71,4 +71,12 @@ pub enum StreamCaptureError { /// An error occurred during GStreamer to send end of stream event. #[error("Error ocurred in the gstreamer flow")] GstreamerFlowError(#[from] gst::FlowError), + + /// An error occurred during checking the image format. + #[error("Invalid image format: {0}")] + InvalidImageFormat(String), } + +// ensure that can be sent over threads +unsafe impl Send for StreamCaptureError {} +unsafe impl Sync for StreamCaptureError {} diff --git a/crates/kornia-io/src/stream/video.rs b/crates/kornia-io/src/stream/video.rs index 116bd232..f886339a 100644 --- a/crates/kornia-io/src/stream/video.rs +++ b/crates/kornia-io/src/stream/video.rs @@ -1,6 +1,5 @@ use std::path::Path; -use futures::prelude::*; use gst::prelude::*; use kornia_image::{Image, ImageSize}; @@ -8,18 +7,29 @@ use kornia_image::{Image, ImageSize}; use super::StreamCaptureError; /// The codec to use for the video writer. -pub enum VideoWriterCodec { +pub enum VideoCodec { /// H.264 codec. H264, } +/// The format of the image to write to the video file. +/// +/// Usually will be the combination of the image format and the pixel type. +pub enum ImageFormat { + /// 8-bit RGB format. + Rgb8, + /// 8-bit mono format. + Mono8, +} + /// A struct for writing video files. pub struct VideoWriter { pipeline: gst::Pipeline, appsrc: gst_app::AppSrc, fps: i32, + format: ImageFormat, counter: u64, - handle: Option>, + handle: Option>, } impl VideoWriter { @@ -29,11 +39,13 @@ impl VideoWriter { /// /// * `path` - The path to save the video file. /// * `codec` - The codec to use for the video writer. + /// * `format` - The expected image format. /// * `fps` - The frames per second of the video. /// * `size` - The size of the video. pub fn new( path: impl AsRef, - codec: VideoWriterCodec, + codec: VideoCodec, + format: ImageFormat, fps: i32, size: ImageSize, ) -> Result { @@ -42,7 +54,7 @@ impl VideoWriter { // TODO: Add support for other codecs #[allow(unreachable_patterns)] let _codec = match codec { - VideoWriterCodec::H264 => "x264enc", + VideoCodec::H264 => "x264enc", _ => { return Err(StreamCaptureError::InvalidConfig( "Unsupported codec".to_string(), @@ -50,6 +62,12 @@ impl VideoWriter { } }; + // TODO: Add support for other formats + let format_str = match format { + ImageFormat::Mono8 => "GRAY8", + ImageFormat::Rgb8 => "RGB", + }; + let path = path.as_ref().to_owned(); let pipeline_str = format!( @@ -76,7 +94,7 @@ impl VideoWriter { appsrc.set_format(gst::Format::Time); let caps = gst::Caps::builder("video/x-raw") - .field("format", "RGB") + .field("format", format_str) .field("width", size.width as i32) .field("height", size.height as i32) .field("framerate", gst::Fraction::new(fps, 1)) @@ -91,32 +109,37 @@ impl VideoWriter { pipeline, appsrc, fps, + format, counter: 0, handle: None, }) } - /// Start the video writer + /// Start the video writer. + /// + /// Set the pipeline to playing and launch a task to handle the bus messages. pub fn start(&mut self) -> Result<(), StreamCaptureError> { + // set the pipeline to playing self.pipeline.set_state(gst::State::Playing)?; let bus = self.pipeline.bus().ok_or(StreamCaptureError::BusError)?; - let mut messages = bus.stream(); - let handle = tokio::spawn(async move { - while let Some(msg) = messages.next().await { + // launch a task to handle the bus messages, exit when EOS is received and set the pipeline to null + let handle = std::thread::spawn(move || { + for msg in bus.iter_timed(gst::ClockTime::NONE) { match msg.view() { gst::MessageView::Eos(..) => { - println!("EOS"); + log::debug!("gstreamer received EOS"); break; } gst::MessageView::Error(err) => { - eprintln!( + log::error!( "Error from {:?}: {} ({:?})", msg.src().map(|s| s.path_string()), err.error(), err.debug() ); + break; } _ => {} } @@ -128,38 +151,49 @@ impl VideoWriter { Ok(()) } - /// Stop the video writer + /// Stop the video writer. + /// + /// Set the pipeline to null and join the thread. + /// pub fn stop(&mut self) -> Result<(), StreamCaptureError> { - // Send end of stream to the appsrc - self.appsrc - .end_of_stream() - .map_err(StreamCaptureError::GstreamerFlowError)?; + // send end of stream to the appsrc + self.appsrc.end_of_stream()?; - // Take the handle and await it - // TODO: This is a blocking call, we need to make it non-blocking if let Some(handle) = self.handle.take() { - tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(async { - if let Err(e) = handle.await { - eprintln!("Error waiting for handle: {:?}", e); - } - }); - }); + handle.join().expect("Failed to join thread"); } - // Set the pipeline to null self.pipeline.set_state(gst::State::Null)?; Ok(()) } - /// Write an image to the video file. /// /// # Arguments /// /// * `img` - The image to write to the video file. - // TODO: support write_async - pub fn write(&mut self, img: &Image) -> Result<(), StreamCaptureError> { + // TODO: explore supporting write_async + pub fn write(&mut self, img: &Image) -> Result<(), StreamCaptureError> { + // check if the image channels are correct + match self.format { + ImageFormat::Mono8 => { + if C != 1 { + return Err(StreamCaptureError::InvalidImageFormat(format!( + "Invalid number of channels: expected 1, got {}", + C + ))); + } + } + ImageFormat::Rgb8 => { + if C != 3 { + return Err(StreamCaptureError::InvalidImageFormat(format!( + "Invalid number of channels: expected 3, got {}", + C + ))); + } + } + } + // TODO: verify is there is a cheaper way to copy the buffer let mut buffer = gst::Buffer::from_mut_slice(img.as_slice().to_vec()); @@ -182,20 +216,46 @@ impl VideoWriter { impl Drop for VideoWriter { fn drop(&mut self) { - self.stop().unwrap_or_else(|e| { - eprintln!("Error stopping video writer: {:?}", e); - }); + if self.handle.is_some() { + self.stop().expect("Failed to stop video writer"); + } } } #[cfg(test)] mod tests { - use super::{VideoWriter, VideoWriterCodec}; + use super::{ImageFormat, VideoCodec, VideoWriter}; use kornia_image::{Image, ImageSize}; + #[ignore = "need gstreamer in CI"] + #[test] + fn video_writer_rgb8u() -> Result<(), Box> { + let tmp_dir = tempfile::tempdir()?; + std::fs::create_dir_all(tmp_dir.path())?; + + let file_path = tmp_dir.path().join("test.mp4"); + + let size = ImageSize { + width: 6, + height: 4, + }; + + let mut writer = + VideoWriter::new(&file_path, VideoCodec::H264, ImageFormat::Rgb8, 30, size)?; + writer.start()?; + + let img = Image::::new(size, vec![0; size.width * size.height * 3])?; + writer.write(&img)?; + writer.stop()?; + + assert!(file_path.exists(), "File does not exist: {:?}", file_path); + + Ok(()) + } + + #[ignore = "need gstreamer in CI"] #[test] - #[ignore = "TODO: fix this test as there's a race condition in the gstreamer flow"] - fn video_writer() -> Result<(), Box> { + fn video_writer_mono8u() -> Result<(), Box> { let tmp_dir = tempfile::tempdir()?; std::fs::create_dir_all(tmp_dir.path())?; @@ -205,10 +265,12 @@ mod tests { width: 6, height: 4, }; - let mut writer = VideoWriter::new(&file_path, VideoWriterCodec::H264, 30, size)?; + + let mut writer = + VideoWriter::new(&file_path, VideoCodec::H264, ImageFormat::Mono8, 30, size)?; writer.start()?; - let img = Image::new(size, vec![0; size.width * size.height * 3])?; + let img = Image::::new(size, vec![0; size.width * size.height])?; writer.write(&img)?; writer.stop()?; diff --git a/examples/video_write tasks/src/main.rs b/examples/video_write tasks/src/main.rs index 02ef76ac..43d13a91 100644 --- a/examples/video_write tasks/src/main.rs +++ b/examples/video_write tasks/src/main.rs @@ -5,7 +5,10 @@ use tokio::sync::Mutex; use kornia::{ image::{Image, ImageSize}, - io::stream::{video::VideoWriterCodec, V4L2CameraConfig, VideoWriter}, + io::stream::{ + video::{ImageFormat, VideoCodec}, + V4L2CameraConfig, VideoWriter, + }, }; #[derive(Parser)] @@ -50,7 +53,13 @@ async fn main() -> Result<(), Box> { .build()?; // start the video writer - let video_writer = VideoWriter::new(args.output, VideoWriterCodec::H264, args.fps, frame_size)?; + let video_writer = VideoWriter::new( + args.output, + VideoCodec::H264, + ImageFormat::Rgb8, + args.fps, + frame_size, + )?; let video_writer = Arc::new(Mutex::new(video_writer)); video_writer.lock().await.start()?; diff --git a/examples/video_write/Cargo.toml b/examples/video_write/Cargo.toml index 3f8d5d00..56dd3a9b 100644 --- a/examples/video_write/Cargo.toml +++ b/examples/video_write/Cargo.toml @@ -9,6 +9,7 @@ publish = false [dependencies] clap = { version = "4.5.4", features = ["derive"] } ctrlc = "3.4.4" +env_logger = "0.11.5" kornia = { workspace = true, features = ["gstreamer"] } rerun = "0.18" tokio = { version = "1" } diff --git a/examples/video_write/src/main.rs b/examples/video_write/src/main.rs index a7f5db20..1f605141 100644 --- a/examples/video_write/src/main.rs +++ b/examples/video_write/src/main.rs @@ -5,7 +5,10 @@ use tokio::sync::Mutex; use kornia::{ image::ImageSize, - io::stream::{video::VideoWriterCodec, V4L2CameraConfig, VideoWriter}, + io::stream::{ + video::{ImageFormat, VideoCodec}, + V4L2CameraConfig, VideoWriter, + }, }; #[derive(Parser)] @@ -22,6 +25,9 @@ struct Args { #[tokio::main] async fn main() -> Result<(), Box> { + // setup logging + let _ = env_logger::builder().is_test(true).try_init(); + let args = Args::parse(); // Ensure the output path ends with .mp4 @@ -47,7 +53,13 @@ async fn main() -> Result<(), Box> { .build()?; // start the video writer - let video_writer = VideoWriter::new(args.output, VideoWriterCodec::H264, args.fps, frame_size)?; + let video_writer = VideoWriter::new( + args.output, + VideoCodec::H264, + ImageFormat::Rgb8, + args.fps, + frame_size, + )?; let video_writer = Arc::new(Mutex::new(video_writer)); video_writer.lock().await.start()?; @@ -80,11 +92,8 @@ async fn main() -> Result<(), Box> { ) .await?; - video_writer - .lock() - .await - .stop() - .expect("Failed to stop video writer"); + // stop the video writer + video_writer.lock().await.stop()?; Ok(()) } diff --git a/examples/webcam/Cargo.toml b/examples/webcam/Cargo.toml index 443505b1..3b34307e 100644 --- a/examples/webcam/Cargo.toml +++ b/examples/webcam/Cargo.toml @@ -11,4 +11,4 @@ clap = { version = "4.5.4", features = ["derive"] } ctrlc = "3.4.4" kornia = { workspace = true, features = ["gstreamer"] } rerun = "0.18" -tokio = { version = "1" } +tokio = { version = "1", features = ["full"] }