Skip to content

Commit

Permalink
🎨 switch colors and show model tag for color
Browse files Browse the repository at this point in the history
- based on seaborn example of _ColorPalette
  • Loading branch information
Henry committed Nov 27, 2023
1 parent e206483 commit e62f80b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
2 changes: 1 addition & 1 deletion project/01_2_performance_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@
"outputs": [],
"source": [
"COLORS_TO_USE = vaep.plotting.defaults.assign_colors(list(k.upper() for k in ORDER_MODELS))\n",
"sns.color_palette(COLORS_TO_USE)"
"vaep.plotting.defaults.ModelColorVisualizer(ORDER_MODELS, COLORS_TO_USE)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion project/01_2_performance_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def build_text(s):

# %%
COLORS_TO_USE = vaep.plotting.defaults.assign_colors(list(k.upper() for k in ORDER_MODELS))
sns.color_palette(COLORS_TO_USE)
vaep.plotting.defaults.ModelColorVisualizer(ORDER_MODELS, COLORS_TO_USE)

# %%
TOP_N_ORDER = ORDER_MODELS[:args.plot_to_n]
Expand Down
30 changes: 29 additions & 1 deletion vaep/plotting/defaults.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import matplotlib as mpl
import seaborn as sns

logger = logging.getLogger(__name__)
Expand All @@ -22,10 +23,11 @@
# other_colors = sns.color_palette()[8:]
other_colors = sns.color_palette("husl", 20)
color_model_mapping['IMPSEQ'] = other_colors[0]
color_model_mapping['QRILC'] = other_colors[1]
color_model_mapping['IMPSEQROB'] = other_colors[1]
color_model_mapping['MICE-NORM'] = other_colors[2]
color_model_mapping['SEQKNN'] = other_colors[3]
color_model_mapping['QRILC'] = other_colors[4]
color_model_mapping['IMPSEQROB'] = other_colors[4]
color_model_mapping['GSIMP'] = other_colors[5]
color_model_mapping['MSIMPUTE'] = other_colors[6]
color_model_mapping['MSIMPUTE_MNAR'] = other_colors[7]
Expand All @@ -49,6 +51,32 @@ def assign_colors(models):
return ret_colors


class ModelColorVisualizer:

def __init__(self, models, palette):
self.models = models
self.palette = map(mpl.colors.colorConverter.to_rgb, palette)

def as_hex(self):
"""Return a color palette with hex codes instead of RGB values."""
hex = [mpl.colors.rgb2hex(rgb) for rgb in self.palette]
return hex

def _repr_html_(self):
"""Rich display of the color palette in an HTML frontend."""
s = 55
n = len(self.models)
html = f'<svg width="{s*2}" height="{s*n/2}">'
for i, (m, c) in enumerate(zip(self.models, self.as_hex())):
html += (
f'<rect x="0" y="{i * s /2}" width="{s*2}" height="{s/2}" style="fill:{c};'
'stroke-width:2;stroke:rgb(255,255,255)" metadata="tt"/>'
)
html += f'<text x="{4}" y="{(i * s / 2) + 20}" font-size="12" fill="black">{m}</text>'
html += '</svg>'
return html


labels_dict = {"NA not interpolated valid_collab collab MSE": 'MSE',
'batch_size': 'bs',
'n_hidden_layers': "No. of hidden layers",
Expand Down

0 comments on commit e62f80b

Please sign in to comment.