-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
400 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#%% | ||
import os | ||
import json | ||
import numpy as np | ||
import seaborn as sns | ||
from scipy.stats import boxcox | ||
from pycirclize import Circos | ||
import matplotlib.pyplot as plt | ||
|
||
base_dir = '/mnt/hanoverdev/data/BiomedSeg/figures_data' | ||
with open(os.path.join(base_dir,'hierarchy.json'), 'r') as f: | ||
hierarchy_data = json.load(f) | ||
|
||
with open(os.path.join(base_dir,'target_counts.json'), 'r') as f: | ||
target_counts = json.load(f) | ||
|
||
with open(os.path.join(base_dir,'modality_counts.json'), 'r') as f: | ||
modality_counts = json.load(f) | ||
|
||
# color scheme | ||
sectors = {k: 0 for k in hierarchy_data.keys()} | ||
for sector_name in hierarchy_data: | ||
for k,v in hierarchy_data[sector_name]['child'].items(): | ||
sectors[sector_name] += len(v['child']) | ||
sectors[sector_name] += 1 | ||
|
||
name2color = {"organ": "#E41A1C", "abnormality": "#377EB8", "histology": "#4DAF4A"} | ||
|
||
def generate_shades(base_color, n): | ||
return sns.light_palette(base_color, n + 2)[1:-1] | ||
|
||
color_schemes = {} | ||
for sector in sectors: | ||
child_colors = generate_shades(name2color[sector], len(hierarchy_data[sector]['child'])) | ||
color_schemes[sector] = child_colors | ||
|
||
parent_track_ratio = (72, 85) | ||
middle_track_ratio = (85, 100) | ||
bar_track_ratio = (45, 70) | ||
parent_track_font_size = 7 | ||
middle_track_font_size = 5.5 | ||
bar_track_font_size = 7 | ||
outer_track_font_size = 9 | ||
|
||
circos = Circos(sectors, space=8.8) | ||
for sector in circos.sectors: | ||
idx2label = {} | ||
idx = 1 | ||
for k,v in hierarchy_data[sector.name.lower()]['child'].items(): | ||
for k1,v1 in v['child'].items(): | ||
idx2label[idx] = k1 | ||
idx += 1 | ||
idx2label[idx] = '' | ||
idx2label[0] = '' | ||
|
||
track_outer = sector.add_track((100, 101)) | ||
track_outer.xticks_by_interval( | ||
1, | ||
tick_length=0, | ||
outer=True, | ||
show_bottom_line=False, | ||
label_orientation="vertical", | ||
label_formatter=lambda v: idx2label[int(v)], | ||
label_size=outer_track_font_size, | ||
show_endlabel=True | ||
) | ||
|
||
track = sector.add_track(parent_track_ratio) | ||
track.axis(fc=name2color[sector.name], lw=0) | ||
track.text(sector.name.capitalize().replace('Mri', 'MRI').replace('Ct', 'CT').replace('Oct', 'OCT').replace('Dermoscopy', "DS"), color="white", size=parent_track_font_size) | ||
|
||
track1 = sector.add_track(middle_track_ratio, r_pad_ratio=0.1) | ||
sect_start = 0 | ||
color_idx = 0 | ||
for i, (k,v) in enumerate(hierarchy_data[sector.name.lower()]['child'].items()): | ||
sect_size = len(v['child']) if i != len(hierarchy_data[sector.name.lower()]['child'])-1 else len(v['child'])+1 | ||
if i == 0: | ||
sect_size += 0.5 | ||
if i == len(hierarchy_data[sector.name.lower()]['child'])-1: | ||
sect_size -= 0.5 | ||
track1.rect(sect_start, sect_start+sect_size, r_lim=(middle_track_ratio[0], middle_track_ratio[1]-1), ec="black", lw=0,fc=color_schemes[sector.name][color_idx]) | ||
color_idx += 1 | ||
track1.text(k.replace('abnormality', 'abn.').replace(' anatomies', '').replace(' disturbance', '').replace('other abn.', 'Other').replace('liver', '').replace('pancreas', '').capitalize(), sect_start+sect_size/2, color="black", size=middle_track_font_size) | ||
sect_start += sect_size | ||
|
||
x = np.linspace(sector.start+1 , sector.end-1 , int(sector.size)-1) | ||
y = [target_counts[idx2label[i+1]] for i in range(0,len(x))] | ||
y_box = boxcox(y, 0.35) | ||
|
||
track2 = sector.add_track(bar_track_ratio, r_pad_ratio=0.1) | ||
track2.axis() | ||
track2.yticks([1.14, 2.29, 3.43, 4.58], ["10$^2$", "10$^3$", "10$^4$", "10$^5$"], label_size=bar_track_font_size-1) | ||
track2.bar(x, y_box, color=name2color[sector.name], alpha=0.5, align="center", lw=0) | ||
|
||
fig = circos.plotfig() | ||
fig.savefig('figure_1a.pdf') | ||
plt.show() | ||
|
||
# %% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
#%% | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
import json | ||
import seaborn as sns | ||
|
||
plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False}) | ||
|
||
# Load data | ||
def load_data(file_path): | ||
with open(file_path, 'r') as f: | ||
return json.load(f) | ||
base_dir = '/mnt/hanoverdev/data/BiomedSeg/figures_data' | ||
data = load_data(os.path.join(base_dir, 'modality_counts.json')) | ||
separate_submodality = False | ||
|
||
# Transform data for plotting | ||
def transform_data(data): | ||
df = pd.DataFrame([(modality, subcat, count) for modality, subcats in data.items() for subcat, count in subcats.items()], columns=['Modality', 'Sub-category', 'Count']) | ||
return df | ||
|
||
df = transform_data(data) | ||
|
||
# Calculate total counts by modality and sort | ||
def calculate_totals(df): | ||
total_counts_by_modality = df.groupby("Modality")["Count"].sum().sort_values(ascending=True) | ||
sorted_modalities = total_counts_by_modality.index.tolist() | ||
return total_counts_by_modality, sorted_modalities | ||
|
||
total_counts_by_modality, sorted_modalities = calculate_totals(df) | ||
|
||
# Generate color map | ||
def generate_color_map(total_counts_by_modality): | ||
base_colors = plt.cm.cool(np.linspace(0, 1, len(total_counts_by_modality))) | ||
modality_color_map = {modality: base_colors[i] for i, modality in enumerate(total_counts_by_modality.index)} | ||
return modality_color_map | ||
|
||
modality_color_map = generate_color_map(total_counts_by_modality) | ||
|
||
# Format total count for display | ||
def format_total_count(total_count): | ||
if total_count >= 1000: | ||
exponent = int(np.floor(np.log10(total_count))) | ||
mantissa = total_count / 10**exponent | ||
formatted_total = f'{mantissa:.2f} x 10$^{exponent}$' | ||
else: | ||
exponent = 0 | ||
formatted_total = str(total_count) | ||
return formatted_total, exponent | ||
|
||
# Plotting function | ||
def plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality): | ||
fig, ax = plt.subplots(figsize=(10, 12)) | ||
current_bottom = np.zeros(len(sorted_modalities)) | ||
gap = 0.005 if separate_submodality else 0 | ||
shades = np.power(np.linspace(0.75, 1, df.groupby("Sub-category").ngroups), 2) | ||
|
||
if separate_submodality: | ||
for i, modality in enumerate(sorted_modalities): | ||
subdf = df[df["Modality"] == modality].sort_values(by='Count', ascending=False) | ||
for j, (index, row) in enumerate(subdf.iterrows()): | ||
count = row['Count'] | ||
if count > 0: | ||
color = np.array(modality_color_map[modality]) * shades[j % len(shades)] | ||
ax.barh(modality, count, left=current_bottom[i], color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5) | ||
current_bottom[i] += count + gap | ||
current_bottom[i] -= gap | ||
total_count = total_counts_by_modality[modality] | ||
formatted_total, exponent = format_total_count(total_count) | ||
ax.text(current_bottom[i] + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left') | ||
else: | ||
for i, modality in enumerate(sorted_modalities): | ||
total_count = total_counts_by_modality[modality] | ||
color = np.array(modality_color_map[modality] * shades[0]) | ||
if modality.islower(): | ||
modality = modality.capitalize() | ||
ax.barh(modality, total_count, color=color, height=0.8, log=True, edgecolor='white', linewidth=0.5) | ||
formatted_total, exponent = format_total_count(total_count) | ||
ax.text(total_count + (10**exponent)*0.05, i, formatted_total, va='center', fontsize=20, ha='left') | ||
|
||
configure_plot(ax, sorted_modalities) | ||
|
||
plt.tight_layout() | ||
plt.savefig("./data_dist_modality_bar_subbar.pdf" if separate_submodality else "./data_dist_modality_bar.pdf", bbox_inches="tight", pad_inches=0) | ||
plt.show() | ||
|
||
# Configure plot aesthetics | ||
def configure_plot(ax, sorted_modalities): | ||
ax.set_xscale('log') | ||
ax.set_title("Number of images per modality", fontsize=28) | ||
plt.yticks(rotation=0, fontsize=24, va='center') | ||
ax.tick_params(axis='x', which='major', length=8) | ||
ax.tick_params(axis='x', which='minor', length=5) | ||
plt.xticks(fontsize=24) | ||
sns.despine() | ||
|
||
# Main script execution | ||
plot_data(df, total_counts_by_modality, sorted_modalities, modality_color_map, separate_submodality) | ||
|
||
# %% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# %% | ||
import os | ||
import json | ||
import numpy as np | ||
import seaborn as sns | ||
from scipy.stats import boxcox | ||
from pycirclize import Circos | ||
import matplotlib.pyplot as plt | ||
|
||
base_dir = '/mnt/hanoverdev/data/BiomedSeg/figures_data' | ||
with open(os.path.join(base_dir,'hierarchy.json'), 'r') as f: | ||
hierarchy_data = json.load(f) | ||
|
||
with open(os.path.join(base_dir,'target_counts.json'), 'r') as f: | ||
target_counts = json.load(f) | ||
|
||
with open(os.path.join(base_dir,'modality_counts.json'), 'r') as f: | ||
modality_counts = json.load(f) | ||
|
||
# color scheme | ||
sectors = {k: len(v) for k,v in modality_counts.items()} | ||
name2color = { | ||
"MRI": "#005A9E", | ||
"CT": "#FF7F00", | ||
"pathology": "#984EA3", | ||
"ultrasound": "#7BC8F6", | ||
"X-Ray": "#999999", | ||
"fundus": "#76B041", | ||
"dermoscopy": "#FDBF6F", | ||
"endoscope": "#C0392B", | ||
"OCT": "#33A02C", | ||
} | ||
|
||
def generate_shades(base_color, n): | ||
return sns.light_palette(base_color, n + 2)[1:-1] | ||
|
||
color_schemes = {} | ||
for sector in sectors: | ||
child_colors = generate_shades(name2color[sector], len(modality_counts[sector])) | ||
color_schemes[sector] = child_colors | ||
|
||
parent_track_ratio = (72, 85) | ||
middle_track_ratio = (85, 100) | ||
bar_track_ratio = (45, 70) | ||
parent_track_font_size = 7 | ||
middle_track_font_size = 5.5 | ||
bar_track_font_size = 7 | ||
|
||
circos = Circos(sectors, space=6) | ||
for sector in circos.sectors: | ||
track = sector.add_track(parent_track_ratio) | ||
track.axis(fc=name2color[sector.name], lw=0) | ||
track.text(sector.name.capitalize().replace('Mri', 'MRI').replace('Ct', 'CT').replace('Oct', 'OCT').replace('Dermoscopy', "DS"), color="white", size=parent_track_font_size) | ||
|
||
track1 = sector.add_track(middle_track_ratio, r_pad_ratio=0.1) | ||
sect_start = 0 | ||
color_idx = 0 | ||
for k,v in modality_counts[sector.name].items(): | ||
sect_size = 1 | ||
track1.rect(sect_start, sect_start+sect_size, r_lim=(middle_track_ratio[0], middle_track_ratio[1]-1) , ec="black", lw=0,fc=color_schemes[sector.name][color_idx]) | ||
color_idx += 1 | ||
track1.text(k.capitalize(), sect_start+sect_size/2, color="black", size=middle_track_font_size) | ||
sect_start += sect_size | ||
|
||
x = np.linspace(sector.start+0.5, sector.end-0.5, int(sector.size)) | ||
y = [v for k,v in modality_counts[sector.name].items()] | ||
y_box = boxcox(y, 0.35) | ||
|
||
track2 = sector.add_track(bar_track_ratio, r_pad_ratio=0.1) | ||
track2.axis() | ||
track2.yticks([1.14, 2.29, 3.43, 4.58], ["10$^2$", "10$^3$", "10$^4$", "10$^5$"], label_size=bar_track_font_size-1) | ||
track2.bar(x, y_box, color=name2color[sector.name], alpha=0.5, align="center", lw=0) | ||
|
||
fig = circos.plotfig() | ||
fig.savefig('data_target_modality.pdf') | ||
plt.show() | ||
|
||
# %% |
Oops, something went wrong.