Skip to content

Commit

Permalink
adding figure scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
aiden-ygu committed Jul 19, 2024
1 parent b098058 commit a114b4a
Show file tree
Hide file tree
Showing 4 changed files with 400 additions and 0 deletions.
99 changes: 99 additions & 0 deletions figures/main_figure_1a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#%%
import os
import json
import numpy as np
import seaborn as sns
from scipy.stats import boxcox
from pycirclize import Circos
import matplotlib.pyplot as plt

base_dir = '/mnt/hanoverdev/data/BiomedSeg/figures_data'
with open(os.path.join(base_dir,'hierarchy.json'), 'r') as f:
hierarchy_data = json.load(f)

with open(os.path.join(base_dir,'target_counts.json'), 'r') as f:
target_counts = json.load(f)

with open(os.path.join(base_dir,'modality_counts.json'), 'r') as f:
modality_counts = json.load(f)

# color scheme
sectors = {k: 0 for k in hierarchy_data.keys()}
for sector_name in hierarchy_data:
for k,v in hierarchy_data[sector_name]['child'].items():
sectors[sector_name] += len(v['child'])
sectors[sector_name] += 1

name2color = {"organ": "#E41A1C", "abnormality": "#377EB8", "histology": "#4DAF4A"}

def generate_shades(base_color, n):
return sns.light_palette(base_color, n + 2)[1:-1]

color_schemes = {}
for sector in sectors:
child_colors = generate_shades(name2color[sector], len(hierarchy_data[sector]['child']))
color_schemes[sector] = child_colors

parent_track_ratio = (72, 85)
middle_track_ratio = (85, 100)
bar_track_ratio = (45, 70)
parent_track_font_size = 7
middle_track_font_size = 5.5
bar_track_font_size = 7
outer_track_font_size = 9

circos = Circos(sectors, space=8.8)
for sector in circos.sectors:
idx2label = {}
idx = 1
for k,v in hierarchy_data[sector.name.lower()]['child'].items():
for k1,v1 in v['child'].items():
idx2label[idx] = k1
idx += 1
idx2label[idx] = ''
idx2label[0] = ''

track_outer = sector.add_track((100, 101))
track_outer.xticks_by_interval(
1,
tick_length=0,
outer=True,
show_bottom_line=False,
label_orientation="vertical",
label_formatter=lambda v: idx2label[int(v)],
label_size=outer_track_font_size,
show_endlabel=True
)

track = sector.add_track(parent_track_ratio)
track.axis(fc=name2color[sector.name], lw=0)
track.text(sector.name.capitalize().replace('Mri', 'MRI').replace('Ct', 'CT').replace('Oct', 'OCT').replace('Dermoscopy', "DS"), color="white", size=parent_track_font_size)

track1 = sector.add_track(middle_track_ratio, r_pad_ratio=0.1)
sect_start = 0
color_idx = 0
for i, (k,v) in enumerate(hierarchy_data[sector.name.lower()]['child'].items()):
sect_size = len(v['child']) if i != len(hierarchy_data[sector.name.lower()]['child'])-1 else len(v['child'])+1
if i == 0:
sect_size += 0.5
if i == len(hierarchy_data[sector.name.lower()]['child'])-1:
sect_size -= 0.5
track1.rect(sect_start, sect_start+sect_size, r_lim=(middle_track_ratio[0], middle_track_ratio[1]-1), ec="black", lw=0,fc=color_schemes[sector.name][color_idx])
color_idx += 1
track1.text(k.replace('abnormality', 'abn.').replace(' anatomies', '').replace(' disturbance', '').replace('other abn.', 'Other').replace('liver', '').replace('pancreas', '').capitalize(), sect_start+sect_size/2, color="black", size=middle_track_font_size)
sect_start += sect_size

x = np.linspace(sector.start+1 , sector.end-1 , int(sector.size)-1)
y = [target_counts[idx2label[i+1]] for i in range(0,len(x))]
y_box = boxcox(y, 0.35)

track2 = sector.add_track(bar_track_ratio, r_pad_ratio=0.1)
track2.axis()
track2.yticks([1.14, 2.29, 3.43, 4.58], ["10$^2$", "10$^3$", "10$^4$", "10$^5$"], label_size=bar_track_font_size-1)
track2.bar(x, y_box, color=name2color[sector.name], alpha=0.5, align="center", lw=0)

fig = circos.plotfig()
fig.savefig('figure_1a.pdf')
plt.show()

# %%
101 changes: 101 additions & 0 deletions figures/main_figure_1b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#%%
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import json
import seaborn as sns

plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})

# Load data
def load_data(file_path):
with open(file_path, 'r') as f:
return json.load(f)
base_dir = '/mnt/hanoverdev/data/BiomedSeg/figures_data'
data = load_data(os.path.join(base_dir, 'modality_counts.json'))
separate_submodality = False

# Transform data for plotting
def transform_data(data):
df = pd.DataFrame([(modality, subcat, count) for modality, subcats in data.items() for subcat, count in subcats.items()], columns=['Modality', 'Sub-category', 'Count'])
return df

df = transform_data(data)

# Calculate total counts by modality and sort
def calculate_totals(df):
total_counts_by_modality = df.groupby("Modality")["Count"].sum().sort_values(ascending=True)
sorted_modalities = total_counts_by_modality.index.tolist()
return total_counts_by_modality, sorted_modalities

total_counts_by_modality, sorted_modalities = calculate_totals(df)

# Generate color map
def generate_color_map(total_counts_by_modality):
base_colors = plt.cm.cool(np.linspace(0, 1, len(total_counts_by_modality)))
modality_color_map = {modality: base_colors[i] for i, modality in enumerate(total_counts_by_modality.index)}
return modality_color_map

modality_color_map = generate_color_map(total_counts_by_modality)

# Format total count for display
def format_total_count(total_count):
if total_count >= 1000:
exponent = int(np.floor(np.log10(total_count)))
mantissa = total_count / 10**exponent
formatted_total = f'{mantissa:.2f} x 10$^{exponent}$'
else:
exponent = 0
formatted_total = str(total_count)
return formatted_total, exponent

# Plotting function
def plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality):
fig, ax = plt.subplots(figsize=(10, 12))
current_bottom = np.zeros(len(sorted_modalities))
gap = 0.005 if separate_submodality else 0
shades = np.power(np.linspace(0.75, 1, df.groupby("Sub-category").ngroups), 2)

if separate_submodality:
for i, modality in enumerate(sorted_modalities):
subdf = df[df["Modality"] == modality].sort_values(by='Count', ascending=False)
for j, (index, row) in enumerate(subdf.iterrows()):
count = row['Count']
if count > 0:
color = np.array(modality_color_map[modality]) * shades[j % len(shades)]
ax.barh(modality, count, left=current_bottom[i], color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5)
current_bottom[i] += count + gap
current_bottom[i] -= gap
total_count = total_counts_by_modality[modality]
formatted_total, exponent = format_total_count(total_count)
ax.text(current_bottom[i] + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left')
else:
for i, modality in enumerate(sorted_modalities):
total_count = total_counts_by_modality[modality]
color = np.array(modality_color_map[modality] * shades[0])
if modality.islower():
modality = modality.capitalize()
ax.barh(modality, total_count, color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5)
formatted_total, exponent = format_total_count(total_count)
ax.text(total_count + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left')

configure_plot(ax, sorted_modalities)

plt.tight_layout()
plt.savefig("./data_dist_modality_bar_subbar.pdf" if separate_submodality else "./data_dist_modality_bar.pdf", bbox_inches="tight", pad_inches=0)
plt.show()

# Configure plot aesthetics
def configure_plot(ax, sorted_modalities):
ax.set_xscale('log')
ax.set_title("Number of images per modality", fontsize=28)
plt.yticks(rotation=0, fontsize=24, va='center')
ax.tick_params(axis='x', which='major', length=8)
ax.tick_params(axis='x', which='minor', length=5)
plt.xticks(fontsize=24)
sns.despine()

# Main script execution
plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality)

# %%
78 changes: 78 additions & 0 deletions figures/supplementary_figure_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# %%
import os
import json
import numpy as np
import seaborn as sns
from scipy.stats import boxcox
from pycirclize import Circos
import matplotlib.pyplot as plt

base_dir = '/mnt/hanoverdev/data/BiomedSeg/figures_data'
with open(os.path.join(base_dir,'hierarchy.json'), 'r') as f:
hierarchy_data = json.load(f)

with open(os.path.join(base_dir,'target_counts.json'), 'r') as f:
target_counts = json.load(f)

with open(os.path.join(base_dir,'modality_counts.json'), 'r') as f:
modality_counts = json.load(f)

# color scheme
sectors = {k: len(v) for k,v in modality_counts.items()}
name2color = {
"MRI": "#005A9E",
"CT": "#FF7F00",
"pathology": "#984EA3",
"ultrasound": "#7BC8F6",
"X-Ray": "#999999",
"fundus": "#76B041",
"dermoscopy": "#FDBF6F",
"endoscope": "#C0392B",
"OCT": "#33A02C",
}

def generate_shades(base_color, n):
return sns.light_palette(base_color, n + 2)[1:-1]

color_schemes = {}
for sector in sectors:
child_colors = generate_shades(name2color[sector], len(modality_counts[sector]))
color_schemes[sector] = child_colors

parent_track_ratio = (72, 85)
middle_track_ratio = (85, 100)
bar_track_ratio = (45, 70)
parent_track_font_size = 7
middle_track_font_size = 5.5
bar_track_font_size = 7

circos = Circos(sectors, space=6)
for sector in circos.sectors:
track = sector.add_track(parent_track_ratio)
track.axis(fc=name2color[sector.name], lw=0)
track.text(sector.name.capitalize().replace('Mri', 'MRI').replace('Ct', 'CT').replace('Oct', 'OCT').replace('Dermoscopy', "DS"), color="white", size=parent_track_font_size)

track1 = sector.add_track(middle_track_ratio, r_pad_ratio=0.1)
sect_start = 0
color_idx = 0
for k,v in modality_counts[sector.name].items():
sect_size = 1
track1.rect(sect_start, sect_start+sect_size, r_lim=(middle_track_ratio[0], middle_track_ratio[1]-1) , ec="black", lw=0,fc=color_schemes[sector.name][color_idx])
color_idx += 1
track1.text(k.capitalize(), sect_start+sect_size/2, color="black", size=middle_track_font_size)
sect_start += sect_size

x = np.linspace(sector.start+0.5, sector.end-0.5, int(sector.size))
y = [v for k,v in modality_counts[sector.name].items()]
y_box = boxcox(y, 0.35)

track2 = sector.add_track(bar_track_ratio, r_pad_ratio=0.1)
track2.axis()
track2.yticks([1.14, 2.29, 3.43, 4.58], ["10$^2$", "10$^3$", "10$^4$", "10$^5$"], label_size=bar_track_font_size-1)
track2.bar(x, y_box, color=name2color[sector.name], alpha=0.5, align="center", lw=0)

fig = circos.plotfig()
fig.savefig('data_target_modality.pdf')
plt.show()

# %%
Loading

0 comments on commit a114b4a

Please sign in to comment.