Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: fix ESS stopping criterion #389

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 169 additions & 62 deletions nessai/samplers/importancesampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,93 @@ def __getstate__(self):
return state


class StoppingCriterion:
"""Object for storing information about a stopping criterion.

Includes the tolerance, current value and whether the tolerance
has been reached based on how it should be checked.

Parameters
----------
name : str
Name of the stopping criterion
check : {"leq", "geq"}
Indicates whether to check if the value is less than or equal (leq) or
or greater than or equal (geq) than the tolerance.
aliases : List[str]
List of aliases (alternative names) for the criterion.
tolerance : float
Tolerance for the criterion.
value : Optional[None]
Current value. Does not have to be specified.
"""

def __init__(
self,
name: str,
check: str,
aliases: list,
tolerance: float = None,
value: float = None,
):
if check.lower() not in ["leq", "geq"]:
raise ValueError(
f"Invalid value for `check`: {check}. "
f"Choose from ['leq', 'geq']."
)
self._name = name
self._check = check
self._aliases = aliases
self._tolerance = tolerance
self._value = value

@property
def name(self) -> str:
return self._name

@property
def check(self) -> str:
return self._check

@property
def aliases(self) -> List[str]:
return self._aliases

@property
def tolerance(self) -> float:
return self._tolerance

@property
def value(self) -> float:
return self._value

def update_tolerance(self, tolerance) -> None:
"""Update the tolerance"""
self._tolerance = tolerance

def update_value(self, value) -> None:
"""Update the current value."""
self._value = value

def update_value_from_sampler(self, sampler) -> None:
"""Udate"""
value = getattr(sampler, self.name, None)
if value is None:
raise RuntimeError(f"{self.name} has not been computed!")
self.update_value(value)

@property
def reached_tolerance(self) -> bool:
"""Indicates if the stopping criterion has been reached"""
if self.check == "leq":
return self._value <= self._tolerance
else:
return self._value >= self._tolerance

def summary(self) -> str:
return f"{self.name}: {self.value:.4g} ({self.tolerance:.2g})"


class ImportanceNestedSampler(BaseNestedSampler):
"""

Expand Down Expand Up @@ -326,17 +413,46 @@ class ImportanceNestedSampler(BaseNestedSampler):
If False, this can help reduce the disk usage.
"""

stopping_criterion_aliases = dict(
ratio=["ratio", "ratio_all"],
ratio_ns=["ratio_ns"],
Z_err=["Z_err", "evidence_error"],
log_dZ=["log_dZ", "log_evidence"],
ess=[
"ess",
],
fractional_error=["fractional_error"],
)
"""Dictionary of available stopping criteria and their aliases."""
stopping_criteria = {
"ratio": StoppingCriterion(
name="ratio",
check="leq",
aliases=["ratio", "ratio_all"],
tolerance=0.0,
value=np.inf,
),
"ratio_ns": StoppingCriterion(
name="ratio_ns",
check="leq",
aliases=["ratio_ns"],
tolerance=0.0,
value=np.inf,
),
"Z_err": StoppingCriterion(
name="Z_err",
check="leq",
aliases=["Z_err", "evidence_error"],
value=np.inf,
),
"log_dZ": StoppingCriterion(
name="log_dZ",
check="leq",
aliases=["log_dZ", "log_evidence"],
value=np.inf,
),
"ess": StoppingCriterion(
name="ess",
check="geq",
aliases=["ess"],
value=0.0,
),
"fractional_error": StoppingCriterion(
name="fractional_error",
check="leq",
aliases=["fractional_error"],
value=0.0,
),
}

def __init__(
self,
Expand All @@ -362,7 +478,7 @@ def __init__(
min_remove: int = 1,
max_samples: Optional[int] = None,
stopping_criterion: str = "ratio",
tolerance: float = 0.0,
tolerance: Optional[float] = None,
n_update: Optional[int] = None,
plot_pool: bool = False,
plot_level_cdf: bool = False,
Expand Down Expand Up @@ -455,7 +571,7 @@ def __init__(
self.log_dZ = np.inf
self.ratio = np.inf
self.ratio_ns = np.inf
self.ess = 0.0
self.ess = 0
self.Z_err = np.inf

self._final_samples = None
Expand Down Expand Up @@ -651,14 +767,14 @@ def reached_tolerance(self) -> bool:
Checks if any or all of the criteria have been met, this depends on the
value of :code:`check_criteria`.
"""
flags = [
self.stopping_criteria[name].reached_tolerance
for name in self.criteria_to_check
]
if self._stop_any:
return any(
[c <= t for c, t in zip(self.criterion, self.tolerance)]
)
return any(flags)
else:
return all(
[c <= t for c, t in zip(self.criterion, self.tolerance)]
)
return all(flags)

@staticmethod
def add_fields():
Expand All @@ -676,34 +792,21 @@ def configure_stopping_criterion(
stopping_criterion = [stopping_criterion]

if isinstance(tolerance, list):
self.tolerance = [float(t) for t in tolerance]
tolerance = [float(t) for t in tolerance]
else:
self.tolerance = [float(tolerance)]

self.stopping_criterion = []
for c in stopping_criterion:
for criterion, aliases in self.stopping_criterion_aliases.items():
if c in aliases:
self.stopping_criterion.append(criterion)
if not self.stopping_criterion:
raise ValueError(
f"Unknown stopping criterion: {stopping_criterion}"
)
for c, c_use in zip(stopping_criterion, self.stopping_criterion):
if c != c_use:
logger.info(
f"Stopping criterion specified ({c}) is "
f"an alias for {c_use}. Using {c_use}."
)
if len(self.stopping_criterion) != len(self.tolerance):
raise ValueError(
"Number of stopping criteria must match tolerances"
)
self.criterion = len(self.tolerance) * [np.inf]

logger.info(f"Stopping criteria: {self.stopping_criterion}")
logger.info(f"Tolerance: {self.tolerance}")
tolerance = [float(tolerance)]

self.criteria_to_check = []
for name, tol in zip(stopping_criterion, tolerance):
for sc in self.stopping_criteria.values():
if name in sc.aliases:
sc.update_tolerance(tol)
self.criteria_to_check.append(name)
break
else:
raise ValueError(f"Unknown stopping criterion: {name}")

logger.info(f"Stopping criteria to check: {self.criteria_to_check}")
if check_criteria not in {"any", "all"}:
raise ValueError("check_criteria must be any or all")
if check_criteria == "any":
Expand Down Expand Up @@ -841,7 +944,7 @@ def initialise_history(self) -> None:
samples_entropy=[],
proposal_entropy=[],
stopping_criteria={
k: [] for k in self.stopping_criterion_aliases.keys()
name: [] for name in self.stopping_criteria.keys()
},
)
)
Expand Down Expand Up @@ -871,10 +974,8 @@ def update_history(self) -> None:
self.model.likelihood_evaluations
)

for k in self.stopping_criterion_aliases.keys():
self.history["stopping_criteria"][k].append(
getattr(self, k, np.nan)
)
for name, sc in self.stopping_criteria.items():
self.history["stopping_criteria"][name].append(sc.value)

def determine_threshold_quantile(
self,
Expand Down Expand Up @@ -1427,12 +1528,18 @@ def compute_stopping_criterion(self) -> List[float]:
self.ess = self.state.effective_n_posterior_samples
self.Z_err = np.exp(self.log_evidence_error)
self.fractional_error = self.state.evidence_error / self.state.evidence
cond = [getattr(self, sc) for sc in self.stopping_criterion]

logger.info(
f"Stopping criteria ({self.stopping_criterion}): {cond} "
f"- Tolerance: {self.tolerance}"
)
cond = {}
for name, sc in self.stopping_criteria.items():
sc.update_value_from_sampler(self)
if name in self.criteria_to_check:
cond[name] = sc.value

status = [
self.stopping_criteria[sc].summary()
for sc in self.criteria_to_check
]
logger.info(f"Stopping criteria: {status}")
return cond

def checkpoint(self, periodic: bool = False, force: bool = False):
Expand Down Expand Up @@ -1583,7 +1690,7 @@ def nested_sampling_loop(self):

logger.info(
f"Finished nested sampling loop after {self.iteration} iterations "
f"with {self.stopping_criterion} = {self.criterion}"
f"with {self.criterion}"
)
self.finalise()
logger.info(f"Training time: {self.training_time}")
Expand Down Expand Up @@ -2021,17 +2128,17 @@ def plot_state(
ax[m].legend()
m += 1

for (i, sc), tol in zip(
enumerate(self.stopping_criterion), self.tolerance
):
for i, sc_name in enumerate(self.criteria_to_check):
ax[m].plot(
its,
self.history["stopping_criteria"][sc],
label=sc,
self.history["stopping_criteria"][sc_name],
label=sc_name,
c=f"C{i}",
ls=config.plotting.line_styles[i],
)
ax[m].axhline(tol, ls=":", c=f"C{i}")
ax[m].axhline(
self.stopping_criteria[sc_name].tolerance, ls=":", c=f"C{i}"
)
ax[m].legend()
ax[m].set_ylabel("Stopping criterion")

Expand Down
Loading