Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sample: replace --faster RNG sampling option with --rng <kind> option #1532

Merged
merged 3 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ qsv-sniffer = { version = "0.10", default-features = false, features = [
"runtime-dispatch-simd",
] }
rand = "0.8"
rand_hc = "0.3"
rand_xoshiro = "0.6"
rayon = "1.8"
redis = { version = "0.24", features = [
"ahash",
Expand Down
211 changes: 150 additions & 61 deletions src/cmd/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,16 @@ sample arguments:

sample options:
--seed <number> Random Number Generator (RNG) seed.
--faster Use a faster RNG that uses the Wyrand algorithm instead
of the ChaCha algorithm used by the standard RNG.
--rng <kind> The RNG algorithm to use.
Three RNGs are supported:
- standard: Use the standard RNG.
1.5 GB/s throughput.
- faster: Use faster RNG using the Xoshiro256Plus algorithm.
8 GB/s throughput.
- cryptosecure: Use cryptographically secure HC128 algorithm.
Recommended by eSTREAM (https://www.ecrypt.eu.org/stream/).
2.1 GB/s throughput though slow initialization.
[default: standard]
--user-agent <agent> Specify custom user agent to use when the input is a URL.
It supports the following variables -
$QSV_VERSION, $QSV_TARGET, $QSV_BIN_NAME, $QSV_KIND and $QSV_COMMAND.
Expand All @@ -55,11 +63,13 @@ Common options:
Must be a single character. (default: ,)
"#;

use std::io;
use std::{io, str::FromStr};

use fastrand; //DevSkim: ignore DS148264
use rand::{self, rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
use rand_hc::Hc128Rng;
use rand_xoshiro::Xoshiro256Plus;
use serde::Deserialize;
use strum_macros::EnumString;
use tempfile::NamedTempFile;
use url::Url;

Expand All @@ -75,19 +85,34 @@ struct Args {
flag_output: Option<String>,
flag_no_headers: bool,
flag_delimiter: Option<Delimiter>,
flag_seed: Option<usize>,
flag_faster: bool,
flag_seed: Option<u64>,
flag_rng: String,
flag_user_agent: Option<String>,
flag_timeout: Option<u16>,
}

#[derive(Debug, EnumString, PartialEq)]
#[strum(ascii_case_insensitive)]
enum RngKind {
Standard,
Faster,
Cryptosecure,
}

pub fn run(argv: &[&str]) -> CliResult<()> {
let mut args: Args = util::get_args(USAGE, argv)?;

if args.arg_sample_size.is_sign_negative() {
return fail_incorrectusage_clierror!("Sample size cannot be negative.");
}

let Ok(rng_kind) = RngKind::from_str(&args.flag_rng) else {
return fail_incorrectusage_clierror!(
"Invalid RNG algorithm `{}`. Supported RNGs are: standard, faster, cryptosecure.",
args.flag_rng
);
};

let temp_download = NamedTempFile::new()?;

args.arg_input = match args.arg_input {
Expand Down Expand Up @@ -135,25 +160,50 @@ pub fn run(argv: &[&str]) -> CliResult<()> {

let mut all_indices = (0..idx.count()).collect::<Vec<_>>();

if args.flag_faster {
log::info!(
"doing --faster sample_random_access. Seed: {:?}",
args.flag_seed
);
if let Some(seed) = args.flag_seed {
fastrand::seed(seed as u64); //DevSkim: ignore DS148264
}
all_indices = fastrand::choose_multiple(all_indices.into_iter(), sample_size as usize); //DevSkim: ignore DS148264
} else {
log::info!(
"doing standard sample_random_access. Seed: {:?}",
args.flag_seed
);
let mut rng: StdRng = match args.flag_seed {
None => StdRng::from_rng(rand::thread_rng()).unwrap(),
Some(seed) => StdRng::seed_from_u64(seed as u64), //DevSkim: ignore DS148264
};
SliceRandom::shuffle(&mut *all_indices, &mut rng); //DevSkim: ignore DS148264
match rng_kind {
RngKind::Standard => {
log::info!(
"doing standard sample_random_access. Seed: {:?}",
args.flag_seed
);
let mut rng: StdRng = match args.flag_seed {
None => StdRng::from_rng(rand::thread_rng()).unwrap(),
Some(seed) => StdRng::seed_from_u64(seed), //DevSkim: ignore DS148264
};
SliceRandom::shuffle(&mut *all_indices, &mut rng); //DevSkim: ignore DS148264
},
RngKind::Faster => {
log::info!(
"doing --faster sample_random_access. Seed: {:?}",
args.flag_seed
);

let mut rng = match args.flag_seed {
None => Xoshiro256Plus::from_rng(rand::thread_rng()).unwrap(),
Some(seed) => Xoshiro256Plus::seed_from_u64(seed), //DevSkim: ignore DS148264
};
SliceRandom::shuffle(&mut *all_indices, &mut rng); //DevSkim: ignore DS148264
},
RngKind::Cryptosecure => {
log::info!(
"doing cryptosecure sample_random_access. Seed: {:?}",
args.flag_seed
);
let seed_32 = match args.flag_seed {
None => rand::thread_rng().gen::<[u8; 32]>(),
Some(seed) => {
let seed_u8 = seed.to_le_bytes();
let mut seed_32 = [0u8; 32];
seed_32[..8].copy_from_slice(&seed_u8);
seed_32
},
};
let mut rng: Hc128Rng = match args.flag_seed {
None => Hc128Rng::from_rng(rand::thread_rng()).unwrap(),
Some(_) => Hc128Rng::from_seed(seed_32),
};
SliceRandom::shuffle(&mut *all_indices, &mut rng);
},
}

for i in all_indices.into_iter().take(sample_size as usize) {
Expand All @@ -171,12 +221,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
}
let mut rdr = rconfig.reader()?;
rconfig.write_headers(&mut rdr, &mut wtr)?;
let sampled = sample_reservoir(
&mut rdr,
sample_size as u64,
args.flag_seed,
args.flag_faster,
)?;
let sampled = sample_reservoir(&mut rdr, sample_size as u64, args.flag_seed, &rng_kind)?;
for row in sampled {
wtr.write_byte_record(&row)?;
}
Expand All @@ -188,8 +233,8 @@ pub fn run(argv: &[&str]) -> CliResult<()> {
fn sample_reservoir<R: io::Read>(
rdr: &mut csv::Reader<R>,
sample_size: u64,
seed: Option<usize>,
faster: bool,
seed: Option<u64>,
rng_kind: &RngKind,
) -> CliResult<Vec<csv::ByteRecord>> {
// The following algorithm has been adapted from:
// https://en.wikipedia.org/wiki/Reservoir_sampling
Expand All @@ -199,37 +244,81 @@ fn sample_reservoir<R: io::Read>(
reservoir.push(row?);
}

if faster {
log::info!("doing --faster sample_reservoir. Seed: {seed:?}");
if let Some(seed) = seed {
fastrand::seed(seed as u64); //DevSkim: ignore DS148264
}
match *rng_kind {
RngKind::Standard => {
log::info!("doing standard sample_random_access. Seed: {seed:?}",);
let mut rng: StdRng = match seed {
None => StdRng::from_rng(rand::thread_rng()).unwrap(),
// the non-cryptographic seed_from_u64 is sufficient for our use case
// as we're optimizing for performance
Some(seed) => StdRng::seed_from_u64(seed), //DevSkim: ignore DS148264
};

let mut random: usize;
for (i, row) in records {
random = fastrand::usize(0..=i); //DevSkim: ignore DS148264
if random < sample_size as usize {
reservoir[random] = row?;
let mut random: usize;
// Now do the sampling.
for (i, row) in records {
random = rng.gen_range(0..=i);
if random < sample_size as usize {
reservoir[random] = row?;
}
}
}
} else {
log::info!("doing standard sample_reservoir. Seed: {seed:?}");
// Seeding RNG
let mut rng: StdRng = match seed {
None => StdRng::from_rng(rand::thread_rng()).unwrap(),
// the non-cryptographic seed_from_u64 is sufficient for our use case
// as we're optimizing for performance
Some(seed) => StdRng::seed_from_u64(seed as u64), //DevSkim: ignore DS148264
};

let mut random: usize;
// Now do the sampling.
for (i, row) in records {
random = rng.gen_range(0..=i);
if random < sample_size as usize {
reservoir[random] = row?;
},
RngKind::Faster => {
log::info!("doing --faster sample_random_access. Seed: {seed:?}",);

let mut rng = match seed {
None => Xoshiro256Plus::from_rng(rand::thread_rng()).unwrap(),
// the non-cryptographic seed_from_u64 is sufficient for our use case
// as we're optimizing for performance
Some(seed) => Xoshiro256Plus::seed_from_u64(seed), //DevSkim: ignore DS148264
};

let mut random: usize;
// Now do the sampling.
for (i, row) in records {
random = rng.gen_range(0..=i);
if random < sample_size as usize {
reservoir[random] = row?;
}
}
}

// if let Some(seed) = seed {
// fastrand::seed(seed); //DevSkim: ignore DS148264
// }

// let mut random: usize;
// for (i, row) in records {
// random = fastrand::usize(0..=i); //DevSkim: ignore DS148264
// if random < sample_size as usize {
// reservoir[random] = row?;
// }
// }
},
RngKind::Cryptosecure => {
log::info!("doing cryptosecure sample_random_access. Seed: {seed:?}",);

let seed_32 = match seed {
None => rand::thread_rng().gen::<[u8; 32]>(),
Some(seed) => {
let seed_u8 = seed.to_le_bytes();
let mut seed_32 = [0u8; 32];
seed_32[..8].copy_from_slice(&seed_u8);
seed_32
},
};
let mut rng: Hc128Rng = match seed {
None => Hc128Rng::from_rng(rand::thread_rng()).unwrap(),
Some(_) => Hc128Rng::from_seed(seed_32),
};

for (i, row) in records {
let random = rng.gen_range(0..=i);
if random < sample_size as usize {
reservoir[random] = row?;
}
}
},
}

Ok(reservoir)
}
Loading
Loading