Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mamunm committed Mar 27, 2024
1 parent 376f80b commit a04d2ea
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 30 deletions.
1 change: 0 additions & 1 deletion experiments/test_1/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
params = MOBOQM9Parameters(featurizer="CM",
kernel="RBF",
surrogate_model="GaussianProcess",
acq_func="qEHVI",
targets=["gap", "mu"],
target_bools=[True, True],
num_total_points=100,
Expand Down
80 changes: 51 additions & 29 deletions src/mobo_qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from botorch.fit import fit_gpytorch_mll
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
import pandas as pd

from .data.cm_featurizer import get_coulomb_matrix
from .acquisition_functions import optimize_qEHVI, optimize_qNEHVI
Expand All @@ -23,7 +24,6 @@ class MOBOQM9Parameters(NamedTuple):
featurizer: Featurizer to use.
kernel: Kernel to use.
surrogate_model: Surrogate model to use.
acq_func: Acquisition function to use.
targets: List of targets to optimize.
target_bools: List of booleans indicating wheather to minimize or
maximize each target.
Expand All @@ -35,7 +35,6 @@ class MOBOQM9Parameters(NamedTuple):
featurizer: Literal["ECFP", "CM", "ACSF"]
kernel: Literal["RBF", "Matern"]
surrogate_model: Literal["GaussianProcess", "RandomForest"]
acq_func: Literal["qEHVI", "qNEHVI", "random"]
targets: List[str]
target_bools: List[bool]
num_candidates: int = 1
Expand All @@ -60,8 +59,24 @@ def __init__(self, params: MOBOQM9Parameters):
self.params.num_total_points)
self.features, self.targets = self.get_features_and_targets()
self.train_indices = self.get_train_indices()
self.dataframe = None
self.dataframe = pd.DataFrame.from_dict(self.from_target_dict())
self.acq_met = {"qEHVI": False, "qNEHVI": False, "random": False}

def form_target_dict(self):
"""
Forms the target dictionary for the MOBOQM9 model.
returns:
target_dict: Target dictionary for the MOBOQM9 model.
"""
target_dict = {"iteration": None}
for i, target in enumerate(self.params.targets):
target_dict[target] = self.targets[:, i]
target_dict["target_qEHVI"] = None
target_dict["target_qNEHVI"] = None
target_dict["target_random"] = None
return target_dict

def get_features_and_targets(self):
"""
Gets the features and targets for the MOBOQM9 model.
Expand All @@ -77,16 +92,19 @@ def get_features_and_targets(self):
else:
raise NotImplementedError

def get_surrogate_model(self):
def get_surrogate_model(self, acq):
"""
Gets the surrogate model for the MOBOQM9 model.
args:
acq: Acquisition function to use.
returns:
model: Surrogate model for the MOBOQM9 model.
"""
features = torch.tensor(self.features[self.train_indices],
features = torch.tensor(self.features[self.train_indices["acq"]],
dtype=torch.double)
targets = torch.tensor(self.correct_sign(self.targets[self.train_indices]),
targets = torch.tensor(self.correct_sign(self.targets[self.train_indices["acq"]]),
dtype=torch.double)
var = torch.full_like(targets, 1e-6)

Expand Down Expand Up @@ -120,28 +138,29 @@ def correct_sign(self, Y):
y_copy[:, idx] *= -1
return y_copy

def optimize_acquisition_function(self, model):
def optimize_acquisition_function(self, model, acq):
"""
Optimizes the acquisition function for the MOBOQM9 model.
args:
model: Surrogate model for the MOBOQM9 model.
acq: Acquisition function to use.
returns:
candidates: Candidates for the MOBOQM9 model.
"""
y_train = self.correct_sign(self.targets[self.train_indices])
y_train = self.correct_sign(self.targets[self.train_indices["acq"]])
y_train = torch.tensor(y_train, dtype=torch.double)
x_train = torch.tensor(self.features[self.train_indices], dtype=torch.double)
x_test = torch.tensor(self.features[~self.train_indices], dtype=torch.double)
x_train = torch.tensor(self.features[self.train_indices["acq"]], dtype=torch.double)
x_test = torch.tensor(self.features[~self.train_indices["acq"]], dtype=torch.double)
reference = y_train.mean(0)[0]
if self.params.acq_func == "qEHVI":
if acq == "qEHVI":
return optimize_qEHVI(model=model,
reference=reference,
y_train=y_train,
x_test=x_test,
n_candidates=self.params.num_candidates)
elif self.params.acq_func == "qNEHVI":
elif acq == "qNEHVI":
return optimize_qNEHVI(model=model,
reference=reference,
x_train=x_train,
Expand All @@ -156,16 +175,19 @@ def run_optimization(self):
"""
for iter in range(self.params.n_iters):
logger.info(f"MOBOQM9 iteration {iter + 1} of {self.params.n_iters}.")
model = self.get_surrogate_model()
if self.params.acq_func == "random":
for _ in range(self.params.num_candidates):
idx = np.random.choice(np.where(~self.train_indices)[0])
self.train_indices[idx] = True
else:
candidates = self.optimize_acquisition_function(model)
self.update_train_indices(candidates)
if self.stopping_criteria_met():
break
for acq in ["qEHVI", "qNEHVI", "random"]:
if self.acq_met[acq]:
continue
model = self.get_surrogate_model(acq)
if acq == "random":
for _ in range(self.params.num_candidates):
idx = np.random.choice(np.where(~self.train_indices)[0])
self.train_indices[acq][idx] = True
else:
candidates = self.optimize_acquisition_function(model)
self.update_train_indices(candidates, acq)
self.stopping_criteria_met(acq)

logger.info("MOBOQM9 optimization finished.")

def get_train_indices(self):
Expand All @@ -180,17 +202,17 @@ def get_train_indices(self):
self.params.num_seed_points)
mask = np.zeros(len(self.total_indices), dtype=bool)
mask[temp_indices] = True
return mask
return {"qEHVI": mask, "qNEHVI": mask, "random": mask}

def stopping_criteria_met(self):
def stopping_criteria_met(self, acq):
"""
Checks if the MOBOQM9 optimization has met the stopping criteria.
returns:
bool: True if the MOBOQM9 optimization has met the stopping criteria.
"""
y_global = torch.tensor(self.targets)
y_current = torch.tensor(self.targets[self.train_indices])
y_current = torch.tensor(self.targets[self.train_indices[acq]])
ref_points = y_global.min(0)[0]
bd_global = DominatedPartitioning(
ref_point=ref_points,
Expand All @@ -202,20 +224,21 @@ def stopping_criteria_met(self):
Y=y_current,
)
volume_current = bd_current.compute_hypervolume().item()
return volume_global == volume_current
self.acq_met[acq] = (volume_global == volume_current)


def update_train_indices(self, candidates):
def update_train_indices(self, candidates, acq):
"""
Updates the train indices for the MOBOQM9 model.
args:
candidates: Candidates for the MOBOQM9 model.
acq: Acquisition function to use.
"""
for cand in candidates:
for idx, feat in enumerate(self.features):
if np.allclose(feat, cand):
self.train_indices[idx] = True
self.train_indices[acq][idx] = True

def validate_params(self):
"""
Expand All @@ -227,7 +250,6 @@ def validate_params(self):
assert self.params.featurizer in ["ECFP", "CM", "ACSF"], "Featurizer must be one of ECFP, CM, or ACSF."
assert self.params.kernel in ["RBF", "Matern", "Tanimoto"], "Kernel must be one of RBF, Matern."
assert self.params.surrogate_model in ["GaussianProcess", "RandomForest"], "Surrogate model must be one of GaussianProcess, or RandomForest."
assert self.params.acq_func in ["qEHVI", "qNEHVI", "random"], "Acquisition function must be one of qEHVI, or qNEHVI, or random."
assert len(self.params.targets) == len(self.params.target_bools), "Number of targets must equal number of target booleans."
assert self.params.num_total_points > 0, "Number of total points must be greater than zero."
assert self.params.num_seed_points > 0, "Number of seed points must be greater than zero."
Expand Down

0 comments on commit a04d2ea

Please sign in to comment.