Skip to content

Commit

Permalink
added surrogate model
Browse files Browse the repository at this point in the history
  • Loading branch information
akhilsnair2017 committed Mar 27, 2024
1 parent eb10957 commit 4c6dbaf
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions src/mobo_qm9.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import NamedTuple, Literal, List
import numpy as np
from loguru import logger

from .data.cm_featurizer import get_coulomb_matrix

import torch
from botorch.models import SingleTaskGP
from gpytorch.kernels import RBFKernel, MaternKernel, TanimotoKernel
from gpytorch.likelihoods import GaussianLikelihood
from botorch.fit import fit_gpytorch_model
from botorch.utils.transforms import Standardize

N_TOTAL_POINTS = 138_728

Expand Down Expand Up @@ -68,7 +73,7 @@ def get_features_and_targets(self):
else:
raise NotImplementedError

def get_surrogate_model(self, X, y):
def get_surrogate_model(self, X, y, kernel_type='RBF'):
"""
Gets the surrogate model for the MOBOQM9 model.
Expand All @@ -79,6 +84,26 @@ def get_surrogate_model(self, X, y):
returns:
model: Surrogate model for the MOBOQM9 model.
"""
X_scaled = Standardize(X)
train_X = torch.tensor(X_scaled, dtype=torch.float32)
train_Y = torch.tensor(y, dtype=torch.float32)
likelihood = GaussianLikelihood()

if kernel_type == 'RBF':
kernel = RBFKernel()
elif kernel_type == 'Matern':
kernel = MaternKernel()
elif kernel_type == 'Tanimoto':
kernel = TanimotoKernel()
else:
raise ValueError("Unsupported kernel type. Supported types are 'RBF', 'Matern', and 'Tanimoto'.")

model = SingleTaskGP(train_X, train_Y, likelihood=likelihood, kernel=kernel)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
fit_gpytorch_model(mll)

return model

y_copy = y.copy()
for idx, mask in enumerate(self.params.target_bools):
if not mask:
Expand Down Expand Up @@ -157,4 +182,4 @@ def validate_params(self):
assert len(self.params.targets) == len(self.params.target_bools), "Number of targets must equal number of target booleans."
assert self.params.num_total_points > 0, "Number of total points must be greater than zero."
assert self.params.num_seed_points > 0, "Number of seed points must be greater than zero."
assert self.params.n_iters > 0, "Number of iterations must be greater than zero."
assert self.params.n_iters > 0, "Number of iterations must be greater than zero."

0 comments on commit 4c6dbaf

Please sign in to comment.