From a114b4a409dfe2589b7333bc6c6a61b4bb7df93b Mon Sep 17 00:00:00 2001 From: aidengu Date: Fri, 19 Jul 2024 15:21:34 -0700 Subject: [PATCH] adding figure scripts --- figures/main_figure_1a.py | 99 ++++++++++++++ figures/main_figure_1b.py | 101 +++++++++++++++ figures/supplementary_figure_2.py | 78 +++++++++++ .../supplementary_figure_x_dice_by_area.py | 122 ++++++++++++++++++ 4 files changed, 400 insertions(+) create mode 100644 figures/main_figure_1a.py create mode 100644 figures/main_figure_1b.py create mode 100644 figures/supplementary_figure_2.py create mode 100644 figures/supplementary_figure_x_dice_by_area.py diff --git a/figures/main_figure_1a.py b/figures/main_figure_1a.py new file mode 100644 index 0000000..a944b11 --- /dev/null +++ b/figures/main_figure_1a.py @@ -0,0 +1,99 @@ +#%% +import os +import json +import numpy as np +import seaborn as sns +from scipy.stats import boxcox +from pycirclize import Circos +import matplotlib.pyplot as plt + +base_dir = '/mnt/hanoverdev/data/BiomedSeg/figures_data' +with open(os.path.join(base_dir,'hierarchy.json'), 'r') as f: + hierarchy_data = json.load(f) + +with open(os.path.join(base_dir,'target_counts.json'), 'r') as f: + target_counts = json.load(f) + +with open(os.path.join(base_dir,'modality_counts.json'), 'r') as f: + modality_counts = json.load(f) + +# color scheme +sectors = {k: 0 for k in hierarchy_data.keys()} +for sector_name in hierarchy_data: + for k,v in hierarchy_data[sector_name]['child'].items(): + sectors[sector_name] += len(v['child']) + sectors[sector_name] += 1 + +name2color = {"organ": "#E41A1C", "abnormality": "#377EB8", "histology": "#4DAF4A"} + +def generate_shades(base_color, n): + return sns.light_palette(base_color, n + 2)[1:-1] + +color_schemes = {} +for sector in sectors: + child_colors = generate_shades(name2color[sector], len(hierarchy_data[sector]['child'])) + color_schemes[sector] = child_colors + +parent_track_ratio = (72, 85) +middle_track_ratio = (85, 100) +bar_track_ratio = (45, 70) +parent_track_font_size = 7 +middle_track_font_size = 5.5 +bar_track_font_size = 7 +outer_track_font_size = 9 + +circos = Circos(sectors, space=8.8) +for sector in circos.sectors: + idx2label = {} + idx = 1 + for k,v in hierarchy_data[sector.name.lower()]['child'].items(): + for k1,v1 in v['child'].items(): + idx2label[idx] = k1 + idx += 1 + idx2label[idx] = '' + idx2label[0] = '' + + track_outer = sector.add_track((100, 101)) + track_outer.xticks_by_interval( + 1, + tick_length=0, + outer=True, + show_bottom_line=False, + label_orientation="vertical", + label_formatter=lambda v: idx2label[int(v)], + label_size=outer_track_font_size, + show_endlabel=True + ) + + track = sector.add_track(parent_track_ratio) + track.axis(fc=name2color[sector.name], lw=0) + track.text(sector.name.capitalize().replace('Mri', 'MRI').replace('Ct', 'CT').replace('Oct', 'OCT').replace('Dermoscopy', "DS"), color="white", size=parent_track_font_size) + + track1 = sector.add_track(middle_track_ratio, r_pad_ratio=0.1) + sect_start = 0 + color_idx = 0 + for i, (k,v) in enumerate(hierarchy_data[sector.name.lower()]['child'].items()): + sect_size = len(v['child']) if i != len(hierarchy_data[sector.name.lower()]['child'])-1 else len(v['child'])+1 + if i == 0: + sect_size += 0.5 + if i == len(hierarchy_data[sector.name.lower()]['child'])-1: + sect_size -= 0.5 + track1.rect(sect_start, sect_start+sect_size, r_lim=(middle_track_ratio[0], middle_track_ratio[1]-1), ec="black", lw=0,fc=color_schemes[sector.name][color_idx]) + color_idx += 1 + track1.text(k.replace('abnormality', 'abn.').replace(' anatomies', '').replace(' disturbance', '').replace('other abn.', 'Other').replace('liver', '').replace('pancreas', '').capitalize(), sect_start+sect_size/2, color="black", size=middle_track_font_size) + sect_start += sect_size + + x = np.linspace(sector.start+1 , sector.end-1 , int(sector.size)-1) + y = [target_counts[idx2label[i+1]] for i in range(0,len(x))] + y_box = boxcox(y, 0.35) + + track2 = sector.add_track(bar_track_ratio, r_pad_ratio=0.1) + track2.axis() + track2.yticks([1.14, 2.29, 3.43, 4.58], ["10$^2$", "10$^3$", "10$^4$", "10$^5$"], label_size=bar_track_font_size-1) + track2.bar(x, y_box, color=name2color[sector.name], alpha=0.5, align="center", lw=0) + +fig = circos.plotfig() +fig.savefig('figure_1a.pdf') +plt.show() + +# %% diff --git a/figures/main_figure_1b.py b/figures/main_figure_1b.py new file mode 100644 index 0000000..568a3c6 --- /dev/null +++ b/figures/main_figure_1b.py @@ -0,0 +1,101 @@ +#%% +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import json +import seaborn as sns + +plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False}) + +# Load data +def load_data(file_path): + with open(file_path, 'r') as f: + return json.load(f) +base_dir = '/mnt/hanoverdev/data/BiomedSeg/figures_data' +data = load_data(os.path.join(base_dir, 'modality_counts.json')) +separate_submodality = False + +# Transform data for plotting +def transform_data(data): + df = pd.DataFrame([(modality, subcat, count) for modality, subcats in data.items() for subcat, count in subcats.items()], columns=['Modality', 'Sub-category', 'Count']) + return df + +df = transform_data(data) + +# Calculate total counts by modality and sort +def calculate_totals(df): + total_counts_by_modality = df.groupby("Modality")["Count"].sum().sort_values(ascending=True) + sorted_modalities = total_counts_by_modality.index.tolist() + return total_counts_by_modality, sorted_modalities + +total_counts_by_modality, sorted_modalities = calculate_totals(df) + +# Generate color map +def generate_color_map(total_counts_by_modality): + base_colors = plt.cm.cool(np.linspace(0, 1, len(total_counts_by_modality))) + modality_color_map = {modality: base_colors[i] for i, modality in enumerate(total_counts_by_modality.index)} + return modality_color_map + +modality_color_map = generate_color_map(total_counts_by_modality) + +# Format total count for display +def format_total_count(total_count): + if total_count >= 1000: + exponent = int(np.floor(np.log10(total_count))) + mantissa = total_count / 10**exponent + formatted_total = f'{mantissa:.2f} x 10$^{exponent}$' + else: + exponent = 0 + formatted_total = str(total_count) + return formatted_total, exponent + +# Plotting function +def plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality): + fig, ax = plt.subplots(figsize=(10, 12)) + current_bottom = np.zeros(len(sorted_modalities)) + gap = 0.005 if separate_submodality else 0 + shades = np.power(np.linspace(0.75, 1, df.groupby("Sub-category").ngroups), 2) + + if separate_submodality: + for i, modality in enumerate(sorted_modalities): + subdf = df[df["Modality"] == modality].sort_values(by='Count', ascending=False) + for j, (index, row) in enumerate(subdf.iterrows()): + count = row['Count'] + if count > 0: + color = np.array(modality_color_map[modality]) * shades[j % len(shades)] + ax.barh(modality, count, left=current_bottom[i], color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5) + current_bottom[i] += count + gap + current_bottom[i] -= gap + total_count = total_counts_by_modality[modality] + formatted_total, exponent = format_total_count(total_count) + ax.text(current_bottom[i] + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left') + else: + for i, modality in enumerate(sorted_modalities): + total_count = total_counts_by_modality[modality] + color = np.array(modality_color_map[modality] * shades[0]) + if modality.islower(): + modality = modality.capitalize() + ax.barh(modality, total_count, color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5) + formatted_total, exponent = format_total_count(total_count) + ax.text(total_count + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left') + + configure_plot(ax, sorted_modalities) + + plt.tight_layout() + plt.savefig("./data_dist_modality_bar_subbar.pdf" if separate_submodality else "./data_dist_modality_bar.pdf", bbox_inches="tight", pad_inches=0) + plt.show() + +# Configure plot aesthetics +def configure_plot(ax, sorted_modalities): + ax.set_xscale('log') + ax.set_title("Number of images per modality", fontsize=28) + plt.yticks(rotation=0, fontsize=24, va='center') + ax.tick_params(axis='x', which='major', length=8) + ax.tick_params(axis='x', which='minor', length=5) + plt.xticks(fontsize=24) + sns.despine() + +# Main script execution +plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality) + +# %% diff --git a/figures/supplementary_figure_2.py b/figures/supplementary_figure_2.py new file mode 100644 index 0000000..5593a26 --- /dev/null +++ b/figures/supplementary_figure_2.py @@ -0,0 +1,78 @@ +# %% +import os +import json +import numpy as np +import seaborn as sns +from scipy.stats import boxcox +from pycirclize import Circos +import matplotlib.pyplot as plt + +base_dir = '/mnt/hanoverdev/data/BiomedSeg/figures_data' +with open(os.path.join(base_dir,'hierarchy.json'), 'r') as f: + hierarchy_data = json.load(f) + +with open(os.path.join(base_dir,'target_counts.json'), 'r') as f: + target_counts = json.load(f) + +with open(os.path.join(base_dir,'modality_counts.json'), 'r') as f: + modality_counts = json.load(f) + +# color scheme +sectors = {k: len(v) for k,v in modality_counts.items()} +name2color = { + "MRI": "#005A9E", + "CT": "#FF7F00", + "pathology": "#984EA3", + "ultrasound": "#7BC8F6", + "X-Ray": "#999999", + "fundus": "#76B041", + "dermoscopy": "#FDBF6F", + "endoscope": "#C0392B", + "OCT": "#33A02C", +} + +def generate_shades(base_color, n): + return sns.light_palette(base_color, n + 2)[1:-1] + +color_schemes = {} +for sector in sectors: + child_colors = generate_shades(name2color[sector], len(modality_counts[sector])) + color_schemes[sector] = child_colors + +parent_track_ratio = (72, 85) +middle_track_ratio = (85, 100) +bar_track_ratio = (45, 70) +parent_track_font_size = 7 +middle_track_font_size = 5.5 +bar_track_font_size = 7 + +circos = Circos(sectors, space=6) +for sector in circos.sectors: + track = sector.add_track(parent_track_ratio) + track.axis(fc=name2color[sector.name], lw=0) + track.text(sector.name.capitalize().replace('Mri', 'MRI').replace('Ct', 'CT').replace('Oct', 'OCT').replace('Dermoscopy', "DS"), color="white", size=parent_track_font_size) + + track1 = sector.add_track(middle_track_ratio, r_pad_ratio=0.1) + sect_start = 0 + color_idx = 0 + for k,v in modality_counts[sector.name].items(): + sect_size = 1 + track1.rect(sect_start, sect_start+sect_size, r_lim=(middle_track_ratio[0], middle_track_ratio[1]-1) , ec="black", lw=0,fc=color_schemes[sector.name][color_idx]) + color_idx += 1 + track1.text(k.capitalize(), sect_start+sect_size/2, color="black", size=middle_track_font_size) + sect_start += sect_size + + x = np.linspace(sector.start+0.5, sector.end-0.5, int(sector.size)) + y = [v for k,v in modality_counts[sector.name].items()] + y_box = boxcox(y, 0.35) + + track2 = sector.add_track(bar_track_ratio, r_pad_ratio=0.1) + track2.axis() + track2.yticks([1.14, 2.29, 3.43, 4.58], ["10$^2$", "10$^3$", "10$^4$", "10$^5$"], label_size=bar_track_font_size-1) + track2.bar(x, y_box, color=name2color[sector.name], alpha=0.5, align="center", lw=0) + +fig = circos.plotfig() +fig.savefig('data_target_modality.pdf') +plt.show() + +# %% diff --git a/figures/supplementary_figure_x_dice_by_area.py b/figures/supplementary_figure_x_dice_by_area.py new file mode 100644 index 0000000..d73a980 --- /dev/null +++ b/figures/supplementary_figure_x_dice_by_area.py @@ -0,0 +1,122 @@ +#%% +import os +import json +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from scipy.stats import sem + +# Define file paths +base_dir = '/mnt/hanoverdev/data/BiomedSeg' +eval_results_path = os.path.join(base_dir, 'biomedparse_eval_results.json') + +# Load data +with open(eval_results_path, 'r') as f: + parsed_data = json.load(f) + +# Extract relevant information +def extract_data(parsed_data): + records = [] + for dataset in parsed_data: + dataset_name = dataset[len('biomed_'):-len('_test/grounding_refcoco')] + instances = parsed_data[dataset]["grounding"]["instance_results"] + for instance in instances: + metadata = instance["metadata"] + grounding_info = metadata["grounding_info"][0] + record = { + "dataset": dataset_name, + "file_name": grounding_info["mask_file"].split("/")[-1], + "area": grounding_info["area"], + "bp_dice": instance["Dice"][0] + } + records.append(record) + return pd.DataFrame(records) + +df = extract_data(parsed_data) + +# Merge with SAM and MedSAM data +def merge_with_sam_medsam(df, parsed_data, base_dir): + comparison_df = pd.DataFrame() + for dataset in parsed_data: + dataset_name = dataset[len('biomed_'):-len('_test/grounding_refcoco')] + if any(sub in dataset_name for sub in ['MSD', 'Radiography', 'amos22']): + dataset_name = dataset_name.replace('-', '/') + + sam_data_path = os.path.join(base_dir, dataset_name, 'test_sam_vit_b_01ec64_dice.csv') + medsam_data_path = os.path.join(base_dir, dataset_name, 'test_medsam_dice.csv') + + sam_data = pd.read_csv(sam_data_path, delimiter=',') + medsam_data = pd.read_csv(medsam_data_path, delimiter=',') + + merged_data = pd.merge(sam_data, medsam_data, on='image', suffixes=('_sam', '_medsam')) + merged_data.rename(columns={'image': 'file_name'}, inplace=True) + merged_data['dataset'] = dataset_name.replace('/', '-') + + comparison_df = pd.concat([comparison_df, merged_data], ignore_index=True) + + return pd.merge(df, comparison_df, on=['dataset', 'file_name'], how='inner') + +df = merge_with_sam_medsam(df, parsed_data, base_dir) + +# Save to CSV +df.to_csv(os.path.join(base_dir, 'dice_by_size.csv'), index=False) + +# Filter datasets +rad_list = [ + 'ACDC', 'COVID-QU-Ex', 'CXR_Masks_and_Labels', 'LGG', 'LIDC-IDRI', 'MMs', + 'MSD-Task01_BrainTumour', 'MSD-Task02_Heart', 'MSD-Task03_Liver', 'MSD-Task04_Hippocampus', + 'MSD-Task05_Prostate', 'MSD-Task06_Lung', 'MSD-Task07_Pancreas', 'MSD-Task08_HepaticVessel', + 'MSD-Task09_Spleen', 'MSD-Task10_Colon', 'PROMISE12', 'QaTa-COV19', 'Radiography-COVID', + 'Radiography-Lung_Opacity', 'Radiography-Normal', 'Radiography-Viral_Pneumonia', + 'amos22-CT', 'amos22-MRI', 'kits23', 'COVID-19_CT' +] +df = df[df['dataset'].isin(rad_list)] + +# Plot area to Dice ratio +def plot_area_to_dice(df): + sns.set_theme(style='ticks') + + total_image_area = 1024 * 1024 # pixels + max_area_threshold = total_image_area # Adjust this threshold as needed + filtered_df = df[df['area'] <= max_area_threshold] + + filtered_df['area_percentage'] = (filtered_df['area'] / total_image_area) * 100 + + bins = np.linspace(filtered_df['area_percentage'].min(), filtered_df['area_percentage'].max(), 15) + filtered_df['area_bin'] = pd.cut(filtered_df['area_percentage'], bins) + + avg_dice_bp = filtered_df.groupby('area_bin')['bp_dice'].mean() + avg_dice_sam = filtered_df.groupby('area_bin')['dice_sam'].mean() if 'dice_sam' in filtered_df.columns else None + avg_dice_medsam = filtered_df.groupby('area_bin')['dice_medsam'].mean() if 'dice_medsam' in filtered_df.columns else None + + sem_dice_bp = filtered_df.groupby('area_bin')['bp_dice'].apply(sem) + sem_dice_sam = filtered_df.groupby('area_bin')['dice_sam'].apply(sem) if 'dice_sam' in filtered_df.columns else None + sem_dice_medsam = filtered_df.groupby('area_bin')['dice_medsam'].apply(sem) if 'dice_medsam' in filtered_df.columns else None + + colors = sns.color_palette("colorblind", 3) + + plt.figure(figsize=(14, 10)) + + plt.errorbar(avg_dice_bp.index.categories.mid, avg_dice_bp, yerr=sem_dice_bp, fmt='-o', label='BiomedParse', color=colors[0], capsize=5) + if avg_dice_sam is not None: + plt.errorbar(avg_dice_sam.index.categories.mid, avg_dice_sam, yerr=sem_dice_sam, fmt='-o', label='SAM', color=colors[1], capsize=5) + if avg_dice_medsam is not None: + plt.errorbar(avg_dice_medsam.index.categories.mid, avg_dice_medsam, yerr=sem_dice_medsam, fmt='-o', label='MedSAM', color=colors[2], capsize=5) + + plt.xlabel('Area (% of total image)', fontsize=20) + plt.ylabel('Dice Score', fontsize=20) + plt.grid(False) + plt.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5, 1.08), ncol=3, frameon=False) + plt.xticks(fontsize=20) + plt.yticks(fontsize=20) + plt.xlim(filtered_df['area_percentage'].min(), filtered_df['area_percentage'].max()) + sns.despine() + + plt.tight_layout() + plt.savefig(os.path.join(base_dir, 'area_vs_dice.pdf'), dpi=300) + plt.show() + +plot_area_to_dice(df) + +# %%