diff --git a/src/mobo_qm9.py b/src/mobo_qm9.py index 6b47b43..d8f468e 100644 --- a/src/mobo_qm9.py +++ b/src/mobo_qm9.py @@ -1,9 +1,14 @@ from typing import NamedTuple, Literal, List import numpy as np from loguru import logger - from .data.cm_featurizer import get_coulomb_matrix +import torch +from botorch.models import SingleTaskGP +from gpytorch.kernels import RBFKernel, MaternKernel, TanimotoKernel +from gpytorch.likelihoods import GaussianLikelihood +from botorch.fit import fit_gpytorch_model +from botorch.utils.transforms import Standardize N_TOTAL_POINTS = 138_728 @@ -68,7 +73,7 @@ def get_features_and_targets(self): else: raise NotImplementedError - def get_surrogate_model(self, X, y): + def get_surrogate_model(self, X, y, kernel_type='RBF'): """ Gets the surrogate model for the MOBOQM9 model. @@ -79,6 +84,26 @@ def get_surrogate_model(self, X, y): returns: model: Surrogate model for the MOBOQM9 model. """ + X_scaled = Standardize(X) + train_X = torch.tensor(X_scaled, dtype=torch.float32) + train_Y = torch.tensor(y, dtype=torch.float32) + likelihood = GaussianLikelihood() + + if kernel_type == 'RBF': + kernel = RBFKernel() + elif kernel_type == 'Matern': + kernel = MaternKernel() + elif kernel_type == 'Tanimoto': + kernel = TanimotoKernel() + else: + raise ValueError("Unsupported kernel type. Supported types are 'RBF', 'Matern', and 'Tanimoto'.") + + model = SingleTaskGP(train_X, train_Y, likelihood=likelihood, kernel=kernel) + mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) + fit_gpytorch_model(mll) + + return model + y_copy = y.copy() for idx, mask in enumerate(self.params.target_bools): if not mask: @@ -157,4 +182,4 @@ def validate_params(self): assert len(self.params.targets) == len(self.params.target_bools), "Number of targets must equal number of target booleans." assert self.params.num_total_points > 0, "Number of total points must be greater than zero." assert self.params.num_seed_points > 0, "Number of seed points must be greater than zero." - assert self.params.n_iters > 0, "Number of iterations must be greater than zero." \ No newline at end of file + assert self.params.n_iters > 0, "Number of iterations must be greater than zero."