diff --git a/src/mobo_qm9.py b/src/mobo_qm9.py index 46a3ae1..9d97593 100644 --- a/src/mobo_qm9.py +++ b/src/mobo_qm9.py @@ -5,8 +5,13 @@ from botorch.utils.multi_objective.box_decompositions import DominatedPartitioning from .data.cm_featurizer import get_coulomb_matrix -from .acquisition_functions import optimize_qEHVI, optimize_qNEHVI +import torch +from botorch.models import ModelListGP, FixedNoiseGP +from gpytorch.kernels import RBFKernel, MaternKernel, TanimotoKernel +from gpytorch.likelihoods import GaussianLikelihood +from botorch.fit import fit_gpytorch_model +from botorch.utils.transforms import Standardize N_TOTAL_POINTS = 138_728 @@ -71,7 +76,7 @@ def get_features_and_targets(self): else: raise NotImplementedError - def get_surrogate_model(self, X, y): + def get_surrogate_model(self, X, y, kernel_type='RBF'): """ Gets the surrogate model for the MOBOQM9 model. @@ -82,7 +87,34 @@ def get_surrogate_model(self, X, y): returns: model: Surrogate model for the MOBOQM9 model. """ - y_copy = y.copy() + X_scaled = Standardize(X) + train_X = torch.tensor(X_scaled, dtype=torch.float32) + train_Y = torch.tensor(y, dtype=torch.float32) + likelihood = GaussianLikelihood() + + input_transform = Standardize(m=train_X.shape[-2]) + train_X_scaled = input_transform(train_X) + + + if kernel_type == 'RBF': + kernel = RBFKernel() + elif kernel_type == 'Matern': + kernel = MaternKernel() + elif kernel_type == 'Tanimoto': + kernel = TanimotoKernel() + else: + raise ValueError("Unsupported kernel type. Supported types are 'RBF', 'Matern', and 'Tanimoto'.") + + models = [FixedNoiseGP(train_X_scaled, train_Y, noise=torch.zeros_like(train_Y), likelihood=likelihood, kernel=kernel)] + + model = ModelListGP(*models) + mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) + fit_gpytorch_model(mll) + + return model, input_transform + + def correct_sign(self,Y) + y_copy = Y.copy() for idx, mask in enumerate(self.params.target_bools): if not mask: y_copy[:, idx] *= -1 @@ -204,4 +236,4 @@ def validate_params(self): assert len(self.params.targets) == len(self.params.target_bools), "Number of targets must equal number of target booleans." assert self.params.num_total_points > 0, "Number of total points must be greater than zero." assert self.params.num_seed_points > 0, "Number of seed points must be greater than zero." - assert self.params.n_iters > 0, "Number of iterations must be greater than zero." \ No newline at end of file + assert self.params.n_iters > 0, "Number of iterations must be greater than zero."