diff --git a/src/sandbox/main.py b/src/sandbox/main.py deleted file mode 100755 index 55b3ace..0000000 --- a/src/sandbox/main.py +++ /dev/null @@ -1,75 +0,0 @@ -from copy import deepcopy -from itertools import product -import logging -from random import randint, random, seed -import sys -from pathlib import Path -from typing import Callable, Iterable -from matplotlib import pyplot as plt -import numpy as np -import gym - - -path = Path(__file__) -sys.path.append(str(path.parents[1].absolute())) -from sandbox.action_selection_rules.ucb import UCB - -from sandbox.algorithms.algorithm import Algorithm -import sandbox.enviroments -from sandbox.action_selection_rules.epsilon_greedy import EpsilonGreedyActionSelection -from sandbox.algorithms.dqn import DQNAlgorithm, MyQNetwork, policy -from sandbox.algorithms.q_learning.qlearning import QLearning -from sandbox.wrappers.discrete_env_wrapper import DiscreteEnvironment -from sandbox.wrappers.stats_wrapper import PlotType, StatsWrapper -from sandbox.algorithms.bandits_algorithm.bandits_algorithm import BanditsAlgorithm -from sandbox.enviroments.multi_armed_bandit.env import NormalDistribution - - -class Comparer: - def __init__(self, algorithms: list[Algorithm], envs: list[gym.Env], get_label: Callable[[Algorithm], str]) -> None: - self.algorithms = algorithms - self.envs = envs - self.get_label = get_label - - def run(self, plot_types: list[PlotType]): - _, axs = plt.subplots(len(self.envs), len(plot_types), figsize=(10, 10), squeeze=False) - algo_colors = [(random(), random(), random()) for _ in self.algorithms] - for i, env in enumerate(self.envs): - env_axs = axs[i] - for algo, color in zip(self.algorithms, algo_colors): - env = deepcopy(env) - env = StatsWrapper(env) - _ = algo.run(5000, env) - env.plot(types=plot_types, ax=env_axs, color=color) - for ax in env_axs: - ax.legend([self.get_label(a) for a in self.algorithms]) - plt.show() - - -def main(): - # NOTE: change logging level to info if you don't want to see ansi renders of env - logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s\n%(message)s') - cmp = Comparer( - algorithms=[BanditsAlgorithm(UCB()), BanditsAlgorithm(EpsilonGreedyActionSelection(0.01))], - envs=[gym.make( - "custom/multiarmed-bandits-v0", - reward_distributions=[NormalDistribution(random(), random()) for _ in range(5)] - ) - ], - get_label=lambda algo: type(algo._select_action).__name__ - ) - cmp.run([PlotType.RewardsVsEpNumber, PlotType.EpisodeLengthHist]) - -def enjoy(env, action_selection, steps) -> None: - state = env.reset() - for step in range(steps): - action = action_selection(state) - state, reward, done, info = env.step(action) - if done: - state = env.reset() - img = env.render() - - - -if __name__ == "__main__": - main() diff --git a/src/sandbox/utils/comparator.py b/src/sandbox/utils/comparator.py index a86d824..2e113fd 100644 --- a/src/sandbox/utils/comparator.py +++ b/src/sandbox/utils/comparator.py @@ -8,10 +8,12 @@ class Comparator: - def __init__(self, algorithms: list[Algorithm], envs: list[gym.Env], get_label: Callable[[Algorithm], str]) -> None: + def __init__(self, algorithms: list[Algorithm], envs: list[gym.Env], get_label: Callable[[Algorithm], str], + live_plotting: bool = True) -> None: self.algorithms = algorithms self.envs = envs self.get_label = get_label + self.live_plotting = live_plotting def run(self, plot_types: list[PlotType]): _, axs = plt.subplots(len(self.envs), len(plot_types), figsize=(10, 10), squeeze=False) @@ -20,9 +22,10 @@ def run(self, plot_types: list[PlotType]): env_axs = axs[i] for algo, color in zip(self.algorithms, algo_colors): env = deepcopy(env) - env = StatsWrapper(env) - _ = algo.run(5000, env) + env = StatsWrapper(env, self.live_plotting) + _ = algo.run(100, env) env.plot(types=plot_types, ax=env_axs, color=color) for ax in env_axs: ax.legend([self.get_label(a) for a in self.algorithms]) + plt.ioff() plt.show() diff --git a/src/sandbox/wrappers/discrete_env_wrapper.py b/src/sandbox/wrappers/discrete_env_wrapper.py index 9b6bf4c..649d3eb 100644 --- a/src/sandbox/wrappers/discrete_env_wrapper.py +++ b/src/sandbox/wrappers/discrete_env_wrapper.py @@ -1,6 +1,7 @@ import gym from gym.core import ObsType, ActType + class DiscreteEnvironment(gym.Wrapper[ObsType, ActType]): @property diff --git a/src/sandbox/wrappers/stats_wrapper.py b/src/sandbox/wrappers/stats_wrapper.py index 0305c6b..84d6232 100644 --- a/src/sandbox/wrappers/stats_wrapper.py +++ b/src/sandbox/wrappers/stats_wrapper.py @@ -6,33 +6,46 @@ import gym from matplotlib.axes import Axes import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation +from matplotlib import style import numpy as np +import matplotlib +matplotlib.use("TkAgg") @dataclass class Statistic: episode_reward: float steps_count: int - + def increment(self, step_reward, length_increment): return Statistic( - self.episode_reward + step_reward, + self.episode_reward + step_reward, self.steps_count + length_increment ) STATS_KEY = 'episode_stats' + class PlotType(IntEnum): RewardsVsEpNumber = auto() EpisodeLengthvsTime = auto() EpisodeLengthHist = auto() CumulatedReward = auto() + + class StatsWrapper(gym.Wrapper): - def __init__(self, env: gym.Env): + def __init__(self, env: gym.Env, real_time=True): super().__init__(env) self.stats: list[Statistic] = [] self._current_statistic = Statistic(0, 0) + self._real_time = real_time + if real_time: + plt.ion() + fig, self.ax = plt.subplots(figsize=(12, 7), + constrained_layout=True) + def step(self, action): observation, reward, done, info = super().step(action) @@ -44,19 +57,26 @@ def step(self, action): def reset(self, **kwargs): self.stats.append(self._current_statistic) + if self._real_time: self._animate() logging.info(f'Episode stats: {self._current_statistic}') self._current_statistic = Statistic(0, 0) return super().reset(**kwargs) - - def plot(self, types: PlotType = None, ax: list[Axes]=None, color=None): + + def _animate(self): + x = [s.episode_reward for s in self.stats] + self.ax.plot(x, color='orange') + plt.pause(0.05) + + + def plot(self, types: PlotType = None, ax: list[Axes] = None, color=None): types = types or list(PlotType) episode_rewards = [s.episode_reward for s in self.stats] steps_count = [s.steps_count for s in self.stats] if ax is None: ax = plt.subplots(figsize=(10, 10), - nrows=len(types), - ncols=1, - constrained_layout=True, squeeze=False)[1] + nrows=len(types), + ncols=1, + constrained_layout=True, squeeze=False)[1] for i, type in enumerate(types): match type: @@ -94,7 +114,3 @@ def plot(self, types: PlotType = None, ax: list[Axes]=None, color=None): '-', c=color or 'orange', linewidth=2) - - - - diff --git a/src/tasks/1-bandits.py b/src/tasks/1-bandits.py index 9989b63..bb352c1 100644 --- a/src/tasks/1-bandits.py +++ b/src/tasks/1-bandits.py @@ -1,5 +1,6 @@ import logging from random import random + from sandbox.action_selection_rules.greedy import GreedyActionSelection from sandbox.enviroments.multi_armed_bandit.env import NormalDistribution from sandbox.utils.comparator import Comparator