Skip to content

Commit

Permalink
Improve Video writer workflow and others (#150)
Browse files Browse the repository at this point in the history
* replace tokio by async_std in video writter

* fix examples

* fix tests

* reduce tokio features

* remove async_str

* remove unused code

* skip gstreamer tests for now in ci

* namings image format

* establish image formar and video codec
  • Loading branch information
edgarriba authored Sep 28, 2024
1 parent eeabb12 commit a4b5fd4
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 50 deletions.
7 changes: 6 additions & 1 deletion crates/kornia-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@ kornia-image.workspace = true

# external
image = { version = "0.25" }
log = "0.4"
thiserror = "1"

# optional dependencies
futures = { version = "0.3.1", optional = true }
gst = { version = "0.23.0", package = "gstreamer", optional = true }
gst-app = { version = "0.23.0", package = "gstreamer-app", optional = true }
memmap2 = "0.9.4"
tokio = { version = "1", features = ["full"], optional = true }
tokio = { version = "1", features = [
"sync",
"rt-multi-thread",
"macros",
], optional = true }
turbojpeg = { version = "1.0.0", optional = true }

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion crates/kornia-io/src/stream/capture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl StreamCapture {
}
_ = signal_rx.changed() => {
self.close()?;
return Err(StreamCaptureError::PipelineCancelled);
break;
}
_ = async { if let Some(ref mut s) = sig { s.as_mut().await } }, if sig.is_some() => {
self.close()?;
Expand Down
8 changes: 8 additions & 0 deletions crates/kornia-io/src/stream/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,12 @@ pub enum StreamCaptureError {
/// An error occurred during GStreamer to send end of stream event.
#[error("Error ocurred in the gstreamer flow")]
GstreamerFlowError(#[from] gst::FlowError),

/// An error occurred during checking the image format.
#[error("Invalid image format: {0}")]
InvalidImageFormat(String),
}

// ensure that can be sent over threads
unsafe impl Send for StreamCaptureError {}
unsafe impl Sync for StreamCaptureError {}
138 changes: 100 additions & 38 deletions crates/kornia-io/src/stream/video.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
use std::path::Path;

use futures::prelude::*;
use gst::prelude::*;

use kornia_image::{Image, ImageSize};

use super::StreamCaptureError;

/// The codec to use for the video writer.
pub enum VideoWriterCodec {
pub enum VideoCodec {
/// H.264 codec.
H264,
}

/// The format of the image to write to the video file.
///
/// Usually will be the combination of the image format and the pixel type.
pub enum ImageFormat {
/// 8-bit RGB format.
Rgb8,
/// 8-bit mono format.
Mono8,
}

/// A struct for writing video files.
pub struct VideoWriter {
pipeline: gst::Pipeline,
appsrc: gst_app::AppSrc,
fps: i32,
format: ImageFormat,
counter: u64,
handle: Option<tokio::task::JoinHandle<()>>,
handle: Option<std::thread::JoinHandle<()>>,
}

impl VideoWriter {
Expand All @@ -29,11 +39,13 @@ impl VideoWriter {
///
/// * `path` - The path to save the video file.
/// * `codec` - The codec to use for the video writer.
/// * `format` - The expected image format.
/// * `fps` - The frames per second of the video.
/// * `size` - The size of the video.
pub fn new(
path: impl AsRef<Path>,
codec: VideoWriterCodec,
codec: VideoCodec,
format: ImageFormat,
fps: i32,
size: ImageSize,
) -> Result<Self, StreamCaptureError> {
Expand All @@ -42,14 +54,20 @@ impl VideoWriter {
// TODO: Add support for other codecs
#[allow(unreachable_patterns)]
let _codec = match codec {
VideoWriterCodec::H264 => "x264enc",
VideoCodec::H264 => "x264enc",
_ => {
return Err(StreamCaptureError::InvalidConfig(
"Unsupported codec".to_string(),
))
}
};

// TODO: Add support for other formats
let format_str = match format {
ImageFormat::Mono8 => "GRAY8",
ImageFormat::Rgb8 => "RGB",
};

let path = path.as_ref().to_owned();

let pipeline_str = format!(
Expand All @@ -76,7 +94,7 @@ impl VideoWriter {
appsrc.set_format(gst::Format::Time);

let caps = gst::Caps::builder("video/x-raw")
.field("format", "RGB")
.field("format", format_str)
.field("width", size.width as i32)
.field("height", size.height as i32)
.field("framerate", gst::Fraction::new(fps, 1))
Expand All @@ -91,32 +109,37 @@ impl VideoWriter {
pipeline,
appsrc,
fps,
format,
counter: 0,
handle: None,
})
}

/// Start the video writer
/// Start the video writer.
///
/// Set the pipeline to playing and launch a task to handle the bus messages.
pub fn start(&mut self) -> Result<(), StreamCaptureError> {
// set the pipeline to playing
self.pipeline.set_state(gst::State::Playing)?;

let bus = self.pipeline.bus().ok_or(StreamCaptureError::BusError)?;
let mut messages = bus.stream();

let handle = tokio::spawn(async move {
while let Some(msg) = messages.next().await {
// launch a task to handle the bus messages, exit when EOS is received and set the pipeline to null
let handle = std::thread::spawn(move || {
for msg in bus.iter_timed(gst::ClockTime::NONE) {
match msg.view() {
gst::MessageView::Eos(..) => {
println!("EOS");
log::debug!("gstreamer received EOS");
break;
}
gst::MessageView::Error(err) => {
eprintln!(
log::error!(
"Error from {:?}: {} ({:?})",
msg.src().map(|s| s.path_string()),
err.error(),
err.debug()
);
break;
}
_ => {}
}
Expand All @@ -128,38 +151,49 @@ impl VideoWriter {
Ok(())
}

/// Stop the video writer
/// Stop the video writer.
///
/// Set the pipeline to null and join the thread.
///
pub fn stop(&mut self) -> Result<(), StreamCaptureError> {
// Send end of stream to the appsrc
self.appsrc
.end_of_stream()
.map_err(StreamCaptureError::GstreamerFlowError)?;
// send end of stream to the appsrc
self.appsrc.end_of_stream()?;

// Take the handle and await it
// TODO: This is a blocking call, we need to make it non-blocking
if let Some(handle) = self.handle.take() {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
if let Err(e) = handle.await {
eprintln!("Error waiting for handle: {:?}", e);
}
});
});
handle.join().expect("Failed to join thread");
}

// Set the pipeline to null
self.pipeline.set_state(gst::State::Null)?;

Ok(())
}

/// Write an image to the video file.
///
/// # Arguments
///
/// * `img` - The image to write to the video file.
// TODO: support write_async
pub fn write(&mut self, img: &Image<u8, 3>) -> Result<(), StreamCaptureError> {
// TODO: explore supporting write_async
pub fn write<const C: usize>(&mut self, img: &Image<u8, C>) -> Result<(), StreamCaptureError> {
// check if the image channels are correct
match self.format {
ImageFormat::Mono8 => {
if C != 1 {
return Err(StreamCaptureError::InvalidImageFormat(format!(
"Invalid number of channels: expected 1, got {}",
C
)));
}
}
ImageFormat::Rgb8 => {
if C != 3 {
return Err(StreamCaptureError::InvalidImageFormat(format!(
"Invalid number of channels: expected 3, got {}",
C
)));
}
}
}

// TODO: verify is there is a cheaper way to copy the buffer
let mut buffer = gst::Buffer::from_mut_slice(img.as_slice().to_vec());

Expand All @@ -182,20 +216,46 @@ impl VideoWriter {

impl Drop for VideoWriter {
fn drop(&mut self) {
self.stop().unwrap_or_else(|e| {
eprintln!("Error stopping video writer: {:?}", e);
});
if self.handle.is_some() {
self.stop().expect("Failed to stop video writer");
}
}
}

#[cfg(test)]
mod tests {
use super::{VideoWriter, VideoWriterCodec};
use super::{ImageFormat, VideoCodec, VideoWriter};
use kornia_image::{Image, ImageSize};

#[ignore = "need gstreamer in CI"]
#[test]
fn video_writer_rgb8u() -> Result<(), Box<dyn std::error::Error>> {
let tmp_dir = tempfile::tempdir()?;
std::fs::create_dir_all(tmp_dir.path())?;

let file_path = tmp_dir.path().join("test.mp4");

let size = ImageSize {
width: 6,
height: 4,
};

let mut writer =
VideoWriter::new(&file_path, VideoCodec::H264, ImageFormat::Rgb8, 30, size)?;
writer.start()?;

let img = Image::<u8, 3>::new(size, vec![0; size.width * size.height * 3])?;
writer.write(&img)?;
writer.stop()?;

assert!(file_path.exists(), "File does not exist: {:?}", file_path);

Ok(())
}

#[ignore = "need gstreamer in CI"]
#[test]
#[ignore = "TODO: fix this test as there's a race condition in the gstreamer flow"]
fn video_writer() -> Result<(), Box<dyn std::error::Error>> {
fn video_writer_mono8u() -> Result<(), Box<dyn std::error::Error>> {
let tmp_dir = tempfile::tempdir()?;
std::fs::create_dir_all(tmp_dir.path())?;

Expand All @@ -205,10 +265,12 @@ mod tests {
width: 6,
height: 4,
};
let mut writer = VideoWriter::new(&file_path, VideoWriterCodec::H264, 30, size)?;

let mut writer =
VideoWriter::new(&file_path, VideoCodec::H264, ImageFormat::Mono8, 30, size)?;
writer.start()?;

let img = Image::new(size, vec![0; size.width * size.height * 3])?;
let img = Image::<u8, 1>::new(size, vec![0; size.width * size.height])?;
writer.write(&img)?;
writer.stop()?;

Expand Down
13 changes: 11 additions & 2 deletions examples/video_write tasks/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use tokio::sync::Mutex;

use kornia::{
image::{Image, ImageSize},
io::stream::{video::VideoWriterCodec, V4L2CameraConfig, VideoWriter},
io::stream::{
video::{ImageFormat, VideoCodec},
V4L2CameraConfig, VideoWriter,
},
};

#[derive(Parser)]
Expand Down Expand Up @@ -50,7 +53,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.build()?;

// start the video writer
let video_writer = VideoWriter::new(args.output, VideoWriterCodec::H264, args.fps, frame_size)?;
let video_writer = VideoWriter::new(
args.output,
VideoCodec::H264,
ImageFormat::Rgb8,
args.fps,
frame_size,
)?;
let video_writer = Arc::new(Mutex::new(video_writer));
video_writer.lock().await.start()?;

Expand Down
1 change: 1 addition & 0 deletions examples/video_write/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ publish = false
[dependencies]
clap = { version = "4.5.4", features = ["derive"] }
ctrlc = "3.4.4"
env_logger = "0.11.5"
kornia = { workspace = true, features = ["gstreamer"] }
rerun = "0.18"
tokio = { version = "1" }
23 changes: 16 additions & 7 deletions examples/video_write/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use tokio::sync::Mutex;

use kornia::{
image::ImageSize,
io::stream::{video::VideoWriterCodec, V4L2CameraConfig, VideoWriter},
io::stream::{
video::{ImageFormat, VideoCodec},
V4L2CameraConfig, VideoWriter,
},
};

#[derive(Parser)]
Expand All @@ -22,6 +25,9 @@ struct Args {

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// setup logging
let _ = env_logger::builder().is_test(true).try_init();

let args = Args::parse();

// Ensure the output path ends with .mp4
Expand All @@ -47,7 +53,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.build()?;

// start the video writer
let video_writer = VideoWriter::new(args.output, VideoWriterCodec::H264, args.fps, frame_size)?;
let video_writer = VideoWriter::new(
args.output,
VideoCodec::H264,
ImageFormat::Rgb8,
args.fps,
frame_size,
)?;
let video_writer = Arc::new(Mutex::new(video_writer));
video_writer.lock().await.start()?;

Expand Down Expand Up @@ -80,11 +92,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)
.await?;

video_writer
.lock()
.await
.stop()
.expect("Failed to stop video writer");
// stop the video writer
video_writer.lock().await.stop()?;

Ok(())
}
2 changes: 1 addition & 1 deletion examples/webcam/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ clap = { version = "4.5.4", features = ["derive"] }
ctrlc = "3.4.4"
kornia = { workspace = true, features = ["gstreamer"] }
rerun = "0.18"
tokio = { version = "1" }
tokio = { version = "1", features = ["full"] }

0 comments on commit a4b5fd4

Please sign in to comment.