diff --git a/baybe/surrogates/gaussian_process.py b/baybe/surrogates/gaussian_process.py index 7e8ba4ab7..c154aa8d7 100644 --- a/baybe/surrogates/gaussian_process.py +++ b/baybe/surrogates/gaussian_process.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, ClassVar from attr import define, field @@ -12,7 +12,6 @@ from baybe.surrogates.base import Surrogate if TYPE_CHECKING: - from botorch.models import SingleTaskGP from torch import Tensor @@ -31,7 +30,9 @@ class GaussianProcessSurrogate(Surrogate): kernel: Kernel = field(factory=MaternKernel) """The kernel used by the Gaussian Process.""" - _model: Optional[SingleTaskGP] = field(init=False, default=None) + # TODO: type should be Optional[botorch.models.SingleTaskGP] but is currently + # omitted due to: https://github.com/python-attrs/cattrs/issues/531 + _model = field(init=False, default=None) """The actual model.""" def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: