forked from NVIDIA/cub
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d583228
commit 1e2d115
Showing
3 changed files
with
175 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
#!/bin/env python3 | ||
|
||
from scipy import stats | ||
|
||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
import pandas as pd | ||
import numpy as np | ||
import statistics | ||
import functools | ||
import argparse | ||
import cub | ||
import os | ||
|
||
|
||
def valid_alpha(value): | ||
fvalue = float(value) | ||
if fvalue < 0 or fvalue > 1: | ||
raise argparse.ArgumentTypeError( | ||
"Alpha must be a float between 0 and 1.") | ||
return fvalue | ||
|
||
|
||
def file_exists(value): | ||
if not os.path.isfile(value): | ||
raise argparse.ArgumentTypeError(f"The file '{value}' does not exist.") | ||
return value | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser(description='Process alpha and files.') | ||
parser.add_argument('--alpha', type=valid_alpha, default=0.01, required=False, | ||
help='Alpha value must be a float between 0 and 1.') | ||
parser.add_argument('files', type=file_exists, nargs='+', | ||
help='At least two files are required.') | ||
parser.add_argument('--plot', | ||
action=argparse.BooleanOptionalAction, | ||
help="Show base distributions.") | ||
args = parser.parse_args() | ||
|
||
if len(args.files) < 2: | ||
parser.error("At least two files are required.") | ||
|
||
return args | ||
|
||
|
||
def distributions_are_different(alpha, samples_list): | ||
# H0: the distributions are not different | ||
# H1: the distribution are different | ||
result = stats.kruskal(*samples_list) | ||
|
||
# Reject H0 | ||
return result.pvalue < alpha | ||
|
||
|
||
def get_group_id(alpha, num_files, row): | ||
if distributions_are_different(alpha, [row[fid] for fid in range(num_files)]): | ||
data = {} | ||
for fid in range(num_files): | ||
data[f'file{fid}'] = row[fid] | ||
sns.displot(data, kind="kde") | ||
plt.show() | ||
|
||
|
||
def plot(args): | ||
dfs = {} | ||
for fid, file in enumerate(args.files): | ||
storage = cub.bench.StorageBase(file) | ||
for alg in storage.algnames(): | ||
df = storage.alg_to_df(alg) | ||
df = df[df['variant'] == 'base'].drop(columns=['variant', 'center', 'elapsed']) | ||
df['file'] = fid | ||
if alg not in dfs: | ||
dfs[alg] = [df] | ||
else: | ||
dfs[alg].append(df) | ||
|
||
get_group_id_closure = functools.partial(get_group_id, args.alpha, len(args.files)) | ||
|
||
for alg in dfs: | ||
print(alg) | ||
df = pd.concat(dfs[alg], ignore_index=True) | ||
index = list(df.columns) | ||
index.remove('samples') | ||
index.remove('file') | ||
df_pivot = df.pivot(index=index, columns='file', values='samples') | ||
df_pivot.apply(get_group_id_closure, axis=1) | ||
|
||
|
||
def combine_samples(num_files, row): | ||
row = row.dropna() | ||
combined_samples = [] | ||
for fid in range(num_files): | ||
if fid in row: | ||
combined_samples.extend(row[fid]) | ||
return np.asarray(sorted(combined_samples), dtype=np.float32) | ||
|
||
|
||
def compute_center(row): | ||
if len(row['samples']) == 0: | ||
return float('inf') | ||
return statistics.median(row['samples']) | ||
|
||
|
||
def merge(args): | ||
dfs = {} | ||
for fid, file in enumerate(args.files): | ||
storage = cub.bench.StorageBase(file) | ||
for alg in storage.algnames(): | ||
df = storage.alg_to_df(alg) | ||
df['file'] = fid | ||
if alg not in dfs: | ||
dfs[alg] = [df] | ||
else: | ||
dfs[alg].append(df) | ||
|
||
combine_closure = functools.partial(combine_samples, len(args.files)) | ||
|
||
storage = cub.bench.StorageBase(cub.bench.db_name) | ||
|
||
for alg in dfs: | ||
df = pd.concat(dfs[alg], ignore_index=True) | ||
index = list(df.columns) | ||
index.remove('samples') | ||
index.remove('elapsed') | ||
index.remove('center') | ||
index.remove('file') | ||
df_pivot = df.pivot(index=index, columns='file', values='samples') | ||
df_pivot['samples'] = df_pivot.apply(combine_closure, axis=1) | ||
df_pivot['center'] = df_pivot.apply(compute_center, axis=1) | ||
df_pivot['elapsed'] = 0.0 # TODO compute min or sum, not sure | ||
df_pivot.drop(columns=list(range(len(args.files))), inplace=True) | ||
df_pivot.reset_index(inplace=True) | ||
storage.store_df(alg, df_pivot) | ||
|
||
|
||
def main(): | ||
args = parse_arguments() | ||
if args.plot: | ||
plot(args) | ||
merge(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |