diff --git a/benchmarks/scripts/analyze.py b/benchmarks/scripts/analyze.py index ae56dc451..7c96699ee 100755 --- a/benchmarks/scripts/analyze.py +++ b/benchmarks/scripts/analyze.py @@ -199,31 +199,51 @@ def case_variants(pattern, algname, ct_point_name, case_df): df = case_df[case_df['variant'].str.contains(pattern, regex=True)].reset_index(drop=True) num_records = len(df) rt_axes = get_rt_axes(df) + rt_axes_values = extract_rt_axes_values(df) vertical_axis_name = rt_axes[0] if 'Elements{io}[pow2]' in rt_axes: vertical_axis_name = 'Elements{io}[pow2]' - vertical_axis_values = extract_rt_axes_values(df)[vertical_axis_name] - vertical_axis_ids = {} + horizontal_axes = rt_axes + horizontal_axes.remove(vertical_axis_name) + vertical_axis_values = rt_axes_values[vertical_axis_name] + vertical_axis_ids = {} for idx, val in enumerate(vertical_axis_values): vertical_axis_ids[val] = idx + if len(horizontal_axes) > 0: + def extract_horizontal_space(df): + values = [] + for rt_axis in horizontal_axes: + values.append(["{}={}".format(rt_axis, v) for v in df[rt_axis].unique()]) + return list(itertools.product(*values)) + + idx = 0 + horizontal_axis_ids = {} + for point in extract_horizontal_space(df): + horizontal_axis_ids[" / ".join(point)] = idx + idx = idx + 1 + num_rows = len(vertical_axis_ids) num_cols = num_records // num_rows fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols, gridspec_kw = {'wspace': 0, 'hspace': 0}) - col_id = 0 - for id, row in df.iterrows(): + for _, row in df.iterrows(): description = row['variant'] - for rt_axis in rt_axes: - description += ' / ' + rt_axis + '=' + row[rt_axis] data = {description: row['samples'], 'base': row['base_samples']} - sns.histplot(data, ax=axes[vertical_axis_ids[row[vertical_axis_name]], col_id], kde=True) - col_id = col_id + 1 - if col_id >= num_cols: - col_id = 0 + + vertical_id = vertical_axis_ids[row[vertical_axis_name]] + + if len(horizontal_axes) > 0: + horizontal_point = [] + for rt_axis in horizontal_axes: + horizontal_point.append("{}={}".format(rt_axis, row[rt_axis])) + horizontal_id = horizontal_axis_ids[" / ".join(horizontal_point)] + sns.histplot(data, ax=axes[vertical_id, horizontal_id], kde=True) + else: + sns.histplot(data, ax=axes[vertical_id], kde=True) for ax in axes.flat: ax.set_xticklabels([])