Skip to content

Commit

Permalink
Merge branch 'main' of github.com:AC-BO-Hackathon/project-mobo-qm9
Browse files Browse the repository at this point in the history
  • Loading branch information
mamunm committed Mar 27, 2024
2 parents acb1ca2 + b439707 commit 555da9c
Showing 1 changed file with 36 additions and 4 deletions.
40 changes: 36 additions & 4 deletions src/mobo_qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
from botorch.utils.multi_objective.box_decompositions import DominatedPartitioning

from .data.cm_featurizer import get_coulomb_matrix
from .acquisition_functions import optimize_qEHVI, optimize_qNEHVI

import torch
from botorch.models import ModelListGP, FixedNoiseGP
from gpytorch.kernels import RBFKernel, MaternKernel, TanimotoKernel
from gpytorch.likelihoods import GaussianLikelihood
from botorch.fit import fit_gpytorch_model
from botorch.utils.transforms import Standardize

N_TOTAL_POINTS = 138_728

Expand Down Expand Up @@ -71,7 +76,7 @@ def get_features_and_targets(self):
else:
raise NotImplementedError

def get_surrogate_model(self, X, y):
def get_surrogate_model(self, X, y, kernel_type='RBF'):
"""
Gets the surrogate model for the MOBOQM9 model.
Expand All @@ -82,7 +87,34 @@ def get_surrogate_model(self, X, y):
returns:
model: Surrogate model for the MOBOQM9 model.
"""
y_copy = y.copy()
X_scaled = Standardize(X)
train_X = torch.tensor(X_scaled, dtype=torch.float32)
train_Y = torch.tensor(y, dtype=torch.float32)
likelihood = GaussianLikelihood()

input_transform = Standardize(m=train_X.shape[-2])
train_X_scaled = input_transform(train_X)


if kernel_type == 'RBF':
kernel = RBFKernel()
elif kernel_type == 'Matern':
kernel = MaternKernel()
elif kernel_type == 'Tanimoto':
kernel = TanimotoKernel()
else:
raise ValueError("Unsupported kernel type. Supported types are 'RBF', 'Matern', and 'Tanimoto'.")

models = [FixedNoiseGP(train_X_scaled, train_Y, noise=torch.zeros_like(train_Y), likelihood=likelihood, kernel=kernel)]

model = ModelListGP(*models)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
fit_gpytorch_model(mll)

return model, input_transform

def correct_sign(self,Y)
y_copy = Y.copy()
for idx, mask in enumerate(self.params.target_bools):
if not mask:
y_copy[:, idx] *= -1
Expand Down Expand Up @@ -204,4 +236,4 @@ def validate_params(self):
assert len(self.params.targets) == len(self.params.target_bools), "Number of targets must equal number of target booleans."
assert self.params.num_total_points > 0, "Number of total points must be greater than zero."
assert self.params.num_seed_points > 0, "Number of seed points must be greater than zero."
assert self.params.n_iters > 0, "Number of iterations must be greater than zero."
assert self.params.n_iters > 0, "Number of iterations must be greater than zero."

0 comments on commit 555da9c

Please sign in to comment.