Skip to content

Commit

Permalink
Better plots
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed May 2, 2023
1 parent 451f2f6 commit 24fcb7e
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions benchmarks/scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,31 +199,51 @@ def case_variants(pattern, algname, ct_point_name, case_df):
df = case_df[case_df['variant'].str.contains(pattern, regex=True)].reset_index(drop=True)
num_records = len(df)
rt_axes = get_rt_axes(df)
rt_axes_values = extract_rt_axes_values(df)

vertical_axis_name = rt_axes[0]
if 'Elements{io}[pow2]' in rt_axes:
vertical_axis_name = 'Elements{io}[pow2]'
vertical_axis_values = extract_rt_axes_values(df)[vertical_axis_name]
vertical_axis_ids = {}
horizontal_axes = rt_axes
horizontal_axes.remove(vertical_axis_name)
vertical_axis_values = rt_axes_values[vertical_axis_name]

vertical_axis_ids = {}
for idx, val in enumerate(vertical_axis_values):
vertical_axis_ids[val] = idx

if len(horizontal_axes) > 0:
def extract_horizontal_space(df):
values = []
for rt_axis in horizontal_axes:
values.append(["{}={}".format(rt_axis, v) for v in df[rt_axis].unique()])
return list(itertools.product(*values))

idx = 0
horizontal_axis_ids = {}
for point in extract_horizontal_space(df):
horizontal_axis_ids[" / ".join(point)] = idx
idx = idx + 1

num_rows = len(vertical_axis_ids)
num_cols = num_records // num_rows
fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, gridspec_kw = {'wspace': 0, 'hspace': 0})

col_id = 0
for id, row in df.iterrows():
for _, row in df.iterrows():
description = row['variant']
for rt_axis in rt_axes:
description += ' / ' + rt_axis + '=' + row[rt_axis]
data = {description: row['samples'],
'base': row['base_samples']}
sns.histplot(data, ax=axes[vertical_axis_ids[row[vertical_axis_name]], col_id], kde=True)
col_id = col_id + 1
if col_id >= num_cols:
col_id = 0

vertical_id = vertical_axis_ids[row[vertical_axis_name]]

if len(horizontal_axes) > 0:
horizontal_point = []
for rt_axis in horizontal_axes:
horizontal_point.append("{}={}".format(rt_axis, row[rt_axis]))
horizontal_id = horizontal_axis_ids[" / ".join(horizontal_point)]
sns.histplot(data, ax=axes[vertical_id, horizontal_id], kde=True)
else:
sns.histplot(data, ax=axes[vertical_id], kde=True)

for ax in axes.flat:
ax.set_xticklabels([])
Expand Down

0 comments on commit 24fcb7e

Please sign in to comment.