diff --git a/flake.nix b/flake.nix index ddd6bd4..fda9da7 100644 --- a/flake.nix +++ b/flake.nix @@ -22,6 +22,8 @@ default = pkgs.mkShell { buildInputs = with pkgs; [ rustup + pkg-config + openssl ]; }; diff --git a/src/api/mod.rs b/src/api/mod.rs index a5bc6a4..cbd218b 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,6 @@ -use indicatif::{ProgressBar, ProgressStyle}; +use std::{collections::VecDeque, time::Duration}; + +use indicatif::{style::ProgressTracker, HumanBytes, ProgressBar, ProgressStyle}; use serde::Deserialize; /// The asynchronous version of the API @@ -33,9 +35,9 @@ impl Progress for ProgressBar { self.set_length(size as u64); self.set_style( ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), + "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec_smoothed} ({eta})", + ).unwrap().with_key("bytes_per_sec_smoothed", MovingAvgRate::default()) + , ); let maxlength = 30; let message = if filename.len() > maxlength { @@ -71,3 +73,48 @@ pub struct RepoInfo { /// The commit sha of the repo. pub sha: String, } + +#[derive(Clone, Default)] +struct MovingAvgRate { + samples: VecDeque<(std::time::Instant, u64)>, +} + +impl ProgressTracker for MovingAvgRate { + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn tick(&mut self, state: &indicatif::ProgressState, now: std::time::Instant) { + // sample at most every 20ms + if self + .samples + .back() + .map_or(true, |(prev, _)| (now - *prev) > Duration::from_millis(20)) + { + self.samples.push_back((now, state.pos())); + } + + while let Some(first) = self.samples.front() { + if now - first.0 > Duration::from_secs(1) { + self.samples.pop_front(); + } else { + break; + } + } + } + + fn reset(&mut self, _state: &indicatif::ProgressState, _now: std::time::Instant) { + self.samples = Default::default(); + } + + fn write(&self, _state: &indicatif::ProgressState, w: &mut dyn std::fmt::Write) { + match (self.samples.front(), self.samples.back()) { + (Some((t0, p0)), Some((t1, p1))) if self.samples.len() > 1 => { + let elapsed_ms = (*t1 - *t0).as_millis(); + let rate = ((p1 - p0) as f64 * 1000f64 / elapsed_ms as f64) as u64; + write!(w, "{}/s", HumanBytes(rate)).unwrap() + } + _ => write!(w, "-").unwrap(), + } + } +} diff --git a/src/api/tokio.rs b/src/api/tokio.rs index e828c4e..f5a92c1 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -599,20 +599,24 @@ impl ApiRepo { let (start, stop) = chunk?; temporaries.push(Reverse((start, stop))); + let mut modified = false; while let Some(Reverse((min, max))) = temporaries.pop() { if min as u64 == committed { - let mut file = tokio::fs::OpenOptions::new() - .write(true) - .open(&filename) - .await?; - file.seek(SeekFrom::Start(length as u64)).await?; - file.write_all(&committed.to_le_bytes()).await?; committed = max as u64 + 1; + modified = true; } else { temporaries.push(Reverse((min, max))); break; } } + if modified { + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .open(&filename) + .await?; + file.seek(SeekFrom::Start(length as u64)).await?; + file.write_all(&committed.to_le_bytes()).await?; + } } tokio::fs::OpenOptions::new() .write(true) @@ -637,11 +641,6 @@ impl ApiRepo { { // Process each socket concurrently. let range = format!("bytes={start}-{stop}"); - let mut file = tokio::fs::OpenOptions::new() - .write(true) - .open(filename) - .await?; - file.seek(SeekFrom::Start(start as u64)).await?; let response = client .get(url) .header(RANGE, range) @@ -649,11 +648,18 @@ impl ApiRepo { .await? .error_for_status()?; let mut byte_stream = response.bytes_stream(); + let mut buf: Vec = Vec::with_capacity(stop - start); while let Some(next) = byte_stream.next().await { let next = next?; - file.write_all(&next).await?; + buf.extend(&next); progress.update(next.len()).await; } + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .open(filename) + .await?; + file.seek(SeekFrom::Start(start as u64)).await?; + file.write_all(&buf).await?; Ok((start, stop)) }