-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
MABSelector
operator selection class to allow using MABWiser to…
… select operators (#153) * Add MABSelector to docs * Add tests and extra validation to MABSelector * Remove old example of solving TSP with MABSelector * Fix typos/clarify prose in docstrings * Make op2arm methods functions instead * Make MABWiser an optional dependency * Gracefully handle missing mabwiser dependency * Add tests for output of MABSelector * Add MABSelector to "ALNS Features" example * Add mabwiser to deps groups for docs and examples * Add finalized API for contextual states * Add contextual bandit example to example notebook * Add tests for contextual MABs * Add mabwiser as a required dep for development * Add fixes from review * Update test to check for new exception
- Loading branch information
Showing
8 changed files
with
575 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
from typing import List, Optional, Tuple | ||
|
||
import numpy as np | ||
from numpy.random import RandomState | ||
|
||
from alns.Outcome import Outcome | ||
from alns.State import ContextualState | ||
from alns.select.OperatorSelectionScheme import OperatorSelectionScheme | ||
|
||
MABWISER_AVAILABLE = True | ||
try: | ||
from mabwiser.mab import MAB, LearningPolicy, NeighborhoodPolicy | ||
except ModuleNotFoundError: | ||
MABWISER_AVAILABLE = False | ||
|
||
|
||
class MABSelector(OperatorSelectionScheme): | ||
""" | ||
A selector that uses any multi-armed-bandit algorithm from MABWiser. | ||
This selector is a wrapper around the many multi-armed bandit algorithms | ||
available in the `MABWiser <https://github.com/fidelity/mabwiser>`_ | ||
library. Since ALNS operator selection can be framed as a | ||
multi-armed-bandit problem (where each [destroy, repair] operator pair is | ||
a bandit arm), this wrapper allows you to use a variety of existing | ||
multi-armed-bandit algorithms as operator selectors instead of | ||
having to reimplement them. | ||
Note that if the provided learning policy is a contextual bandit | ||
algorithm, your state class must provide a `get_context` function that | ||
returns a context vector for the current state. | ||
Parameters | ||
---------- | ||
scores | ||
A list of four non-negative elements, representing the rewards when the | ||
candidate solution results in a new global best (idx 0), is better than | ||
the current solution (idx 1), the solution is accepted (idx 2), or | ||
rejected (idx 3). | ||
num_destroy | ||
Number of destroy operators. | ||
num_repair | ||
Number of repair operators. | ||
learning_policy | ||
A MABWiser learning policy that acts as an operator selector. See the | ||
MABWiser documentation for a list of available learning policies. | ||
neighborhood_policy | ||
The neighborhood policy that MABWiser should use. Only available for | ||
contextual learning policies. See the MABWiser documentation for a | ||
list of available neighborhood policies. | ||
seed | ||
A seed that will be passed to the underlying MABWiser object. | ||
op_coupling | ||
Optional boolean matrix that indicates coupling between destroy and | ||
repair operators. Entry (i, j) is True if destroy operator i can be | ||
used together with repair operator j, and False otherwise. | ||
References | ||
---------- | ||
.. [1] Emily Strong, Bernard Kleynhans, & Serdar Kadioglu (2021). | ||
MABWiser: Parallelizable Contextual Multi-armed Bandits. | ||
Int. J. Artif. Intell. Tools, 30(4), 2150021:1–2150021:19. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
scores: List[float], | ||
num_destroy: int, | ||
num_repair: int, | ||
learning_policy: "LearningPolicy", | ||
neighborhood_policy: Optional["NeighborhoodPolicy"] = None, | ||
seed: Optional[int] = None, | ||
op_coupling: Optional[np.ndarray] = None, | ||
**kwargs, | ||
): | ||
if not MABWISER_AVAILABLE: | ||
raise ImportError("MABSelector requires the MABWiser library. ") | ||
|
||
super().__init__(num_destroy, num_repair, op_coupling) | ||
|
||
if any(score < 0 for score in scores): | ||
raise ValueError("Negative scores are not understood.") | ||
|
||
if len(scores) < 4: | ||
# More than four is OK because we only use the first four. | ||
raise ValueError(f"Expected four scores, found {len(scores)}") | ||
|
||
# forward the seed argument if not null | ||
if seed is not None: | ||
kwargs["seed"] = seed | ||
|
||
# the set of valid operator pairs (arms) is equal to the cartesian | ||
# product of destroy and repair operators, except we leave out any | ||
# pairs disallowed by op_coupling | ||
arms = [ | ||
f"{d_idx}_{r_idx}" | ||
for d_idx in range(num_destroy) | ||
for r_idx in range(num_repair) | ||
if self._op_coupling[d_idx, r_idx] | ||
] | ||
self._mab = MAB( | ||
arms, | ||
learning_policy, | ||
neighborhood_policy, | ||
**kwargs, | ||
) | ||
self._scores = scores | ||
|
||
@property | ||
def scores(self) -> List[float]: | ||
return self._scores | ||
|
||
@property | ||
def mab(self) -> "MAB": | ||
return self._mab | ||
|
||
def __call__( # type: ignore[override] | ||
self, | ||
rnd_state: RandomState, | ||
best: ContextualState, | ||
curr: ContextualState, | ||
) -> Tuple[int, int]: | ||
""" | ||
Returns the (destroy, repair) operator pair from the underlying MAB | ||
strategy | ||
""" | ||
if self._mab._is_initial_fit: | ||
has_context = self._mab.is_contextual | ||
context = ( | ||
np.atleast_2d(curr.get_context()) if has_context else None | ||
) | ||
prediction = self._mab.predict(contexts=context) | ||
return arm2ops(prediction) | ||
else: | ||
# This can happen when the MAB object has not yet been fit on any | ||
# observations. In that case we return any feasible operator index | ||
# pair as a first observation. | ||
allowed = np.argwhere(self._op_coupling) | ||
idx = rnd_state.randint(len(allowed)) | ||
return (allowed[idx][0], allowed[idx][1]) | ||
|
||
def update( # type: ignore[override] | ||
self, | ||
cand: ContextualState, | ||
d_idx: int, | ||
r_idx: int, | ||
outcome: Outcome, | ||
): | ||
""" | ||
Updates the underlying MAB algorithm given the reward of the chosen | ||
destroy and repair operator combination ``(d_idx, r_idx)``. | ||
""" | ||
has_context = self._mab.is_contextual | ||
context = np.atleast_2d(cand.get_context()) if has_context else None | ||
|
||
self._mab.partial_fit( | ||
[ops2arm(d_idx, r_idx)], | ||
[self._scores[outcome]], | ||
contexts=context, | ||
) | ||
|
||
|
||
def ops2arm(destroy_idx: int, repair_idx: int) -> str: | ||
""" | ||
Converts a tuple of destroy and repair operator indices to an arm | ||
string that can be passed to self._mab. | ||
Examples | ||
-------- | ||
>>> ops2arm(0, 1) | ||
"0_1" | ||
>>> ops2arm(12, 3) | ||
"12_3" | ||
""" | ||
return f"{destroy_idx}_{repair_idx}" | ||
|
||
|
||
def arm2ops(arm: str) -> Tuple[int, int]: | ||
""" | ||
Converts an arm string returned from self._mab to a tuple of destroy | ||
and repair operator indices. | ||
Examples | ||
-------- | ||
>>> arm2ops("0_1") | ||
(0, 1) | ||
>>> arm2ops("12_3") | ||
(12, 3) | ||
""" | ||
[destroy, repair] = arm.split("_") | ||
return int(destroy), int(repair) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
from typing import List | ||
|
||
import numpy.random as rnd | ||
from mabwiser.mab import LearningPolicy, NeighborhoodPolicy | ||
from numpy.testing import assert_equal, assert_raises | ||
from pytest import mark | ||
|
||
from alns.Outcome import Outcome | ||
from alns.select import MABSelector | ||
from alns.select.MABSelector import arm2ops, ops2arm | ||
from alns.tests.states import Zero, ZeroWithOneContext, ZeroWithZeroContext | ||
|
||
|
||
@mark.parametrize( | ||
"destroy_idx, repair_idx", | ||
[ | ||
(0, 0), | ||
(0, 1), | ||
(3, 3), | ||
(12, 7), | ||
(0, 14), | ||
], | ||
) | ||
def test_arm_conversion(destroy_idx, repair_idx): | ||
expected = (destroy_idx, repair_idx) | ||
actual = arm2ops(ops2arm(destroy_idx, repair_idx)) | ||
|
||
assert_equal(actual, expected) | ||
|
||
|
||
def test_does_not_raise_on_valid_mab(): | ||
MABSelector([0, 0, 0, 0], 2, 1, LearningPolicy.EpsilonGreedy(0.15)) | ||
MABSelector( | ||
[0, 0, 0, 0], | ||
2, | ||
1, | ||
LearningPolicy.EpsilonGreedy(0.15), | ||
NeighborhoodPolicy.Radius(5), | ||
) | ||
MABSelector( | ||
[0, 0, 0, 0], | ||
2, | ||
1, | ||
LearningPolicy.EpsilonGreedy(0.15), | ||
NeighborhoodPolicy.Radius(5), | ||
1234567, | ||
) | ||
MABSelector( | ||
[0, 0, 0, 0], 2, 1, LearningPolicy.EpsilonGreedy(0.15), seed=1234567 | ||
) | ||
|
||
|
||
@mark.parametrize( | ||
"scores, learning_policy, num_destroy, num_repair", | ||
[ | ||
( | ||
[5, 3, 2, -1], | ||
LearningPolicy.EpsilonGreedy(0.15), | ||
1, | ||
1, | ||
), # negative score | ||
( | ||
[5, 3, 2], | ||
LearningPolicy.EpsilonGreedy(0.15), | ||
1, | ||
1, | ||
), # len(score) < 4 | ||
], | ||
) | ||
def test_raises_invalid_arguments( | ||
scores: List[float], | ||
learning_policy: LearningPolicy, | ||
num_destroy: int, | ||
num_repair: int, | ||
): | ||
with assert_raises(ValueError): | ||
MABSelector(scores, num_destroy, num_repair, learning_policy) | ||
|
||
|
||
def test_call_with_only_one_operator_pair(): | ||
# Only one operator pair, so the algorithm should select (0, 0). | ||
select = MABSelector( | ||
[2, 1, 1, 0], 1, 1, LearningPolicy.EpsilonGreedy(0.15) | ||
) | ||
state = rnd.RandomState() | ||
|
||
for _ in range(10): | ||
selected = select(state, Zero(), Zero()) | ||
assert_equal(selected, (0, 0)) | ||
|
||
|
||
def test_mab_epsilon_greedy(): | ||
state = rnd.RandomState() | ||
|
||
# epsilon=0 is equivalent to greedy selection | ||
select = MABSelector([2, 1, 1, 0], 2, 1, LearningPolicy.EpsilonGreedy(0.0)) | ||
|
||
select.update(Zero(), 0, 0, outcome=Outcome.BETTER) | ||
selected = select(state, Zero(), Zero()) | ||
for _ in range(10): | ||
selected = select(state, Zero(), Zero()) | ||
assert_equal(selected, (0, 0)) | ||
|
||
select.update(Zero(), 1, 0, outcome=Outcome.BEST) | ||
for _ in range(10): | ||
selected = select(state, Zero(), Zero()) | ||
assert_equal(selected, (1, 0)) | ||
|
||
|
||
@mark.parametrize("alpha", [0.25, 0.5]) | ||
def test_mab_ucb1(alpha): | ||
state = rnd.RandomState() | ||
select = MABSelector([2, 1, 1, 0], 2, 1, LearningPolicy.UCB1(alpha)) | ||
|
||
select.update(Zero(), 0, 0, outcome=Outcome.BEST) | ||
mab_select = select(state, Zero(), Zero()) | ||
assert_equal(mab_select, (0, 0)) | ||
|
||
select.update(Zero(), 0, 0, outcome=Outcome.REJECT) | ||
mab_select = select(state, Zero(), Zero()) | ||
assert_equal(mab_select, (0, 0)) | ||
|
||
|
||
def test_contextual_mab_requires_context(): | ||
select = MABSelector( | ||
[2, 1, 1, 0], | ||
2, | ||
1, | ||
LearningPolicy.LinGreedy(0), | ||
) | ||
# error: "Zero" state has no get_context method | ||
with assert_raises(AttributeError): | ||
select.update(Zero(), 0, 0, outcome=Outcome.BEST) | ||
|
||
|
||
def text_contextual_mab_uses_context(): | ||
state = rnd.RandomState() | ||
select = MABSelector( | ||
[2, 1, 1, 0], | ||
2, | ||
1, | ||
# epsilon=0 is equivalent to greedy | ||
LearningPolicy.LinGreedy(0), | ||
) | ||
|
||
select.update(ZeroWithZeroContext(), 0, 0, outcome=Outcome.REJECT) | ||
select.update(ZeroWithZeroContext(), 0, 0, outcome=Outcome.REJECT) | ||
select.update(ZeroWithZeroContext(), 1, 0, outcome=Outcome.BEST) | ||
|
||
select.update(ZeroWithOneContext(), 1, 0, outcome=Outcome.REJECT) | ||
select.update(ZeroWithOneContext(), 1, 0, outcome=Outcome.REJECT) | ||
select.update(ZeroWithOneContext(), 0, 0, outcome=Outcome.BEST) | ||
|
||
mab_select = select(state, ZeroWithZeroContext(), ZeroWithZeroContext()) | ||
assert_equal(mab_select, (1, 0)) | ||
|
||
mab_select = select(state, ZeroWithZeroContext(), ZeroWithZeroContext()) | ||
assert_equal(mab_select, (0, 0)) |
Oops, something went wrong.