Skip to content

Commit

Permalink
better error checking for dofs
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Apr 21, 2024
1 parent 97c7244 commit cb7f461
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 97 deletions.
6 changes: 3 additions & 3 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from botorch.models.transforms.input import Normalize
from databroker import Broker
from ophyd import Signal
from tqdm import tqdm

from . import plotting, utils
from .bayesian import acquisition, models
Expand Down Expand Up @@ -402,8 +401,9 @@ def learn(
new_table = yield from self.acquire(center_inputs)
new_table.loc[:, "acq_func"] = "sample_center_on_init"

for i in tqdm(range(iterations), desc="Learning..."):
print(f"running iteration {i + 1} / {iterations}")
for i in range(iterations):
if self.verbose:
print(f"running iteration {i + 1} / {iterations}")
for single_acq_func in np.atleast_1d(acq_func):
res = self.ask(n=n, acq_func_identifier=single_acq_func, upsample=upsample, route=route, **acq_func_kwargs)
new_table = yield from self.acquire(res["points"])
Expand Down
152 changes: 82 additions & 70 deletions src/blop/dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field, fields
from operator import attrgetter
from typing import Tuple
from typing import Tuple, Union

import numpy as np
import pandas as pd
Expand All @@ -17,14 +17,15 @@
"transform": "str",
"search_domain": "object",
"trust_domain": "object",
"units": "str",
"domain": "object",
"active": "bool",
"read_only": "bool",
"units": "str",
"tags": "object",
}

DOF_TYPES = ["continuous", "binary", "ordinal", "categorical"]
TRANSFORM_DOMAINS = {"log": (0.0, np.inf), "sigmoid": (0.0, 1.0), "tanh": (-1.0, 1.0)}
TRANSFORM_DOMAINS = {"log": (0.0, np.inf), "logit": (0.0, 1.0), "arctanh": (-1.0, 1.0)}


class ReadOnlyError(Exception):
Expand All @@ -43,14 +44,6 @@ def _validate_dofs(dofs):
return list(dofs)


def _validate_dof_transform(transform):
if transform is None:
return (-np.inf, np.inf)

if transform not in TRANSFORM_DOMAINS:
raise ValueError(f"'transform' must be a callable with one argument, or one of {TRANSFORM_DOMAINS}")


def _validate_continuous_dof_domains(search_domain, trust_domain, domain):
"""
A DOF MUST have a search domain, and it MIGHT have a trust domain or a transform domain.
Expand All @@ -59,33 +52,37 @@ def _validate_continuous_dof_domains(search_domain, trust_domain, domain):
search_domain \\subseteq trust_domain \\subseteq domain
"""

if len(search_domain) != 2:
raise ValueError("'search_domain' must be a 2-tuple of numbers.")
try:
search_domain = tuple((float(search_domain[0]), float(search_domain[1])))
assert len(search_domain) == 2
except: # noqa
raise ValueError("If type='continuous', then 'search_domain' must be a tuple of two numbers.")

if search_domain[0] >= search_domain[1]:
raise ValueError("The lower search bound must be strictly less than the upper search bound.")

if domain is not None:
if (search_domain[0] < domain[0]) or (search_domain[1] > domain[1]):
raise ValueError(f"The search domain {search_domain} is outside the transform domain {domain}.")
if (search_domain[0] <= domain[0]) or (search_domain[1] >= domain[1]):
raise ValueError(f"The search domain {search_domain} must be a strict subset of the domain {domain}.")

if trust_domain is not None:
if (search_domain[0] < trust_domain[0]) or (search_domain[1] > trust_domain[1]):
raise ValueError(f"The search domain {search_domain} is outside the trust domain {trust_domain}.")
raise ValueError(f"The search domain {search_domain} must be a subset of the trust domain {trust_domain}.")

if (trust_domain is not None) and (domain is not None):
if (trust_domain[0] < domain[0]) or (trust_domain[1] > domain[1]):
raise ValueError(f"The trust domain {trust_domain} is outside the transform domain {domain}.")
raise ValueError(f"The trust domain {trust_domain} must be a subset of the trust domain {domain}.")


def _validate_discrete_dof_domains(search_domain, trust_domain, domain):
def _validate_discrete_dof_domains(search_domain, trust_domain):
"""
A DOF MUST have a search domain, and it MIGHT have a trust domain or a transform domain
Check that all the domains are kosher by enforcing that:
search_domain \\subseteq trust_domain \\subseteq domain
"""
...
if not trust_domain.issuperset(search_domain):
raise ValueError(f"The trust domain {trust_domain} not a superset of the search domain {search_domain}.")


@dataclass
Expand Down Expand Up @@ -126,9 +123,9 @@ class DOF:

name: str = None
description: str = ""
type: str = "continuous"
search_domain: Tuple[float, float] = None
trust_domain: Tuple[float, float] = None
type: str = None
search_domain: Union[Tuple[float, float], Sequence] = None
trust_domain: Union[Tuple[float, float], Sequence] = None
units: str = None
read_only: bool = False
active: bool = True
Expand All @@ -151,35 +148,41 @@ def __repr__(self):

# Some post-processing. This is specific to dataclasses
def __post_init__(self):
if self.type not in DOF_TYPES:
raise ValueError(f"'type' must be one of {DOF_TYPES}")

if (self.name is None) ^ (self.device is None):
if self.name is None:
self.name = self.device.name
else:
raise ValueError("DOF() accepts exactly one of either a name or an ophyd device.")
raise ValueError("You must specify exactly one of 'name' or 'device'.")

# if our input is continuous
if self.type == "continuous":
_validate_dof_transform(self.transform)
if self.search_domain is None:
if not self.read_only:
raise ValueError("You must specify search_domain if read_only=False.")

if self.trust_domain is None:
self.trust_domain = TRANSFORM_DOMAINS[self.transform] if self.transform is not None else (-np.inf, np.inf)
if self.type is None:
if isinstance(self.search_domain, tuple):
self.type = "continuous"
elif isinstance(self.search_domain, set):
if len(self.search_domain) == 2:
self.type = "binary"
else:
self.type = "categorical"

if self.search_domain is None:
if not self.read_only:
raise ValueError("You must specify search_domain if the device is not read-only.")
else:
_validate_continuous_dof_domains(self.search_domain, self.trust_domain, self.domain)
if self.type not in DOF_TYPES:
raise ValueError(f"'type' must be one of {DOF_TYPES}")

# our input is usually continuous
if self.type == "continuous":
_validate_continuous_dof_domains(self._search_domain, self._trust_domain, self.domain)

self.search_domain = tuple((float(self.search_domain[0]), float(self.search_domain[1])))

if self.device is None:
center = float(self._untransform(np.mean([self._transform(np.array(self.search_domain))])))
self.device = Signal(name=self.name, value=center)

# otherwise it must be discrete
else:
_validate_discrete_dof_domains(self.search_domain, self.trust_domain, self.domain)
_validate_discrete_dof_domains(self._search_domain, self._trust_domain)

if self.type == "binary":
if self.search_domain is None:
Expand All @@ -203,33 +206,48 @@ def __post_init__(self):
self.device.kind = "hinted"

@property
def domain(self):
def _search_domain(self):
"""
The total domain of the DOF.
Compute the search domain of the DOF.
"""
if self.transform is None:
if self.read_only:
value = self.readback
if self.type == "continuous":
return (-np.inf, np.inf)
return tuple(value, value)
else:
return self.search_domain
return TRANSFORM_DOMAINS[self.transform]
return {value}
else:
return self.search_domain

@property
def _search_domain(self):
if self.read_only:
_readback = self.readback
return np.array([_readback, _readback])
return np.array(self.search_domain)
def _trust_domain(self):
"""
If trust_domain is None, then we return the total domain.
"""
return self.trust_domain or self.domain

@property
def domain(self):
"""
The total domain; the user can't control this. This is what we fall back on as the trust_domain if none is supplied.
If the DOF is continuous:
If there is a transform, return the domain of the transform
Else, return (-inf, inf)
If the DOF is discrete:
If there is a trust domain, return the trust domain
Else, return the search domain
"""
if self.type == "continuous":
if self.transform is None:
return (-np.inf, np.inf)
else:
return TRANSFORM_DOMAINS[self.transform]
else:
return self.trust_domain or self.search_domain

def _trust(self, x):
return (self.trust_domain[0] <= x) & (x <= self.trust_domain[1])

@property
def _trust_domain(self):
if self.trust_domain is None:
return self.domain
return self.trust_domain

def _transform(self, x, normalize=True):
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, dtype=torch.double)
Expand All @@ -238,9 +256,9 @@ def _transform(self, x, normalize=True):

if self.transform == "log":
x = torch.log(x)
if self.transform == "sigmoid":
if self.transform == "logit":
x = (x / (1 - x)).log()
if self.transform == "tanh":
if self.transform == "arctanh":
x = torch.arctanh(x)

if normalize and not self.read_only:
Expand All @@ -261,19 +279,11 @@ def _untransform(self, x):
return x
if self.transform == "log":
return torch.exp(x)
if self.transform == "sigmoid":
if self.transform == "logit":
return 1 / (1 + torch.exp(-x))
if self.transform == "tanh":
if self.transform == "arctanh":
return torch.tanh(x)

# @property
# def _transformed_search_domain(self):
# return self._transform(np.array(self._search_domain), normalize=False)

# @property
# def _transformed_trust_domain(self):
# return self._transform(np.array(self._trust_domain), normalize=False)

@property
def readback(self):
# there is probably a better way to do this
Expand All @@ -284,11 +294,13 @@ def summary(self) -> pd.Series:
series = pd.Series(index=list(DOF_FIELD_TYPES.keys()), dtype="object")
for attr in series.index:
value = getattr(self, attr)
if attr in ["search_domain", "trust_domain"]:
if (self.type == "continuous") and not self.read_only:
if value is not None:
if attr in ["search_domain", "trust_domain", "domain"]:
if (self.type == "continuous") and not self.read_only and value is not None:
if attr in ["search_domain", "trust_domain"]:
value = f"[{value[0]:.02e}, {value[1]:.02e}]"
else:
value = f"({value[0]:.02e}, {value[1]:.02e})"
series[attr] = value if value is not None else ""
series[attr] = value if value else ""
return series

@property
Expand Down
36 changes: 12 additions & 24 deletions src/blop/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,23 @@

OBJ_FIELD_TYPES = {
"description": "object",
"kind": "str",
# "kind": "str",
"type": "str",
"target": "object",
"active": "bool",
"transform": "str",
"domain": "str",
"trust_domain": "object",
"active": "bool",
"weight": "bool",
"weight": "float",
"units": "object",
"noise_bounds": "object",
"noise": "float",
"n": "int",
"n_valid": "int",
"latent_groups": "object",
"active": "bool",
}

SUPPORTED_OBJ_TYPES = ["continuous", "binary", "ordinal", "categorical"]
TRANSFORM_DOMAINS = {"log": (0.0, np.inf), "sigmoid": (0.0, 1.0), "tanh": (-1.0, 1.0)}
TRANSFORM_DOMAINS = {"log": (0.0, np.inf), "logit": (0.0, 1.0), "arctanh": (-1.0, 1.0)}


class DuplicateNameError(ValueError):
Expand Down Expand Up @@ -160,9 +160,9 @@ def _transform(self, y):

if self.transform == "log":
y = y.log()
if self.transform == "sigmoid":
if self.transform == "logit":
y = (y / (1 - y)).log()
if self.transform == "tanh":
if self.transform == "arctanh":
y = torch.arctanh(y)

if self.target == "min":
Expand All @@ -179,9 +179,9 @@ def _untransform(self, y):

if self.transform == "log":
y = y.exp()
if self.transform == "sigmoid":
if self.transform == "logit":
y = 1 / (1 + torch.exp(-y))
if self.transform == "tanh":
if self.transform == "arctanh":
y = torch.tanh(y)

return y
Expand Down Expand Up @@ -212,21 +212,9 @@ def summary(self) -> pd.Series:
series[attr] = value if value is not None else ""
return series

# @property
# def trust_lower_bound(self):
# if self.trust_domain is None:
# return 0 if self.log else -np.inf
# return float(self.trust_domain[0])

# @property
# def trust_upper_bound(self):
# if self.trust_domain is None:
# return np.inf
# return float(self.trust_domain[1])

@property
def noise(self) -> float:
return self.model.likelihood.noise.item() if hasattr(self, "model") else None
return self.model.likelihood.noise.item() if hasattr(self, "model") else np.nan

@property
def snr(self) -> float:
Expand All @@ -245,7 +233,7 @@ def targeting_constraint(self, x: torch.Tensor) -> torch.Tensor:
m = p.mean
s = p.variance.sqrt()

sish = s + 0.1 * m.std()
sish = s + 0.1 * m.std() # for numerical stability

return (
0.5 * (approximate_erf((b - m) / (np.sqrt(2) * sish)) - approximate_erf((a - m) / (np.sqrt(2) * sish)))[..., -1]
Expand Down

0 comments on commit cb7f461

Please sign in to comment.