Skip to content

Commit

Permalink
add option to select metrics for spider plots of absolute performance
Browse files Browse the repository at this point in the history
  • Loading branch information
fernandomeyer committed Mar 9, 2020
1 parent 36613fe commit 3b24e5d
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 33 deletions.
7 changes: 4 additions & 3 deletions opal.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def evaluate(gs_samples_list, profiles_list_to_samples_list, labels, filter_tail
gs_rank_to_taxid_to_percentage = gs_id_to_rank_to_taxid_to_percentage[sample_id]
gs_pf_profile = gs_id_to_pf_profile[sample_id]
else:
sys.stderr.write("Skipping assessment of {} for sample {}. Make sure the SampleID of the gold standard and the profile are identical.\n".format(label, sample_id))
logging.getLogger('opal').warning("Skipping assessment of {} for sample {}. Make sure the SampleID of the gold standard and the profile are identical.\n".format(label, sample_id))
continue

rank_to_taxid_to_percentage = load_data.get_rank_to_taxid_to_percentage(profile)
Expand Down Expand Up @@ -303,7 +303,8 @@ def main():
group2.add_argument('-m', '--memory', help='Comma-separated memory usages in gigabytes', required=False)
group2.add_argument('-d', '--desc', help='Description for HTML page', required=False)
group2.add_argument('-r', '--ranks', help='Highest and lowest taxonomic ranks to consider in performance rankings, comma-separated. Valid ranks: superkingdom, phylum, class, order, family, genus, species, strain (default:superkingdom,species)', required=False)
group2.add_argument('--metrics_plot', help='Metrics for spider plot of relative performances, first character, comma-separated. Valid metrics: w:weighted Unifrac, l:L1 norm, c:completeness, p:purity, f:false positives, t:true positives (default: w,l,c,p,f)', required=False)
group2.add_argument('--metrics_plot_rel', help='Metrics for spider plot of relative performances, first character, comma-separated. Valid metrics: w:weighted Unifrac, l:L1 norm, c:completeness, p:purity, f:false positives, t:true positives (default: w,l,c,p,f)', required=False)
group2.add_argument('--metrics_plot_abs', help='Metrics for spider plot of absolute performances, first character, comma-separated. Valid metrics: c:completeness, p:purity, b:Bray-Curtis (default: c,p)', required=False)
group2.add_argument('--silent', help='Silent mode', action='store_true')
group2.add_argument('-v', '--version', action='version', version='%(prog)s ' + __version__)
group2.add_argument('-h', '--help', action='help', help='Show this help message and exit')
Expand Down Expand Up @@ -354,7 +355,7 @@ def main():
logger.info('done')

logger.info('Creating more plots...')
plots_list += pl.plot_all(pd_metrics, labels, output_dir, args.metrics_plot)
plots_list += pl.plot_all(pd_metrics, labels, output_dir, args.metrics_plot_rel, args.metrics_plot_abs)
logger.info('done')

logger.info('Computing rankings...')
Expand Down
21 changes: 11 additions & 10 deletions src/html_opal.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_rank_to_sample_pd(pd_metrics):
return rank_to_sample_pd


def get_formatted_pd_rankings(pd_rankings):
def get_formatted_pd_rankings(pd_rankings, labels):
df_list = []
df_list_unsorted_pos = []
metrics_list = []
Expand All @@ -126,7 +126,8 @@ def get_formatted_pd_rankings(pd_rankings):
df_list_unsorted_pos.append(pd.DataFrame({metric: df2['tool'].tolist(), 'score' + metric: df2['position'].tolist()}))

df_sum = pd_rankings.groupby(['tool'])['position'].sum().reset_index().sort_values('position')
df_sum_unsorted_pos = pd_rankings.groupby(['tool'])['position'].sum().reset_index()
df_sum_unsorted_pos = pd_rankings.groupby(['tool'])['position'].sum().loc[labels].reset_index()

df_list.append(
pd.DataFrame({SUM_OF_SCORES: df_sum['tool'].tolist(), 'score' + SUM_OF_SCORES: df_sum['position'].tolist()}))
df_list_unsorted_pos.append(
Expand All @@ -137,8 +138,8 @@ def get_formatted_pd_rankings(pd_rankings):
return pd_show, pd_show_unsorted_pos


def create_rankings_html(pd_rankings, ranks_scored):
pd_show, pd_show_unsorted_pos = get_formatted_pd_rankings(pd_rankings)
def create_rankings_html(pd_rankings, ranks_scored, labels):
pd_show, pd_show_unsorted_pos = get_formatted_pd_rankings(pd_rankings, labels)

table_source = ColumnDataSource(pd_show)

Expand Down Expand Up @@ -196,7 +197,7 @@ def create_rankings_html(pd_rankings, ranks_scored):
weight_unifrac = Slider(start=0, end=10, value=1, step=.1, title=c.UNIFRAC + " weight", callback=callback)
callback.args["weight_unifrac"] = weight_unifrac

p = figure(x_range=pd_show_unsorted_pos[SUM_OF_SCORES].tolist(), plot_width=1000, plot_height=400, title=SUM_OF_SCORES + " - lower is better")
p = figure(x_range=pd_show_unsorted_pos[SUM_OF_SCORES].tolist(), plot_width=800, plot_height=400, title=SUM_OF_SCORES + " - lower is better")
p.vbar(x='x', top='top', source=source, width=0.5, bottom=0, color="firebrick")

col_rankings = column([Div(text="<font color='navy'><u>Hint 1:</u> click on the columns of scores for sorting.</font>", style={"width": "600px", "margin-bottom": "0px"}),
Expand Down Expand Up @@ -309,8 +310,8 @@ def create_metrics_table(pd_metrics, labels, sample_ids_list):
alpha_diversity_metics = 'Alpha diversity'
all_metrics_labels = [presence_metrics_label, estimates_metrics_label, alpha_diversity_metics]

styles = [{'selector': 'td', 'props': [('width', '100pt')]},
{'selector': 'th', 'props': [('width', '100pt'), ('text-align', 'left')]},
styles = [{'selector': 'td', 'props': [('width', '115pt')]},
{'selector': 'th', 'props': [('width', '115pt'), ('text-align', 'left')]},
{'selector': 'th:nth-child(1)', 'props': [('width', '120pt'), ('font-weight', 'normal')]},
{'selector': '', 'props': [('width', 'max-content'), ('width', '-moz-max-content'), ('border-top', '1px solid lightgray'), ('border-spacing', '0px')]},
{'selector': 'expand-toggle:checked ~ * .data', 'props': [('background-color', 'white !important')]}]
Expand Down Expand Up @@ -418,10 +419,10 @@ def create_alpha_diversity_tab():
def create_plots_html(plots_list):
message_no_spdplot = 'Spider plots of performance require at least 3 profiles.'

text = '<img src="spider_plot.png" />' if 'spider_plot' in plots_list else message_no_spdplot
text = '<img src="spider_plot_relative.png" />' if 'spider_plot_relative' in plots_list else message_no_spdplot
plot1 = Panel(child=Div(text=text), title='Relative performance', width=780)

text = '<img src="spider_plot_recall_precision.png" />' if 'spider_plot_recall_precision' in plots_list else message_no_spdplot
text = '<img src="spider_plot_absolute.png" />' if 'spider_plot_absolute' in plots_list else message_no_spdplot
plot2 = Panel(child=Div(text=text), title='Absolute performance')

tabs_plots = Tabs(tabs=[plot1, plot2], width=780, css_classes=['bk-tabs-margin'])
Expand Down Expand Up @@ -501,7 +502,7 @@ def create_computing_efficiency_tab(pd_metrics, plots_list, tabs_list):


def create_html(pd_rankings, ranks_scored, pd_metrics, labels, sample_ids_list, plots_list, output_dir, desc_text):
col_rankings = create_rankings_html(pd_rankings, ranks_scored)
col_rankings = create_rankings_html(pd_rankings, ranks_scored, labels)

create_heatmap_bar(output_dir)

Expand Down
53 changes: 33 additions & 20 deletions src/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def spider_plot(metrics, labels, rank_to_metric_to_toolvalues, output_dir, file_
return []
theta = spl.radar_factory(N, frame='polygon')
fig, axes = plt.subplots(figsize=(9, 9), nrows=2, ncols=3, subplot_kw=dict(projection='radar'))
fig.subplots_adjust(wspace=1.0, hspace=0.0, top=0.87, bottom=0.45)
fig.subplots_adjust(wspace=.5, hspace=0.3, top=0.87, bottom=0.45)

for ax, rank in zip(axes.flat, c.PHYLUM_SPECIES):
if grid_points:
Expand All @@ -366,12 +366,8 @@ def spider_plot(metrics, labels, rank_to_metric_to_toolvalues, output_dir, file_
ax.set_title(rank, weight='bold', size=9, position=(0.5, 1.1),
horizontalalignment='center', verticalalignment='center')

if absolute:
metric_suffix = 'absolute'
else:
metric_suffix = ''
# select only metrics in metrics list
metrics_subdict = OrderedDict((metric, rank_to_metric_to_toolvalues[rank][metric + metric_suffix]) for metric in metrics)
metrics_subdict = OrderedDict((metric, rank_to_metric_to_toolvalues[rank][metric]) for metric in metrics)
it = 1
metric_to_toolindex = []
for d, color in zip(metrics_subdict.values(), colors):
Expand Down Expand Up @@ -408,6 +404,9 @@ def spider_plot(metrics, labels, rank_to_metric_to_toolvalues, output_dir, file_
xticklabel.set_position((0,.20))
xticklabel.set_fontsize('x-small')

if absolute:
metrics = [metric[:-8] for metric in metrics]

ax = axes[0, 0]
ax.legend(metrics, loc=(2.0 - 0.353 * len(metrics), 1.25), labelspacing=0.1, fontsize='small', ncol=len(metrics), frameon=False)
fig.savefig(os.path.join(output_dir, file_name + '.pdf'), dpi=100, format='pdf', bbox_inches='tight')
Expand Down Expand Up @@ -452,14 +451,25 @@ def plot_braycurtis_l1norm(braycurtis_list, l1norm_list, labels, output_dir):
plt.close(fig)


def get_metrics_for_spider_plot(metrics_plot):
initial_to_metric = {'w':c.UNIFRAC, 'l':c.L1NORM, 'c':c.RECALL, 'p':c.PRECISION, 'f':c.FP, 't':c.TP}
def get_metrics_for_spider_plot(metrics_plot, absolute):
metrics_initial = [x.strip() for x in metrics_plot.split(',')]
metrics_list = []

if not absolute:
initial_to_metric = {'w':c.UNIFRAC, 'l':c.L1NORM, 'c':c.RECALL, 'p':c.PRECISION, 'f':c.FP, 't':c.TP}
for initial in metrics_initial:
if initial not in initial_to_metric:
logging.getLogger('opal').warning('Invalid metric initial {} provided with option --metrics_plot_rel. Defaults will be used.'.format(initial))
return [c.UNIFRAC, c.L1NORM, c.RECALL, c.PRECISION, c.FP]
else:
metrics_list.append(initial_to_metric[initial])
return metrics_list

initial_to_metric = {'c':c.RECALL+'absolute', 'p':c.PRECISION+'absolute', 'b':c.BRAY_CURTIS+'absolute'}
for initial in metrics_initial:
if initial not in initial_to_metric:
logging.getLogger('opal').warning('Invalid metric initial {} provided with option --metrics_plot. Defaults will be used.'.format(initial))
return [c.UNIFRAC, c.L1NORM, c.RECALL, c.PRECISION, c.FP]
logging.getLogger('opal').warning('Invalid metric initial {} provided with option --metrics_plot_abs. Defaults will be used.'.format(initial))
return [c.RECALL+'absolute', c.PRECISION+'absolute']
else:
metrics_list.append(initial_to_metric[initial])
return metrics_list
Expand Down Expand Up @@ -576,8 +586,7 @@ def spider_plot_preprocess_metrics(pd_mean, labels):
return tool_to_rank_to_metric_to_value


def plot_all(pd_metrics, labels, output_dir, metrics_plot):
metrics = [c.UNIFRAC, c.L1NORM, c.RECALL, c.PRECISION, c.FP]
def plot_all(pd_metrics, labels, output_dir, metrics_plot_rel, metrics_plot_abs):
rank_to_metric_to_toolvalues = defaultdict(lambda : defaultdict(list))

pd_copy = pd_metrics.copy()
Expand All @@ -593,33 +602,37 @@ def plot_all(pd_metrics, labels, output_dir, metrics_plot):

tool_to_rank_to_metric_to_value = spider_plot_preprocess_metrics(pd_mean, labels)

metrics_for_plot = get_metrics_for_spider_plot(metrics_plot) if metrics_plot else metrics
metrics_for_plot_rel = get_metrics_for_spider_plot(metrics_plot_rel, absolute=False) if metrics_plot_rel else [c.UNIFRAC, c.L1NORM, c.RECALL, c.PRECISION, c.FP]
metrics_for_plot_abs = get_metrics_for_spider_plot(metrics_plot_abs, absolute=True) if metrics_plot_abs else [c.RECALL+'absolute', c.PRECISION+'absolute']

present_labels = []
for label in labels:
if label not in tool_to_rank_to_metric_to_value:
continue
else:
present_labels.append(label)
for rank in c.PHYLUM_SPECIES:
for metric in metrics_for_plot + [c.RECALL+'absolute', c.PRECISION+'absolute']:
for metric in metrics_for_plot_rel + metrics_for_plot_abs:
if metric in tool_to_rank_to_metric_to_value[label][rank]:
rank_to_metric_to_toolvalues[rank][metric].append(tool_to_rank_to_metric_to_value[label][rank][metric])
rank_to_metric_to_toolvalues[rank][c.UNIFRAC].append(tool_to_rank_to_metric_to_value[label]['rank independent'][c.UNIFRAC])

colors = [plt.cm.tab10(2), plt.cm.tab10(0), plt.cm.tab10(3), 'k', 'm', 'y']
plots_list = spider_plot(metrics_for_plot,
colors2 = ['r', 'k', 'olive']

plots_list = spider_plot(metrics_for_plot_rel,
present_labels,
rank_to_metric_to_toolvalues,
output_dir,
'spider_plot',
colors[:len(metrics_for_plot)])
'spider_plot_relative',
colors[:len(metrics_for_plot_rel)])

plots_list += spider_plot([c.RECALL, c.PRECISION],
plots_list += spider_plot(metrics_for_plot_abs,
present_labels,
rank_to_metric_to_toolvalues,
output_dir,
'spider_plot_recall_precision',
['r', 'k'],
'spider_plot_absolute',
colors2[:len(metrics_for_plot_abs)],
grid_points=[0.2, 0.4, 0.6, 0.8, 1.0],
fill=True,
absolute=True)
Expand Down

0 comments on commit 3b24e5d

Please sign in to comment.