-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #65 from outbrain/approx-test
Framework for data regression tests
- Loading branch information
Showing
4 changed files
with
159 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
# a suite of a bit longer (regression) tests | ||
|
||
By running `data_regression_experiment.sh`, you can conduct a stand-alone experiment that demonstrates the rankings' capability of approximating the scores obtained by using the full data set. | ||
|
||
![comparison](./comparison.png) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from __future__ import annotations | ||
|
||
import glob | ||
import os | ||
import sys | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
def extract_just_ranking(dfile): | ||
"""Extract ranking from an output file.""" | ||
ranks = [] | ||
with open(dfile) as df: | ||
next(df) # Skip header line | ||
for line in df: | ||
parts = line.strip().split('\t') | ||
ranks.append(parts[1]) | ||
return ranks | ||
|
||
def calculate_mismatch_scores(all_folders, mismatches): | ||
"""Calculate mismatch scores based on ranking files.""" | ||
all_counts = [int(folder.split('_').pop()) for folder in all_folders if 'ranking' in folder] | ||
|
||
ranking_out_struct = {} | ||
for count in all_counts: | ||
rpath = os.path.join(dfolder, f'ranking_{count}', 'feature_singles.tsv') | ||
ranking_out_struct[count] = extract_just_ranking(rpath) | ||
|
||
pivot_score_key = max(all_counts) | ||
reference_ranking = ranking_out_struct[pivot_score_key] | ||
|
||
out_results = {} | ||
for ranking_id, ranking in ranking_out_struct.items(): | ||
mismatches_counter = 0 | ||
for el in ranking[:mismatches]: | ||
if el not in reference_ranking[:mismatches]: | ||
mismatches_counter += 1 | ||
out_results[ranking_id] = 100 * (1 - mismatches_counter / mismatches) | ||
|
||
return dict(sorted(out_results.items(), key=lambda x: x[0])) | ||
|
||
def plot_precision_curve(results, pivot_score_key, mismatches, axs, c1, c2): | ||
"""Plot the precision curve based on mismatch results.""" | ||
instances = [100 * (k / pivot_score_key) for k in results.keys()] | ||
values = list(results.values()) | ||
|
||
axs[c1,c2].plot(instances, values, marker='o', linestyle='-', color='black') | ||
axs[c1,c2].invert_xaxis() | ||
axs[c1,c2].set(xlabel='Proportion of data used (%)', ylabel=f'hits@{mismatches} (%)', title=f'Approximation, top {mismatches} Features') | ||
axs[c1,c2].grid(True) | ||
|
||
if __name__ == '__main__': | ||
if len(sys.argv) < 2: | ||
print('Usage: python script.py <directory>') | ||
sys.exit(1) | ||
|
||
dfolder = sys.argv[1] | ||
mismatch_range = [1, 5, 10, 20] | ||
fig, axs = plt.subplots(2, 2) | ||
fig.set_figheight(10) | ||
fig.set_figwidth(10) | ||
row = -1 | ||
for enx, mismatches in enumerate(mismatch_range): | ||
if enx % 2 == 0: | ||
row += 1 | ||
col = enx % 2 | ||
all_folders = list(glob.glob(os.path.join(dfolder, '*'))) | ||
out_results = calculate_mismatch_scores(all_folders, mismatches) | ||
pivot_score_key = max(out_results) | ||
plot_precision_curve(out_results, pivot_score_key, mismatches, axs, row, col) | ||
plt.tight_layout() | ||
plt.savefig('comparison.png', dpi=300) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#!/bin/bash | ||
|
||
set -euo pipefail # Enable strict mode for safety | ||
|
||
# Configurable variables | ||
NUM_ROWS=1000000 | ||
NUM_FEATURES=100 | ||
INPUT_FILE="test_data_synthetic/data.csv" | ||
SIZES=('50000' '100000' '200000' '500000' '600000' '700000' '800000' '900000' '1000000') | ||
|
||
# Function to remove a directory safely | ||
remove_directory_safely() { | ||
directory_to_remove=$1 | ||
if [ -d "$directory_to_remove" ]; then | ||
echo "Removing directory: $directory_to_remove" | ||
rm -rvf "$directory_to_remove" | ||
else | ||
echo "Directory does not exist, skipping: $directory_to_remove" | ||
fi | ||
} | ||
|
||
# Function to generate random data | ||
generate_data() { | ||
echo "Generating random data files with $NUM_ROWS rows and $NUM_FEATURES features..." | ||
outrank --task data_generator --num_synthetic_rows $NUM_ROWS --num_synthetic_features $NUM_FEATURES | ||
echo "Random data generation complete." | ||
} | ||
|
||
# Function to create subspaces from the data | ||
sample_subspaces() { | ||
for i in "${SIZES[@]}" | ||
do | ||
dataset="test_data_synthetic/dataset_$i" | ||
outfile="$dataset/data.csv" | ||
mkdir -p "$dataset" | ||
|
||
if [ -f "$INPUT_FILE" ]; then | ||
echo "Sampling $i rows into $outfile..." | ||
head -n $i "$INPUT_FILE" > "$outfile" | ||
echo "Sampling for $outfile done." | ||
else | ||
echo "Input file $INPUT_FILE not found. Skipping sampling for $i rows." | ||
fi | ||
done | ||
} | ||
|
||
# Function to perform feature ranking | ||
feature_ranking() { | ||
for i in "${SIZES[@]}" | ||
do | ||
dataset="test_data_synthetic/dataset_$i" | ||
output_folder="./test_data_synthetic/ranking_$i" | ||
|
||
if [ ! -d "$dataset" ]; then | ||
echo "Dataset directory $dataset does not exist. Skipping ranking for $i rows." | ||
continue | ||
fi | ||
|
||
echo "Proceeding with feature ranking for $i rows..." | ||
outrank --task ranking --data_path "$dataset" --data_source csv-raw \ | ||
--combination_number_upper_bound 60 --output_folder "$output_folder" \ | ||
--disable_tqdm True | ||
|
||
echo "Feature ranking summary for $i rows." | ||
outrank --task ranking_summary --output_folder "$output_folder" --data_path "$dataset" | ||
echo "Ranking for $i done." | ||
done | ||
} | ||
|
||
# Function to analyze the rankings | ||
analyse_rankings() { | ||
echo "Analyzing the rankings..." | ||
python analyse_rankings.py test_data_synthetic | ||
echo "Analysis complete." | ||
} | ||
|
||
# Main script execution | ||
remove_directory_safely test_data_synthetic/ | ||
generate_data | ||
sample_subspaces | ||
feature_ranking | ||
analyse_rankings | ||
|
||
echo "Script execution finished." |