Skip to content

Commit

Permalink
Add MABSelector operator selection class to allow using MABWiser to…
Browse files Browse the repository at this point in the history
… select operators (#153)

* Add MABSelector to docs

* Add tests and extra validation to MABSelector

* Remove old example of solving TSP with MABSelector

* Fix typos/clarify prose in docstrings

* Make op2arm methods functions instead

* Make MABWiser an optional dependency

* Gracefully handle missing mabwiser dependency

* Add tests for output of MABSelector

* Add MABSelector to "ALNS Features" example

* Add mabwiser to deps groups for docs and examples

* Add finalized API for contextual states

* Add contextual bandit example to example notebook

* Add tests for contextual MABs

* Add mabwiser as a required dep for development

* Add fixes from review

* Update test to check for new exception
  • Loading branch information
P-bibs authored Jun 10, 2023
1 parent 6bb423d commit 963080d
Show file tree
Hide file tree
Showing 8 changed files with 575 additions and 27 deletions.
20 changes: 20 additions & 0 deletions alns/State.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Protocol

import numpy as np


class State(Protocol):
"""
Expand All @@ -11,3 +13,21 @@ def objective(self) -> float:
"""
Computes the state's associated objective value.
"""


class ContextualState(Protocol):
"""
Protocol for a solution state that also provides context. Solutions should
define an ``objective()`` function as well as a ``get_context()``
function.
"""

def objective(self) -> float:
"""
Computes the state's associated objective value.
"""

def get_context(self) -> np.ndarray:
"""
Computes a context vector for the current state
"""
191 changes: 191 additions & 0 deletions alns/select/MABSelector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from typing import List, Optional, Tuple

import numpy as np
from numpy.random import RandomState

from alns.Outcome import Outcome
from alns.State import ContextualState
from alns.select.OperatorSelectionScheme import OperatorSelectionScheme

MABWISER_AVAILABLE = True
try:
from mabwiser.mab import MAB, LearningPolicy, NeighborhoodPolicy
except ModuleNotFoundError:
MABWISER_AVAILABLE = False


class MABSelector(OperatorSelectionScheme):
"""
A selector that uses any multi-armed-bandit algorithm from MABWiser.
This selector is a wrapper around the many multi-armed bandit algorithms
available in the `MABWiser <https://github.com/fidelity/mabwiser>`_
library. Since ALNS operator selection can be framed as a
multi-armed-bandit problem (where each [destroy, repair] operator pair is
a bandit arm), this wrapper allows you to use a variety of existing
multi-armed-bandit algorithms as operator selectors instead of
having to reimplement them.
Note that if the provided learning policy is a contextual bandit
algorithm, your state class must provide a `get_context` function that
returns a context vector for the current state.
Parameters
----------
scores
A list of four non-negative elements, representing the rewards when the
candidate solution results in a new global best (idx 0), is better than
the current solution (idx 1), the solution is accepted (idx 2), or
rejected (idx 3).
num_destroy
Number of destroy operators.
num_repair
Number of repair operators.
learning_policy
A MABWiser learning policy that acts as an operator selector. See the
MABWiser documentation for a list of available learning policies.
neighborhood_policy
The neighborhood policy that MABWiser should use. Only available for
contextual learning policies. See the MABWiser documentation for a
list of available neighborhood policies.
seed
A seed that will be passed to the underlying MABWiser object.
op_coupling
Optional boolean matrix that indicates coupling between destroy and
repair operators. Entry (i, j) is True if destroy operator i can be
used together with repair operator j, and False otherwise.
References
----------
.. [1] Emily Strong, Bernard Kleynhans, & Serdar Kadioglu (2021).
MABWiser: Parallelizable Contextual Multi-armed Bandits.
Int. J. Artif. Intell. Tools, 30(4), 2150021:1–2150021:19.
"""

def __init__(
self,
scores: List[float],
num_destroy: int,
num_repair: int,
learning_policy: "LearningPolicy",
neighborhood_policy: Optional["NeighborhoodPolicy"] = None,
seed: Optional[int] = None,
op_coupling: Optional[np.ndarray] = None,
**kwargs,
):
if not MABWISER_AVAILABLE:
raise ImportError("MABSelector requires the MABWiser library. ")

super().__init__(num_destroy, num_repair, op_coupling)

if any(score < 0 for score in scores):
raise ValueError("Negative scores are not understood.")

if len(scores) < 4:
# More than four is OK because we only use the first four.
raise ValueError(f"Expected four scores, found {len(scores)}")

# forward the seed argument if not null
if seed is not None:
kwargs["seed"] = seed

# the set of valid operator pairs (arms) is equal to the cartesian
# product of destroy and repair operators, except we leave out any
# pairs disallowed by op_coupling
arms = [
f"{d_idx}_{r_idx}"
for d_idx in range(num_destroy)
for r_idx in range(num_repair)
if self._op_coupling[d_idx, r_idx]
]
self._mab = MAB(
arms,
learning_policy,
neighborhood_policy,
**kwargs,
)
self._scores = scores

@property
def scores(self) -> List[float]:
return self._scores

@property
def mab(self) -> "MAB":
return self._mab

def __call__( # type: ignore[override]
self,
rnd_state: RandomState,
best: ContextualState,
curr: ContextualState,
) -> Tuple[int, int]:
"""
Returns the (destroy, repair) operator pair from the underlying MAB
strategy
"""
if self._mab._is_initial_fit:
has_context = self._mab.is_contextual
context = (
np.atleast_2d(curr.get_context()) if has_context else None
)
prediction = self._mab.predict(contexts=context)
return arm2ops(prediction)
else:
# This can happen when the MAB object has not yet been fit on any
# observations. In that case we return any feasible operator index
# pair as a first observation.
allowed = np.argwhere(self._op_coupling)
idx = rnd_state.randint(len(allowed))
return (allowed[idx][0], allowed[idx][1])

def update( # type: ignore[override]
self,
cand: ContextualState,
d_idx: int,
r_idx: int,
outcome: Outcome,
):
"""
Updates the underlying MAB algorithm given the reward of the chosen
destroy and repair operator combination ``(d_idx, r_idx)``.
"""
has_context = self._mab.is_contextual
context = np.atleast_2d(cand.get_context()) if has_context else None

self._mab.partial_fit(
[ops2arm(d_idx, r_idx)],
[self._scores[outcome]],
contexts=context,
)


def ops2arm(destroy_idx: int, repair_idx: int) -> str:
"""
Converts a tuple of destroy and repair operator indices to an arm
string that can be passed to self._mab.
Examples
--------
>>> ops2arm(0, 1)
"0_1"
>>> ops2arm(12, 3)
"12_3"
"""
return f"{destroy_idx}_{repair_idx}"


def arm2ops(arm: str) -> Tuple[int, int]:
"""
Converts an arm string returned from self._mab to a tuple of destroy
and repair operator indices.
Examples
--------
>>> arm2ops("0_1")
(0, 1)
>>> arm2ops("12_3")
(12, 3)
"""
[destroy, repair] = arm.split("_")
return int(destroy), int(repair)
1 change: 1 addition & 0 deletions alns/select/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .AlphaUCB import AlphaUCB
from .MABSelector import MABSelector
from .OperatorSelectionScheme import OperatorSelectionScheme
from .RandomSelect import RandomSelect
from .RouletteWheel import RouletteWheel
Expand Down
158 changes: 158 additions & 0 deletions alns/select/tests/test_mab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from typing import List

import numpy.random as rnd
from mabwiser.mab import LearningPolicy, NeighborhoodPolicy
from numpy.testing import assert_equal, assert_raises
from pytest import mark

from alns.Outcome import Outcome
from alns.select import MABSelector
from alns.select.MABSelector import arm2ops, ops2arm
from alns.tests.states import Zero, ZeroWithOneContext, ZeroWithZeroContext


@mark.parametrize(
"destroy_idx, repair_idx",
[
(0, 0),
(0, 1),
(3, 3),
(12, 7),
(0, 14),
],
)
def test_arm_conversion(destroy_idx, repair_idx):
expected = (destroy_idx, repair_idx)
actual = arm2ops(ops2arm(destroy_idx, repair_idx))

assert_equal(actual, expected)


def test_does_not_raise_on_valid_mab():
MABSelector([0, 0, 0, 0], 2, 1, LearningPolicy.EpsilonGreedy(0.15))
MABSelector(
[0, 0, 0, 0],
2,
1,
LearningPolicy.EpsilonGreedy(0.15),
NeighborhoodPolicy.Radius(5),
)
MABSelector(
[0, 0, 0, 0],
2,
1,
LearningPolicy.EpsilonGreedy(0.15),
NeighborhoodPolicy.Radius(5),
1234567,
)
MABSelector(
[0, 0, 0, 0], 2, 1, LearningPolicy.EpsilonGreedy(0.15), seed=1234567
)


@mark.parametrize(
"scores, learning_policy, num_destroy, num_repair",
[
(
[5, 3, 2, -1],
LearningPolicy.EpsilonGreedy(0.15),
1,
1,
), # negative score
(
[5, 3, 2],
LearningPolicy.EpsilonGreedy(0.15),
1,
1,
), # len(score) < 4
],
)
def test_raises_invalid_arguments(
scores: List[float],
learning_policy: LearningPolicy,
num_destroy: int,
num_repair: int,
):
with assert_raises(ValueError):
MABSelector(scores, num_destroy, num_repair, learning_policy)


def test_call_with_only_one_operator_pair():
# Only one operator pair, so the algorithm should select (0, 0).
select = MABSelector(
[2, 1, 1, 0], 1, 1, LearningPolicy.EpsilonGreedy(0.15)
)
state = rnd.RandomState()

for _ in range(10):
selected = select(state, Zero(), Zero())
assert_equal(selected, (0, 0))


def test_mab_epsilon_greedy():
state = rnd.RandomState()

# epsilon=0 is equivalent to greedy selection
select = MABSelector([2, 1, 1, 0], 2, 1, LearningPolicy.EpsilonGreedy(0.0))

select.update(Zero(), 0, 0, outcome=Outcome.BETTER)
selected = select(state, Zero(), Zero())
for _ in range(10):
selected = select(state, Zero(), Zero())
assert_equal(selected, (0, 0))

select.update(Zero(), 1, 0, outcome=Outcome.BEST)
for _ in range(10):
selected = select(state, Zero(), Zero())
assert_equal(selected, (1, 0))


@mark.parametrize("alpha", [0.25, 0.5])
def test_mab_ucb1(alpha):
state = rnd.RandomState()
select = MABSelector([2, 1, 1, 0], 2, 1, LearningPolicy.UCB1(alpha))

select.update(Zero(), 0, 0, outcome=Outcome.BEST)
mab_select = select(state, Zero(), Zero())
assert_equal(mab_select, (0, 0))

select.update(Zero(), 0, 0, outcome=Outcome.REJECT)
mab_select = select(state, Zero(), Zero())
assert_equal(mab_select, (0, 0))


def test_contextual_mab_requires_context():
select = MABSelector(
[2, 1, 1, 0],
2,
1,
LearningPolicy.LinGreedy(0),
)
# error: "Zero" state has no get_context method
with assert_raises(AttributeError):
select.update(Zero(), 0, 0, outcome=Outcome.BEST)


def text_contextual_mab_uses_context():
state = rnd.RandomState()
select = MABSelector(
[2, 1, 1, 0],
2,
1,
# epsilon=0 is equivalent to greedy
LearningPolicy.LinGreedy(0),
)

select.update(ZeroWithZeroContext(), 0, 0, outcome=Outcome.REJECT)
select.update(ZeroWithZeroContext(), 0, 0, outcome=Outcome.REJECT)
select.update(ZeroWithZeroContext(), 1, 0, outcome=Outcome.BEST)

select.update(ZeroWithOneContext(), 1, 0, outcome=Outcome.REJECT)
select.update(ZeroWithOneContext(), 1, 0, outcome=Outcome.REJECT)
select.update(ZeroWithOneContext(), 0, 0, outcome=Outcome.BEST)

mab_select = select(state, ZeroWithZeroContext(), ZeroWithZeroContext())
assert_equal(mab_select, (1, 0))

mab_select = select(state, ZeroWithZeroContext(), ZeroWithZeroContext())
assert_equal(mab_select, (0, 0))
Loading

0 comments on commit 963080d

Please sign in to comment.