Skip to content

Commit

Permalink
perf: micro-optimize sample weighted method
Browse files Browse the repository at this point in the history
- amortize allocations
- cache selected.len()
- also improve weighted method description
  • Loading branch information
jqnatividad committed Feb 10, 2025
1 parent ba63982 commit 7f5903d
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions src/cmd/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ It supports seven sampling methods:
https://en.wikipedia.org/wiki/Stratified_sampling
- WEIGHTED: the sampling method when the --weighted option is specified.
Samples records with probability proportional to weights in the specified weight column.
If the weight column contains a value that is not a number for a record, the record will be
skipped. The weights are automatically scaled based on the maximum weight in the sample.
The number of records to sample is specified by the <sample-size> argument.
Useful when some records are more important than others.
Uses MEMORY PROPORTIONAL to the sample size (k) - O(k).
Samples records with probabilities proportional to values in a specified weight column.
Records with higher weights are more likely to be selected. For example, if you have
sales data and want to sample transactions weighted by revenue, high-value transactions
will have a higher chance of being included. Non-numeric weights are treated as zero.
The weights are automatically normalized using the maximum weight in the dataset.
Specify the desired sample size with <sample-size>. Uses MEMORY PROPORTIONAL to the
sample size (k) - O(k).
"Weighted random sampling with a reservoir" https://doi.org/10.1016/j.ipl.2005.11.003
- CLUSTER: the sampling method when the --cluster option is specified.
Expand Down Expand Up @@ -852,11 +853,12 @@ fn sample_weighted<R: io::Read, W: io::Write>(
) -> CliResult<()> {
// First pass: find maximum weight
let mut max_weight = 0.0f64;
let mut curr_record;
for record in rdr.byte_records() {
let record = record?;
curr_record = record?;

let weight: f64 = fast_float2::parse(
record
curr_record
.get(weight_column)
.ok_or_else(|| format!("Weight column index {weight_column} out of bounds"))?,
)
Expand Down Expand Up @@ -931,17 +933,19 @@ fn do_weighted_sampling<T: Rng + ?Sized>(
let mut selected = HashSet::with_capacity(sample_size);
let mut attempts = 0;
let max_attempts = sample_size * 100; // Prevent infinite loops
let mut curr_record;
let mut selected_len = 0;

while selected.len() < sample_size && attempts < max_attempts {
while selected_len < sample_size && attempts < max_attempts {
for (i, record) in records.enumerate() {
if selected.len() >= sample_size {
if selected_len >= sample_size {
break;
}

let record = record?;
curr_record = record?;

let weight: f64 = fast_float2::parse(
record
curr_record
.get(weight_column)
.ok_or_else(|| format!("Weight column index {weight_column} out of bounds"))?,
)
Expand All @@ -960,7 +964,8 @@ fn do_weighted_sampling<T: Rng + ?Sized>(

if include_flag && !selected.contains(&i) {
selected.insert(i);
wtr.write_byte_record(&record)?;
selected_len += 1;
wtr.write_byte_record(&curr_record)?;
}

attempts += 1;
Expand All @@ -970,11 +975,9 @@ fn do_weighted_sampling<T: Rng + ?Sized>(
}
}

if selected.len() < sample_size {
log::warn!(
"Could only sample {} records out of requested {}",
selected.len(),
sample_size
if selected_len < sample_size {
wwarn!(
"Could only sample {selected_len} records out of requested {sample_size}"
);
}

Expand Down

0 comments on commit 7f5903d

Please sign in to comment.