Skip to content

Commit

Permalink
Merge pull request #70 from kkovary/68-graph-gp-serializer
Browse files Browse the repository at this point in the history
custom SIGP model save & load
  • Loading branch information
Ryan-Rhys authored Oct 3, 2024
2 parents 6e55384 + b6ee916 commit 9fc64b7
Showing 1 changed file with 198 additions and 7 deletions.
205 changes: 198 additions & 7 deletions gauche/gp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,56 @@
"""
Implements the SIGP model, a custom Gaussian Process for non-tensorial inputs.
Key components:
- load_class: Dynamically loads classes from modules.
- NonTensorialInputs: Container for non-tensor data (e.g., graphs).
- SIGP: Custom ExactGP allowing non-tensorial inputs.
Provides functionality for handling non-tensorial inputs, saving/loading models,
and working with graph kernels in Gaussian Process models.
Extends GPyTorch's ExactGP for structured, non-tensor inputs like graphs.
"""

import importlib
from copy import copy, deepcopy
from functools import lru_cache
from typing import Any, Optional

import torch
from gpytorch import settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import _GaussianLikelihoodBase
from gpytorch.models import ExactGP
from gpytorch.models.exact_prediction_strategies import prediction_strategy
from loguru import logger


def load_class(module_name: str, class_name: str) -> Any:
"""
Dynamically load a class from a given module with error handling.
Args:
module_name (str): The name of the module.
class_name (str): The name of the class to load.
Returns:
class: The dynamically loaded class.
Raises:
ImportError: If the module cannot be loaded.
AttributeError: If the class cannot be found in the module.
"""
try:
module = importlib.import_module(module_name)
return getattr(module, class_name)
except ImportError as err:
raise ImportError(
f"Module {module_name} could not be loaded: {str(err)}"
) from err
except AttributeError as err:
raise AttributeError(
f"Class {class_name} not found in {module_name}: {str(err)}"
) from err


class NonTensorialInputs:
Expand Down Expand Up @@ -52,9 +96,11 @@ def __init__(self, train_inputs, train_targets, likelihood):
super(ExactGP, self).__init__()
if train_inputs is not None:
self.train_inputs = tuple(
i.unsqueeze(-1)
if torch.is_tensor(i) and i.ndimension() == 1
else i
(
i.unsqueeze(-1)
if torch.is_tensor(i) and i.ndimension() == 1
else i
)
for i in train_inputs
)
self.train_targets = train_targets
Expand All @@ -71,9 +117,11 @@ def __call__(self, *args, **kwargs):
)

inputs = [
i.unsqueeze(-1)
if torch.is_tensor(i) and i.ndimension() == 1
else i
(
i.unsqueeze(-1)
if torch.is_tensor(i) and i.ndimension() == 1
else i
)
for i in args
]

Expand Down Expand Up @@ -182,3 +230,146 @@ def __call__(self, *args, **kwargs):
*batch_shape, *test_shape
).contiguous()
return full_output.__class__(predictive_mean, predictive_covar)

@classmethod
def save(
cls,
model: "SIGP",
optimizer: Optional[torch.optim.Optimizer] = None,
filename: str = "model.pth",
) -> None:
"""
Saves the model state, optimizer state, training data, and other configurations to a file.
Args:
model (SIGP): The model instance to save.
optimizer (Optional[torch.optim.Optimizer]): The optimizer associated with the model. Default is None.
filename (str): The filename where the model state will be saved. Default is "model.pth".
Returns:
None
"""
logger.info(f"Saving model state to {filename}")
model_state = {
"version": "0.1.0",
"model_state_dict": model.state_dict(),
"optimizer_class": (
optimizer.__class__.__name__ if optimizer is not None else None
),
"optimizer_state_dict": (
optimizer.state_dict() if optimizer is not None else None
),
"train_inputs": model.train_inputs,
"train_targets": model.train_targets,
"likelihood_state_dict": model.likelihood.state_dict(),
"covariance_state_dict": model.covariance.state_dict(),
"mean_module_state_dict": model.mean.state_dict(),
"model_class": model.__class__.__name__,
"likelihood_class": model.likelihood.__class__.__name__,
"covar_module_class": model.covariance.__class__.__name__,
"mean_module_class": model.mean.__class__.__name__,
"covar_module_args": (
{"node_label": model.covariance.node_label}
if hasattr(model.covariance, "node_label")
else {}
),
}
torch.save(model_state, filename)

@classmethod
def load(cls, filename: str = "model.pth") -> "SIGP":
"""
Load the model state and other configurations from a file for inference.
Args:
filename (str): The filename from which to load the model state. Default is "model.pth".
Returns:
SIGP: The loaded model instance.
Raises:
ValueError: If the model class specified in the file is not found.
"""
logger.info(f"Loading model state from {filename}")
model_state = torch.load(filename)

if model_state.get("version", "0.0.0") != "0.1.0":
logger.warning(
f"Loading model version {model_state.get('version', '0.0.0')}. Current version is 0.1.0."
)

# Dynamically get the class from globals() or a predefined mapping if not available directly
ModelClass = globals().get(model_state["model_class"], None)
if ModelClass is None:
raise ValueError(
f"Model class {model_state['model_class']} not found."
)

LikelihoodClass = load_class(
"gpytorch.likelihoods", model_state["likelihood_class"]
)
if LikelihoodClass is None:
raise ValueError(
f"Likelihood class {model_state['likelihood_class']} not found."
)
CovarModuleClass = load_class(
"gauche.kernels.graph_kernels", model_state["covar_module_class"]
)
if CovarModuleClass is None:
raise ValueError(
f"Covariance module class {model_state['covar_module_class']} not found."
)
MeanModuleClass = load_class(
"gpytorch.means", model_state["mean_module_class"]
)

model = ModelClass(
train_x=model_state["train_inputs"],
train_y=model_state["train_targets"],
likelihood=LikelihoodClass(),
covar_module=CovarModuleClass(**model_state["covar_module_args"]),
mean_module=MeanModuleClass(),
)
model.load_state_dict(model_state["model_state_dict"])
model.likelihood.load_state_dict(model_state["likelihood_state_dict"])
model.covariance.load_state_dict(model_state["covariance_state_dict"])
model.mean.load_state_dict(model_state["mean_module_state_dict"])

model.eval()
return model

@staticmethod
def load_optimizer(
filename: str = "model.pth",
) -> Optional[torch.optim.Optimizer]:
"""
Load the optimizer state from a file.
Args:
filename (str): The filename from which to load the optimizer state. Default is "model.pth".
Returns:
Optional[torch.optim.Optimizer]: The loaded optimizer if available, None otherwise.
"""
logger.info(f"Loading optimizer state from {filename}")
model_state = torch.load(filename)

optimizer_class_name = model_state.get("optimizer_class")
optimizer_state_dict = model_state.get("optimizer_state_dict")

if optimizer_class_name is None or optimizer_state_dict is None:
logger.warning("No optimizer information found in the saved file.")
return None

OptimizerClass = getattr(torch.optim, optimizer_class_name, None)
if OptimizerClass is None:
logger.warning(
f"Optimizer class {optimizer_class_name} not found in torch.optim."
)
return None

# Create a dummy optimizer with a single parameter
optimizer = OptimizerClass([torch.nn.Parameter(torch.empty(1))])
optimizer.load_state_dict(optimizer_state_dict)

return optimizer

0 comments on commit 9fc64b7

Please sign in to comment.