-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Title argument for Result.plot_objective * Call-back mixin for a new global best * Warn when overwriting callback * Flexible operator_counts plotting * Operator decay parameters in [0, 1] * Tests
- Loading branch information
Showing
9 changed files
with
194 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from enum import IntEnum, unique | ||
|
||
|
||
@unique | ||
class CallbackFlag(IntEnum): | ||
""" | ||
Callback flags for the mix-in. | ||
""" | ||
ON_BEST = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import warnings | ||
|
||
from .CallbackFlag import CallbackFlag | ||
from .exceptions_warnings import OverwriteWarning | ||
|
||
|
||
class CallbackMixin: | ||
|
||
def __init__(self): | ||
""" | ||
Callback mix-in for ALNS. This allows for some flexibility by having | ||
ALNS call custom functions whenever a special event happens. | ||
""" | ||
self._callbacks = {} | ||
|
||
def on_best(self, func): | ||
""" | ||
Sets a callback function to be called when ALNS finds a new global best | ||
solution state. | ||
Parameters | ||
---------- | ||
func : callable | ||
A function that should take a solution State as its first parameter, | ||
and a numpy RandomState as its second (cf. the operator signature). | ||
It should return a (new) solution State. | ||
Warns | ||
----- | ||
OverwriteWarning | ||
When a callback has already been set for the ON_BEST flag. | ||
""" | ||
self._set_callback(CallbackFlag.ON_BEST, func) | ||
|
||
def has_callback(self, flag): | ||
""" | ||
Determines if a callable has been set for the passed-in flag. | ||
Parameters | ||
---------- | ||
flag : CallbackFlag | ||
Returns | ||
------- | ||
bool | ||
True if a callable is set, False otherwise. | ||
""" | ||
return flag in self._callbacks | ||
|
||
def callback(self, flag): | ||
""" | ||
Returns the callback for the passed-in flag, assuming it exists. | ||
Parameters | ||
---------- | ||
flag : CallbackFlag | ||
The callback flag for which to retrieve a callback. | ||
Returns | ||
------- | ||
callable | ||
Callback for the passed-in flag. | ||
""" | ||
return self._callbacks[flag] | ||
|
||
def _set_callback(self, flag, func): | ||
""" | ||
Sets the passed-in callback func for the passed-in flag. Warns if this | ||
would overwrite an existing callback. | ||
""" | ||
if self.has_callback(flag): | ||
warnings.warn("A callback function has already been set for the" | ||
" `{0}' flag. This callback will now be replaced by" | ||
" the newly passed-in callback.".format(flag), | ||
OverwriteWarning) | ||
|
||
self._callbacks[flag] = func |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .ALNS import ALNS | ||
from .CallbackFlag import CallbackFlag | ||
from .State import State |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from numpy.testing import assert_, assert_no_warnings, assert_warns | ||
|
||
from alns import CallbackFlag | ||
from alns.CallbackMixin import CallbackMixin | ||
from alns.exceptions_warnings import OverwriteWarning | ||
|
||
|
||
def dummy_callback(): | ||
return None | ||
|
||
|
||
def test_insert_extraction_on_best(): | ||
""" | ||
Tests if regular add/return callback works for ON_BEST. | ||
""" | ||
mixin = CallbackMixin() | ||
mixin.on_best(dummy_callback) | ||
|
||
assert_(mixin.has_callback(CallbackFlag.ON_BEST)) | ||
assert_(mixin.callback(CallbackFlag.ON_BEST) is dummy_callback) | ||
|
||
|
||
def test_overwrite_warns_on_best(): | ||
""" | ||
There can only be a single callback for each event point, so inserting two | ||
(or more) should warn the previous callback for ON_BEST is overwritten. | ||
""" | ||
mixin = CallbackMixin() | ||
|
||
with assert_no_warnings(): # first insert is fine.. | ||
mixin.on_best(dummy_callback) | ||
|
||
with assert_warns(OverwriteWarning): # .. but second insert should warn | ||
mixin.on_best(dummy_callback) |
Oops, something went wrong.