From cb7f4619df10782e20524f066ddb7e291cdb7010 Mon Sep 17 00:00:00 2001 From: Thomas Morris Date: Sun, 21 Apr 2024 14:04:13 -0700 Subject: [PATCH] better error checking for dofs --- src/blop/agent.py | 6 +- src/blop/dofs.py | 152 ++++++++++++++++++++++------------------- src/blop/objectives.py | 36 ++++------ 3 files changed, 97 insertions(+), 97 deletions(-) diff --git a/src/blop/agent.py b/src/blop/agent.py index b8be285..231ab23 100644 --- a/src/blop/agent.py +++ b/src/blop/agent.py @@ -25,7 +25,6 @@ from botorch.models.transforms.input import Normalize from databroker import Broker from ophyd import Signal -from tqdm import tqdm from . import plotting, utils from .bayesian import acquisition, models @@ -402,8 +401,9 @@ def learn( new_table = yield from self.acquire(center_inputs) new_table.loc[:, "acq_func"] = "sample_center_on_init" - for i in tqdm(range(iterations), desc="Learning..."): - print(f"running iteration {i + 1} / {iterations}") + for i in range(iterations): + if self.verbose: + print(f"running iteration {i + 1} / {iterations}") for single_acq_func in np.atleast_1d(acq_func): res = self.ask(n=n, acq_func_identifier=single_acq_func, upsample=upsample, route=route, **acq_func_kwargs) new_table = yield from self.acquire(res["points"]) diff --git a/src/blop/dofs.py b/src/blop/dofs.py index 733a817..44faff4 100644 --- a/src/blop/dofs.py +++ b/src/blop/dofs.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Sequence from dataclasses import dataclass, field, fields from operator import attrgetter -from typing import Tuple +from typing import Tuple, Union import numpy as np import pandas as pd @@ -17,14 +17,15 @@ "transform": "str", "search_domain": "object", "trust_domain": "object", - "units": "str", + "domain": "object", "active": "bool", "read_only": "bool", + "units": "str", "tags": "object", } DOF_TYPES = ["continuous", "binary", "ordinal", "categorical"] -TRANSFORM_DOMAINS = {"log": (0.0, np.inf), "sigmoid": (0.0, 1.0), "tanh": (-1.0, 1.0)} +TRANSFORM_DOMAINS = {"log": (0.0, np.inf), "logit": (0.0, 1.0), "arctanh": (-1.0, 1.0)} class ReadOnlyError(Exception): @@ -43,14 +44,6 @@ def _validate_dofs(dofs): return list(dofs) -def _validate_dof_transform(transform): - if transform is None: - return (-np.inf, np.inf) - - if transform not in TRANSFORM_DOMAINS: - raise ValueError(f"'transform' must be a callable with one argument, or one of {TRANSFORM_DOMAINS}") - - def _validate_continuous_dof_domains(search_domain, trust_domain, domain): """ A DOF MUST have a search domain, and it MIGHT have a trust domain or a transform domain. @@ -59,33 +52,37 @@ def _validate_continuous_dof_domains(search_domain, trust_domain, domain): search_domain \\subseteq trust_domain \\subseteq domain """ - if len(search_domain) != 2: - raise ValueError("'search_domain' must be a 2-tuple of numbers.") + try: + search_domain = tuple((float(search_domain[0]), float(search_domain[1]))) + assert len(search_domain) == 2 + except: # noqa + raise ValueError("If type='continuous', then 'search_domain' must be a tuple of two numbers.") if search_domain[0] >= search_domain[1]: raise ValueError("The lower search bound must be strictly less than the upper search bound.") if domain is not None: - if (search_domain[0] < domain[0]) or (search_domain[1] > domain[1]): - raise ValueError(f"The search domain {search_domain} is outside the transform domain {domain}.") + if (search_domain[0] <= domain[0]) or (search_domain[1] >= domain[1]): + raise ValueError(f"The search domain {search_domain} must be a strict subset of the domain {domain}.") if trust_domain is not None: if (search_domain[0] < trust_domain[0]) or (search_domain[1] > trust_domain[1]): - raise ValueError(f"The search domain {search_domain} is outside the trust domain {trust_domain}.") + raise ValueError(f"The search domain {search_domain} must be a subset of the trust domain {trust_domain}.") if (trust_domain is not None) and (domain is not None): if (trust_domain[0] < domain[0]) or (trust_domain[1] > domain[1]): - raise ValueError(f"The trust domain {trust_domain} is outside the transform domain {domain}.") + raise ValueError(f"The trust domain {trust_domain} must be a subset of the trust domain {domain}.") -def _validate_discrete_dof_domains(search_domain, trust_domain, domain): +def _validate_discrete_dof_domains(search_domain, trust_domain): """ A DOF MUST have a search domain, and it MIGHT have a trust domain or a transform domain Check that all the domains are kosher by enforcing that: search_domain \\subseteq trust_domain \\subseteq domain """ - ... + if not trust_domain.issuperset(search_domain): + raise ValueError(f"The trust domain {trust_domain} not a superset of the search domain {search_domain}.") @dataclass @@ -126,9 +123,9 @@ class DOF: name: str = None description: str = "" - type: str = "continuous" - search_domain: Tuple[float, float] = None - trust_domain: Tuple[float, float] = None + type: str = None + search_domain: Union[Tuple[float, float], Sequence] = None + trust_domain: Union[Tuple[float, float], Sequence] = None units: str = None read_only: bool = False active: bool = True @@ -151,27 +148,33 @@ def __repr__(self): # Some post-processing. This is specific to dataclasses def __post_init__(self): - if self.type not in DOF_TYPES: - raise ValueError(f"'type' must be one of {DOF_TYPES}") - if (self.name is None) ^ (self.device is None): if self.name is None: self.name = self.device.name else: - raise ValueError("DOF() accepts exactly one of either a name or an ophyd device.") + raise ValueError("You must specify exactly one of 'name' or 'device'.") - # if our input is continuous - if self.type == "continuous": - _validate_dof_transform(self.transform) + if self.search_domain is None: + if not self.read_only: + raise ValueError("You must specify search_domain if read_only=False.") - if self.trust_domain is None: - self.trust_domain = TRANSFORM_DOMAINS[self.transform] if self.transform is not None else (-np.inf, np.inf) + if self.type is None: + if isinstance(self.search_domain, tuple): + self.type = "continuous" + elif isinstance(self.search_domain, set): + if len(self.search_domain) == 2: + self.type = "binary" + else: + self.type = "categorical" - if self.search_domain is None: - if not self.read_only: - raise ValueError("You must specify search_domain if the device is not read-only.") - else: - _validate_continuous_dof_domains(self.search_domain, self.trust_domain, self.domain) + if self.type not in DOF_TYPES: + raise ValueError(f"'type' must be one of {DOF_TYPES}") + + # our input is usually continuous + if self.type == "continuous": + _validate_continuous_dof_domains(self._search_domain, self._trust_domain, self.domain) + + self.search_domain = tuple((float(self.search_domain[0]), float(self.search_domain[1]))) if self.device is None: center = float(self._untransform(np.mean([self._transform(np.array(self.search_domain))]))) @@ -179,7 +182,7 @@ def __post_init__(self): # otherwise it must be discrete else: - _validate_discrete_dof_domains(self.search_domain, self.trust_domain, self.domain) + _validate_discrete_dof_domains(self._search_domain, self._trust_domain) if self.type == "binary": if self.search_domain is None: @@ -203,33 +206,48 @@ def __post_init__(self): self.device.kind = "hinted" @property - def domain(self): + def _search_domain(self): """ - The total domain of the DOF. + Compute the search domain of the DOF. """ - if self.transform is None: + if self.read_only: + value = self.readback if self.type == "continuous": - return (-np.inf, np.inf) + return tuple(value, value) else: - return self.search_domain - return TRANSFORM_DOMAINS[self.transform] + return {value} + else: + return self.search_domain @property - def _search_domain(self): - if self.read_only: - _readback = self.readback - return np.array([_readback, _readback]) - return np.array(self.search_domain) + def _trust_domain(self): + """ + If trust_domain is None, then we return the total domain. + """ + return self.trust_domain or self.domain + + @property + def domain(self): + """ + The total domain; the user can't control this. This is what we fall back on as the trust_domain if none is supplied. + If the DOF is continuous: + If there is a transform, return the domain of the transform + Else, return (-inf, inf) + If the DOF is discrete: + If there is a trust domain, return the trust domain + Else, return the search domain + """ + if self.type == "continuous": + if self.transform is None: + return (-np.inf, np.inf) + else: + return TRANSFORM_DOMAINS[self.transform] + else: + return self.trust_domain or self.search_domain def _trust(self, x): return (self.trust_domain[0] <= x) & (x <= self.trust_domain[1]) - @property - def _trust_domain(self): - if self.trust_domain is None: - return self.domain - return self.trust_domain - def _transform(self, x, normalize=True): if not isinstance(x, torch.Tensor): x = torch.tensor(x, dtype=torch.double) @@ -238,9 +256,9 @@ def _transform(self, x, normalize=True): if self.transform == "log": x = torch.log(x) - if self.transform == "sigmoid": + if self.transform == "logit": x = (x / (1 - x)).log() - if self.transform == "tanh": + if self.transform == "arctanh": x = torch.arctanh(x) if normalize and not self.read_only: @@ -261,19 +279,11 @@ def _untransform(self, x): return x if self.transform == "log": return torch.exp(x) - if self.transform == "sigmoid": + if self.transform == "logit": return 1 / (1 + torch.exp(-x)) - if self.transform == "tanh": + if self.transform == "arctanh": return torch.tanh(x) - # @property - # def _transformed_search_domain(self): - # return self._transform(np.array(self._search_domain), normalize=False) - - # @property - # def _transformed_trust_domain(self): - # return self._transform(np.array(self._trust_domain), normalize=False) - @property def readback(self): # there is probably a better way to do this @@ -284,11 +294,13 @@ def summary(self) -> pd.Series: series = pd.Series(index=list(DOF_FIELD_TYPES.keys()), dtype="object") for attr in series.index: value = getattr(self, attr) - if attr in ["search_domain", "trust_domain"]: - if (self.type == "continuous") and not self.read_only: - if value is not None: + if attr in ["search_domain", "trust_domain", "domain"]: + if (self.type == "continuous") and not self.read_only and value is not None: + if attr in ["search_domain", "trust_domain"]: + value = f"[{value[0]:.02e}, {value[1]:.02e}]" + else: value = f"({value[0]:.02e}, {value[1]:.02e})" - series[attr] = value if value is not None else "" + series[attr] = value if value else "" return series @property diff --git a/src/blop/objectives.py b/src/blop/objectives.py index 022ee27..10e4784 100644 --- a/src/blop/objectives.py +++ b/src/blop/objectives.py @@ -13,23 +13,23 @@ OBJ_FIELD_TYPES = { "description": "object", - "kind": "str", + # "kind": "str", "type": "str", "target": "object", - "active": "bool", "transform": "str", + "domain": "str", "trust_domain": "object", - "active": "bool", - "weight": "bool", + "weight": "float", "units": "object", "noise_bounds": "object", "noise": "float", - "n": "int", + "n_valid": "int", "latent_groups": "object", + "active": "bool", } SUPPORTED_OBJ_TYPES = ["continuous", "binary", "ordinal", "categorical"] -TRANSFORM_DOMAINS = {"log": (0.0, np.inf), "sigmoid": (0.0, 1.0), "tanh": (-1.0, 1.0)} +TRANSFORM_DOMAINS = {"log": (0.0, np.inf), "logit": (0.0, 1.0), "arctanh": (-1.0, 1.0)} class DuplicateNameError(ValueError): @@ -160,9 +160,9 @@ def _transform(self, y): if self.transform == "log": y = y.log() - if self.transform == "sigmoid": + if self.transform == "logit": y = (y / (1 - y)).log() - if self.transform == "tanh": + if self.transform == "arctanh": y = torch.arctanh(y) if self.target == "min": @@ -179,9 +179,9 @@ def _untransform(self, y): if self.transform == "log": y = y.exp() - if self.transform == "sigmoid": + if self.transform == "logit": y = 1 / (1 + torch.exp(-y)) - if self.transform == "tanh": + if self.transform == "arctanh": y = torch.tanh(y) return y @@ -212,21 +212,9 @@ def summary(self) -> pd.Series: series[attr] = value if value is not None else "" return series - # @property - # def trust_lower_bound(self): - # if self.trust_domain is None: - # return 0 if self.log else -np.inf - # return float(self.trust_domain[0]) - - # @property - # def trust_upper_bound(self): - # if self.trust_domain is None: - # return np.inf - # return float(self.trust_domain[1]) - @property def noise(self) -> float: - return self.model.likelihood.noise.item() if hasattr(self, "model") else None + return self.model.likelihood.noise.item() if hasattr(self, "model") else np.nan @property def snr(self) -> float: @@ -245,7 +233,7 @@ def targeting_constraint(self, x: torch.Tensor) -> torch.Tensor: m = p.mean s = p.variance.sqrt() - sish = s + 0.1 * m.std() + sish = s + 0.1 * m.std() # for numerical stability return ( 0.5 * (approximate_erf((b - m) / (np.sqrt(2) * sish)) - approximate_erf((a - m) / (np.sqrt(2) * sish)))[..., -1]