Skip to content

Commit

Permalink
Callbacks (#18)
Browse files Browse the repository at this point in the history
* Title argument for Result.plot_objective

* Call-back mixin for a new global best

* Warn when overwriting callback

* Flexible operator_counts plotting

* Operator decay parameters in [0, 1]

* Tests
  • Loading branch information
N-Wouda authored Feb 20, 2020
1 parent 5c98493 commit c669e90
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 55 deletions.
19 changes: 14 additions & 5 deletions alns/ALNS.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import numpy.random as rnd

from .CallbackFlag import CallbackFlag
from .CallbackMixin import CallbackMixin
from .Result import Result
from .State import State # pylint: disable=unused-import
from .Statistics import Statistics
Expand All @@ -13,7 +15,7 @@
from .select_operator import select_operator


class ALNS:
class ALNS(CallbackMixin):

def __init__(self, rnd_state=rnd.RandomState()):
"""
Expand All @@ -35,6 +37,8 @@ def __init__(self, rnd_state=rnd.RandomState()):
Gendreau (Ed.), *Handbook of Metaheuristics* (2 ed., pp. 399-420).
Springer.
"""
super().__init__()

self._destroy_operators = OrderedDict()
self._repair_operators = OrderedDict()

Expand Down Expand Up @@ -116,7 +120,8 @@ def iterate(self, initial_solution, weights, operator_decay, criterion,
is better than the current solution (idx 1), the solution is
accepted (idx 2), or rejected (idx 3).
operator_decay : float
The operator decay parameter, as a float in the unit interval.
The operator decay parameter, as a float in the unit interval,
[0, 1] (inclusive).
criterion : AcceptanceCriterion
The acceptance criterion to use for candidate states. See also
the `alns.criteria` module for an overview.
Expand All @@ -134,8 +139,8 @@ def iterate(self, initial_solution, weights, operator_decay, criterion,
Returns
-------
Result
A result object, containing the best and last solutions, and some
additional results.
A result object, containing the best solution and some additional
statistics.
References
----------
Expand Down Expand Up @@ -178,6 +183,10 @@ class of vehicle routing problems with backhauls. *European Journal of
candidate, criterion)

if current.objective() < best.objective():
if self.has_callback(CallbackFlag.ON_BEST):
callback = self.callback(CallbackFlag.ON_BEST)
current = callback(current, self._rnd_state)

best = current

# The weights are updated as convex combinations of the current
Expand Down Expand Up @@ -274,7 +283,7 @@ def _validate_parameters(self, weights, operator_decay, iterations):
if len(self.destroy_operators) == 0 or len(self.repair_operators) == 0:
raise ValueError("Missing at least one destroy or repair operator.")

if not (0 < operator_decay < 1):
if not (0 <= operator_decay <= 1):
raise ValueError("Operator decay parameter outside unit interval"
" is not understood.")

Expand Down
9 changes: 9 additions & 0 deletions alns/CallbackFlag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from enum import IntEnum, unique


@unique
class CallbackFlag(IntEnum):
"""
Callback flags for the mix-in.
"""
ON_BEST = 0
77 changes: 77 additions & 0 deletions alns/CallbackMixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import warnings

from .CallbackFlag import CallbackFlag
from .exceptions_warnings import OverwriteWarning


class CallbackMixin:

def __init__(self):
"""
Callback mix-in for ALNS. This allows for some flexibility by having
ALNS call custom functions whenever a special event happens.
"""
self._callbacks = {}

def on_best(self, func):
"""
Sets a callback function to be called when ALNS finds a new global best
solution state.
Parameters
----------
func : callable
A function that should take a solution State as its first parameter,
and a numpy RandomState as its second (cf. the operator signature).
It should return a (new) solution State.
Warns
-----
OverwriteWarning
When a callback has already been set for the ON_BEST flag.
"""
self._set_callback(CallbackFlag.ON_BEST, func)

def has_callback(self, flag):
"""
Determines if a callable has been set for the passed-in flag.
Parameters
----------
flag : CallbackFlag
Returns
-------
bool
True if a callable is set, False otherwise.
"""
return flag in self._callbacks

def callback(self, flag):
"""
Returns the callback for the passed-in flag, assuming it exists.
Parameters
----------
flag : CallbackFlag
The callback flag for which to retrieve a callback.
Returns
-------
callable
Callback for the passed-in flag.
"""
return self._callbacks[flag]

def _set_callback(self, flag, func):
"""
Sets the passed-in callback func for the passed-in flag. Warns if this
would overwrite an existing callback.
"""
if self.has_callback(flag):
warnings.warn("A callback function has already been set for the"
" `{0}' flag. This callback will now be replaced by"
" the newly passed-in callback.".format(flag),
OverwriteWarning)

self._callbacks[flag] = func
51 changes: 25 additions & 26 deletions alns/Result.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def statistics(self):

return self._statistics

def plot_objectives(self, ax=None, **kwargs):
def plot_objectives(self, ax=None, title=None, **kwargs):
"""
Plots the collected objective values at each iteration.
Expand All @@ -67,15 +67,20 @@ def plot_objectives(self, ax=None, **kwargs):
ax : Axes
Optional axes argument. If not passed, a new figure and axes are
constructed.
title : str
Optional title argument. When not passed, a default is set.
kwargs : dict
Optional arguments passed to ``ax.plot``.
"""
if ax is None:
_, ax = plt.subplots()

if title is None:
title = "Objective value at each iteration"

ax.plot(self.statistics.objectives, **kwargs)

ax.set_title("Objective value at each iteration")
ax.set_title(title)
ax.set_ylabel("Objective value")
ax.set_xlabel("Iteration (#)")

Expand All @@ -94,15 +99,21 @@ def plot_operator_counts(self, figure=None, title=None, legend=None,
title : str
Optional figure title. When not passed, no title is set.
legend : list
Optional legend entries. When passed, this should be a list of
four strings. When not passed, a sensible default is set.
Optional legend entries. When passed, this should be a list of at
most four strings. The first string describes the number of times
a best solution was found, the second a better, the third a solution
was accepted but did not improve upon the current or global best,
and the fourth the number of times a solution was rejected. If less
than four strings are passed, only the first len(legend) count types
are plotted. When not passed, a sensible default is set and all
counts are shown.
kwargs : dict
Optional arguments passed to each call of ``ax.barh``.
Raises
------
ValueError
When the passed-in legend list is not of appropriate length.
When the legend contains more than four elements.
"""
if figure is None:
figure, (d_ax, r_ax) = plt.subplots(nrows=2)
Expand All @@ -118,46 +129,34 @@ def plot_operator_counts(self, figure=None, title=None, legend=None,
if title is not None:
figure.suptitle(title)

if legend is not None:
if len(legend) < 4:
raise ValueError("Legend not understood. Expected 4 items,"
" found {0}.".format(len(legend)))
else:
if legend is None:
legend = ["Best", "Better", "Accepted", "Rejected"]
elif len(legend) > 4:
raise ValueError("Legend not understood. Expected at most 4 items,"
" found {0}.".format(len(legend)))

self._plot_operator_counts(d_ax,
self.statistics.destroy_operator_counts,
"Destroy operators",
len(legend),
**kwargs)

self._plot_operator_counts(r_ax,
self.statistics.repair_operator_counts,
"Repair operators",
len(legend),
**kwargs)

# It is not really a problem if the legend is longer than four items,
# but we will only use the first four.
figure.legend(legend[:4], ncol=4, loc="lower center")
figure.legend(legend, ncol=len(legend), loc="lower center")

plt.draw_if_interactive()

@staticmethod
def _plot_operator_counts(ax, operator_counts, title, **kwargs):
def _plot_operator_counts(ax, operator_counts, title, num_types, **kwargs):
"""
Internal helper that plots the passed-in operator_counts on the given
ax object.
Parameters
----------
ax: Axes
An axes object, to be populated with data.
operator_counts : dict
A dictionary of operator counts.
title : str
Plot title.
**kwargs
Optional keyword arguments, to be passed to ``ax.barh``.
Note
----
This code takes loosely after an example from the matplotlib gallery
Expand All @@ -170,7 +169,7 @@ def _plot_operator_counts(ax, operator_counts, title, **kwargs):

ax.set_xlim(right=np.sum(operator_counts, axis=1).max())

for idx in range(4):
for idx in range(num_types):
widths = operator_counts[:, idx]
starts = cumulative_counts[:, idx] - widths

Expand Down
1 change: 1 addition & 0 deletions alns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .ALNS import ALNS
from .CallbackFlag import CallbackFlag
from .State import State
35 changes: 21 additions & 14 deletions alns/tests/test_alns.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ def objective(self):
return self._value


# CALLBACKS --------------------------------------------------------------------

def test_on_best_is_called():
"""
Tests if the callback is invoked when a new global best is found.
"""
alns = get_alns_instance([lambda state, rnd: Zero()],
[lambda state, rnd: Zero()])

# Called when a new global best is found. In this case, that happens once:
# in the only iteration below. It returns a state with value 10, which
# should then also be returned by the entire algorithm.
alns.on_best(lambda *args: ValueState(10))

result = alns.iterate(One(), [1, 1, 1, 1], .5, HillClimbing(), 1)
assert_equal(result.best_state.objective(), 10)

# OPERATORS --------------------------------------------------------------------


Expand Down Expand Up @@ -171,20 +188,6 @@ def test_raises_explosive_operator_decay():
alns.iterate(One(), [1, 1, 1, 1], 1.2, HillClimbing())


def test_raises_boundary_operator_decay():
"""
The boundary cases, zero and one, should both raise.
"""
alns = get_alns_instance([lambda state, rnd: None],
[lambda state, rnd: None])

with assert_raises(ValueError):
alns.iterate(One(), [1, 1, 1, 1], 0, HillClimbing())

with assert_raises(ValueError):
alns.iterate(One(), [1, 1, 1, 1], 1, HillClimbing())


def test_raises_insufficient_weights():
"""
We need (at least) four weights to be passed-in, one for each updating
Expand Down Expand Up @@ -239,6 +242,10 @@ def test_does_not_raise():

alns.iterate(Zero(), [1, 1, 1, 1], .5, HillClimbing(), 100)

# 0 and 1 are both acceptable decay parameters (since v1.2.0).
alns.iterate(Zero(), [1, 1, 1, 1], 0., HillClimbing(), 100)
alns.iterate(Zero(), [1, 1, 1, 1], 1., HillClimbing(), 100)


# EXAMPLES ---------------------------------------------------------------------

Expand Down
34 changes: 34 additions & 0 deletions alns/tests/test_callback_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from numpy.testing import assert_, assert_no_warnings, assert_warns

from alns import CallbackFlag
from alns.CallbackMixin import CallbackMixin
from alns.exceptions_warnings import OverwriteWarning


def dummy_callback():
return None


def test_insert_extraction_on_best():
"""
Tests if regular add/return callback works for ON_BEST.
"""
mixin = CallbackMixin()
mixin.on_best(dummy_callback)

assert_(mixin.has_callback(CallbackFlag.ON_BEST))
assert_(mixin.callback(CallbackFlag.ON_BEST) is dummy_callback)


def test_overwrite_warns_on_best():
"""
There can only be a single callback for each event point, so inserting two
(or more) should warn the previous callback for ON_BEST is overwritten.
"""
mixin = CallbackMixin()

with assert_no_warnings(): # first insert is fine..
mixin.on_best(dummy_callback)

with assert_warns(OverwriteWarning): # .. but second insert should warn
mixin.on_best(dummy_callback)
Loading

0 comments on commit c669e90

Please sign in to comment.