Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 0 additions & 75 deletions src/sandbox/main.py

This file was deleted.

9 changes: 6 additions & 3 deletions src/sandbox/utils/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@


class Comparator:
def __init__(self, algorithms: list[Algorithm], envs: list[gym.Env], get_label: Callable[[Algorithm], str]) -> None:
def __init__(self, algorithms: list[Algorithm], envs: list[gym.Env], get_label: Callable[[Algorithm], str],
live_plotting: bool = True) -> None:
self.algorithms = algorithms
self.envs = envs
self.get_label = get_label
self.live_plotting = live_plotting

def run(self, plot_types: list[PlotType]):
_, axs = plt.subplots(len(self.envs), len(plot_types), figsize=(10, 10), squeeze=False)
Expand All @@ -20,9 +22,10 @@ def run(self, plot_types: list[PlotType]):
env_axs = axs[i]
for algo, color in zip(self.algorithms, algo_colors):
env = deepcopy(env)
env = StatsWrapper(env)
_ = algo.run(5000, env)
env = StatsWrapper(env, self.live_plotting)
_ = algo.run(100, env)
env.plot(types=plot_types, ax=env_axs, color=color)
for ax in env_axs:
ax.legend([self.get_label(a) for a in self.algorithms])
plt.ioff()
plt.show()
1 change: 1 addition & 0 deletions src/sandbox/wrappers/discrete_env_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gym
from gym.core import ObsType, ActType


class DiscreteEnvironment(gym.Wrapper[ObsType, ActType]):

@property
Expand Down
40 changes: 28 additions & 12 deletions src/sandbox/wrappers/stats_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,46 @@
import gym
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib import style
import numpy as np
import matplotlib
matplotlib.use("TkAgg")

@dataclass
class Statistic:
episode_reward: float
steps_count: int

def increment(self, step_reward, length_increment):
return Statistic(
self.episode_reward + step_reward,
self.episode_reward + step_reward,
self.steps_count + length_increment
)


STATS_KEY = 'episode_stats'


class PlotType(IntEnum):
RewardsVsEpNumber = auto()
EpisodeLengthvsTime = auto()
EpisodeLengthHist = auto()
CumulatedReward = auto()


class StatsWrapper(gym.Wrapper):

def __init__(self, env: gym.Env):
def __init__(self, env: gym.Env, real_time=True):
super().__init__(env)
self.stats: list[Statistic] = []
self._current_statistic = Statistic(0, 0)
self._real_time = real_time
if real_time:
plt.ion()
fig, self.ax = plt.subplots(figsize=(12, 7),
constrained_layout=True)


def step(self, action):
observation, reward, done, info = super().step(action)
Expand All @@ -44,19 +57,26 @@ def step(self, action):

def reset(self, **kwargs):
self.stats.append(self._current_statistic)
if self._real_time: self._animate()
logging.info(f'Episode stats: {self._current_statistic}')
self._current_statistic = Statistic(0, 0)
return super().reset(**kwargs)

def plot(self, types: PlotType = None, ax: list[Axes]=None, color=None):

def _animate(self):
x = [s.episode_reward for s in self.stats]
self.ax.plot(x, color='orange')
plt.pause(0.05)


def plot(self, types: PlotType = None, ax: list[Axes] = None, color=None):
types = types or list(PlotType)
episode_rewards = [s.episode_reward for s in self.stats]
steps_count = [s.steps_count for s in self.stats]
if ax is None:
ax = plt.subplots(figsize=(10, 10),
nrows=len(types),
ncols=1,
constrained_layout=True, squeeze=False)[1]
nrows=len(types),
ncols=1,
constrained_layout=True, squeeze=False)[1]

for i, type in enumerate(types):
match type:
Expand Down Expand Up @@ -94,7 +114,3 @@ def plot(self, types: PlotType = None, ax: list[Axes]=None, color=None):
'-',
c=color or 'orange',
linewidth=2)




1 change: 1 addition & 0 deletions src/tasks/1-bandits.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from random import random

from sandbox.action_selection_rules.greedy import GreedyActionSelection
from sandbox.enviroments.multi_armed_bandit.env import NormalDistribution
from sandbox.utils.comparator import Comparator
Expand Down