Skip to content

Commit

Permalink
Update plotting code
Browse files Browse the repository at this point in the history
  • Loading branch information
jungtaekkim committed Mar 28, 2024
1 parent d956ec1 commit e8df4d0
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 1 deletion.
112 changes: 112 additions & 0 deletions src/plot_bayesian_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import numpy as np
import os
import matplotlib.pyplot as plt


path_results = '../results'
path_figures = '../figures'

num_layerss = [2, 4, 6, 8]
seeds = np.arange(42, 421, 42)
num_init = 10
num_iter = 500


def plot(means, sems, num_layers, str_objective, show_figure, save_figure):
cm = plt.get_cmap('tab10')

if str_objective == 'trans':
str_label = r'\textrm{Transmittance}'
str_ylabel = r'\textrm{Maximum transmittance}'
color = cm(0)
elif str_objective == 'effec':
str_label = r'\textrm{Shielding effectiveness}'
str_ylabel = r'\textrm{Maximum shielding effectiveness (dB)}'
color = cm(1)
else:
raise ValueError

plt.rc('text', usetex=True)

fig = plt.figure(figsize=(8, 6))
ax = fig.gca()

bx = np.arange(0, means.shape[0])

ax.plot(bx, means, label=str_label, linewidth=4, linestyle='solid', color=color)
ax.fill_between(bx, means - sems, means + sems, alpha=0.3, color=color)

ax.set_xlabel(r'\textrm{Iteration}', fontsize=24)
ax.set_ylabel(str_ylabel, fontsize=24)

ax.set_xlim([0, np.max(bx)])
plt.tick_params(axis='both', which='major', labelsize=20)

plt.grid()
plt.tight_layout()

if save_figure:
plt.savefig(os.path.join(path_figures, f'maxima_layers_{num_layers}_{str_objective}.pdf'),
format='pdf', transparent=True, bbox_inches='tight')
plt.savefig(os.path.join(path_figures, f'maxima_layers_{num_layers}_{str_objective}.png'),
format='png', transparent=True, bbox_inches='tight')

if show_figure:
plt.show()

plt.close('all')

def get_maxima(values, num_init):
new_values = []

for value in values:
new_value = [np.max(value[:num_init])]

for val in value[num_init:]:
if new_value[-1] < val:
new_value.append(val)
else:
new_value.append(new_value[-1])
new_values.append(new_value)
return np.array(new_values)

def get_means_sems(values):
means = np.mean(values, axis=0)
sems = np.std(values, axis=0, ddof=1) / np.sqrt(values.shape[0])
return means, sems


if __name__ == '__main__':
show_figure = True
save_figure = True

for num_layers in num_layerss:
transparencies_all = []
shielding_effectivenesses_all = []

for seed in seeds:
str_file = f'mobo_layers_{num_layers}_init_{num_init}_iter_{num_iter}_seed_{seed:04d}.npy'

results = np.load(os.path.join(path_results, str_file), allow_pickle=True)
results = results[()]

negative_transparencies = results['negative_transparencies']
negative_shielding_effectivenesses = results['negative_shielding_effectivenesses']

transparencies = -1.0 * negative_transparencies
shielding_effectivenesses = -1.0 * negative_shielding_effectivenesses

print(transparencies.shape)
print(shielding_effectivenesses.shape)

transparencies_all.append(transparencies)
shielding_effectivenesses_all.append(shielding_effectivenesses)

trans = get_maxima(transparencies_all, num_init)
effec = get_maxima(shielding_effectivenesses_all, num_init)

means_trans, sems_trans = get_means_sems(trans)
means_effec, sems_effec = get_means_sems(effec)

plot(means_trans, sems_trans, num_layers, 'trans', show_figure, save_figure)
plot(means_effec, sems_effec, num_layers, 'effec', show_figure, save_figure)
86 changes: 86 additions & 0 deletions src/plot_pareto_frontiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import numpy as np
import os
import matplotlib.pyplot as plt


path_results = '../results'
path_figures = '../figures'

num_layerss = [2, 4, 6, 8]
seeds = np.arange(42, 421, 42)
num_init = 10
num_iter = 500


def is_pareto_frontier(objs):
assert isinstance(objs, np.ndarray)
assert len(objs.shape) == 2
assert objs.shape[1] == 2

is_pareto = np.ones(objs.shape[0], dtype=bool)

for i, c in enumerate(objs):
if is_pareto[i]:
is_pareto[is_pareto] = np.any(objs[is_pareto] > c, axis=1)
is_pareto[i] = True
return is_pareto

def plot(Y, num_layers):
plt.rc('text', usetex=True)

fig = plt.figure(figsize=(8, 6))
ax = fig.gca()
cm = plt.get_cmap('tab10')

pareto_frontier = Y[is_pareto_frontier(Y)]
indices = np.argsort(pareto_frontier[:, 0])
pareto_frontier = pareto_frontier[indices]

ax.plot(Y[:, 0], Y[:, 1], linestyle='none', color=cm(2), marker='.', markersize=14)
ax.plot(pareto_frontier[:, 0], pareto_frontier[:, 1], linestyle='solid', linewidth=4, color=cm(3))

ax.set_xlabel(r'\textrm{Transmittance}', fontsize=24)
ax.set_ylabel(r'\textrm{Shielding Effectiveness (dB)}', fontsize=24)

plt.tick_params(axis='both', which='major', labelsize=20)

plt.grid()
plt.tight_layout()

if save_figure:
plt.savefig(os.path.join(path_figures, f'pareto_layers_{num_layers}.pdf'), format='pdf', transparent=True, bbox_inches='tight')
plt.savefig(os.path.join(path_figures, f'pareto_layers_{num_layers}.png'), format='png', transparent=True, bbox_inches='tight')

if show_figure:
plt.show()

plt.close('all')


if __name__ == '__main__':
show_figure = True
save_figure = False

for num_layers in num_layerss:
transparencies_all = []
shielding_effectivenesses_all = []

for seed in seeds:
str_file = f'mobo_layers_{num_layers}_init_{num_init}_iter_{num_iter}_seed_{seed:04d}.npy'

results = np.load(os.path.join(path_results, str_file), allow_pickle=True)
results = results[()]

negative_transparencies = results['negative_transparencies']
negative_shielding_effectivenesses = results['negative_shielding_effectivenesses']

transparencies = -1.0 * negative_transparencies
shielding_effectivenesses = -1.0 * negative_shielding_effectivenesses

transparencies_all += list(transparencies)
shielding_effectivenesses_all += list(shielding_effectivenesses)

Y = np.array([transparencies_all, shielding_effectivenesses_all]).T
print(Y.shape)

plot(Y, num_layers)
2 changes: 1 addition & 1 deletion src/plot_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_color(material):
elif material == 'W':
str_color = 'orange'
elif material == 'TiO2':
str_color = 'lawngreen'
str_color = 'blue'
elif material == 'TiN':
str_color = 'darkgreen'
elif material == 'Al2O3':
Expand Down

0 comments on commit e8df4d0

Please sign in to comment.