From ba3cbb56b96008326c646522c74c2c9798a7f3a4 Mon Sep 17 00:00:00 2001 From: AdrianSosic Date: Mon, 15 Apr 2024 22:36:55 +0200 Subject: [PATCH] Remove SingleTaskGP type hint Due to serialization issues: https://github.com/python-attrs/cattrs/issues/531 --- baybe/surrogates/gaussian_process.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/baybe/surrogates/gaussian_process.py b/baybe/surrogates/gaussian_process.py index 7e8ba4ab73..c154aa8d7a 100644 --- a/baybe/surrogates/gaussian_process.py +++ b/baybe/surrogates/gaussian_process.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, ClassVar from attr import define, field @@ -12,7 +12,6 @@ from baybe.surrogates.base import Surrogate if TYPE_CHECKING: - from botorch.models import SingleTaskGP from torch import Tensor @@ -31,7 +30,9 @@ class GaussianProcessSurrogate(Surrogate): kernel: Kernel = field(factory=MaternKernel) """The kernel used by the Gaussian Process.""" - _model: Optional[SingleTaskGP] = field(init=False, default=None) + # TODO: type should be Optional[botorch.models.SingleTaskGP] but is currently + # omitted due to: https://github.com/python-attrs/cattrs/issues/531 + _model = field(init=False, default=None) """The actual model.""" def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: