-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fe1c5f9
commit 1274942
Showing
3 changed files
with
256 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
from enum import Enum | ||
|
||
import gym | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from gym import spaces | ||
from gym.utils import seeding | ||
|
||
|
||
class Actions(Enum): | ||
Hold = 0 | ||
Buy = 1 | ||
Sell = 2 | ||
|
||
|
||
class Positions(Enum): | ||
Short = 0 | ||
Long = 1 | ||
|
||
def opposite(self): | ||
return Positions.Short if self == Positions.Long else Positions.Long | ||
|
||
|
||
class GymAnytrading(gym.Env): | ||
""" | ||
Based on https://github.com/AminHP/gym-anytrading | ||
""" | ||
|
||
metadata = {'render.modes': ['human']} | ||
|
||
def __init__(self, signal_features, prices, window_size, fee=0.0): | ||
assert signal_features.ndim == 2 | ||
|
||
self.seed() | ||
self.signal_features = signal_features | ||
self.prices = prices | ||
self.window_size = window_size | ||
self.fee = fee | ||
self.shape = (window_size, self.signal_features.shape[1]) | ||
|
||
# spaces | ||
self.action_space = spaces.Discrete(len(Actions)) | ||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32) | ||
|
||
# episode | ||
self._start_tick = self.window_size | ||
self._end_tick = len(self.prices) - 1 | ||
self._done = None | ||
self._current_tick = None | ||
self._last_trade_tick = None | ||
self._position = None | ||
self._position_history = None | ||
self._total_reward = None | ||
self._total_profit = None | ||
self._first_rendering = None | ||
self.history = None | ||
|
||
|
||
def seed(self, seed=None): | ||
self.np_random, seed = seeding.np_random(seed) | ||
return [seed] | ||
|
||
|
||
def reset(self): | ||
self._done = False | ||
self._current_tick = self._start_tick | ||
self._last_trade_tick = self._current_tick - 1 | ||
self._position = Positions.Short | ||
self._position_history = (self.window_size * [None]) + [self._position] | ||
self._total_reward = 0. | ||
self._total_profit = 1. # unit | ||
self._first_rendering = True | ||
self.history = {} | ||
return self._get_observation() | ||
|
||
|
||
def step(self, action): | ||
self._done = False | ||
self._current_tick += 1 | ||
|
||
if self._current_tick == self._end_tick: | ||
self._done = True | ||
|
||
step_reward = self._calculate_reward(action) | ||
self._total_reward += step_reward | ||
|
||
self._update_profit(action) | ||
|
||
trade = False | ||
if ((action == Actions.Buy.value and self._position == Positions.Short) or | ||
(action == Actions.Sell.value and self._position == Positions.Long)): | ||
trade = True | ||
|
||
if trade: | ||
self._position = self._position.opposite() | ||
self._last_trade_tick = self._current_tick | ||
|
||
self._position_history.append(self._position) | ||
observation = self._get_observation() | ||
info = dict( | ||
total_reward = self._total_reward, | ||
total_profit = self._total_profit, | ||
position = self._position.value | ||
) | ||
self._update_history(info) | ||
|
||
return observation, step_reward, self._done, info | ||
|
||
|
||
def _get_observation(self): | ||
return self.signal_features[(self._current_tick-self.window_size):self._current_tick] | ||
|
||
|
||
def _update_history(self, info): | ||
if not self.history: | ||
self.history = {key: [] for key in info.keys()} | ||
|
||
for key, value in info.items(): | ||
self.history[key].append(value) | ||
|
||
|
||
def render(self, mode='human'): | ||
def _plot_position(position, tick): | ||
color = None | ||
if position == Positions.Short: | ||
color = 'red' | ||
elif position == Positions.Long: | ||
color = 'green' | ||
if color: | ||
plt.scatter(tick, self.prices[tick], color=color) | ||
|
||
if self._first_rendering: | ||
self._first_rendering = False | ||
plt.cla() | ||
plt.plot(self.prices) | ||
start_position = self._position_history[self._start_tick] | ||
_plot_position(start_position, self._start_tick) | ||
|
||
_plot_position(self._position, self._current_tick) | ||
|
||
plt.suptitle( | ||
"Total Reward: %.6f" % self._total_reward + ' ~ ' + | ||
"Total Profit: %.6f" % self._total_profit | ||
) | ||
|
||
plt.pause(0.01) | ||
|
||
|
||
def render_all(self, mode='human'): | ||
window_ticks = np.arange(len(self._position_history)) | ||
plt.plot(self.prices) | ||
|
||
short_ticks = [] | ||
long_ticks = [] | ||
for i, tick in enumerate(window_ticks): | ||
if self._position_history[i] == Positions.Short: | ||
short_ticks.append(tick) | ||
elif self._position_history[i] == Positions.Long: | ||
long_ticks.append(tick) | ||
|
||
plt.plot(short_ticks, self.prices[short_ticks], 'ro') | ||
plt.plot(long_ticks, self.prices[long_ticks], 'go') | ||
|
||
plt.suptitle( | ||
"Total Reward: %.6f" % self._total_reward + ' ~ ' + | ||
"Total Profit: %.6f" % self._total_profit | ||
) | ||
|
||
|
||
def close(self): | ||
plt.close() | ||
|
||
def save_rendering(self, filepath): | ||
plt.savefig(filepath) | ||
|
||
def pause_rendering(self): | ||
plt.show() | ||
|
||
def _calculate_reward(self, action): | ||
step_reward = 0 | ||
|
||
trade = False | ||
if ((action == Actions.Buy.value and self._position == Positions.Short) or | ||
(action == Actions.Sell.value and self._position == Positions.Long)): | ||
trade = True | ||
|
||
if trade: | ||
current_price = self.prices[self._current_tick] | ||
last_trade_price = self.prices[self._last_trade_tick] | ||
price_diff = current_price - last_trade_price | ||
|
||
if self._position == Positions.Long: | ||
step_reward += price_diff | ||
|
||
return step_reward | ||
|
||
def _update_profit(self, action): | ||
trade = False | ||
if ((action == Actions.Buy.value and self._position == Positions.Short) or | ||
(action == Actions.Sell.value and self._position == Positions.Long)): | ||
trade = True | ||
|
||
if trade or self._done: | ||
current_price = self.prices[self._current_tick] | ||
last_trade_price = self.prices[self._last_trade_tick] | ||
|
||
if self._position == Positions.Long: | ||
shares = (self._total_profit * (1 - self.fee)) / last_trade_price | ||
self._total_profit = (shares * (1 - self.fee)) * current_price | ||
|
||
def max_possible_profit(self): | ||
current_tick = self._start_tick | ||
last_trade_tick = current_tick - 1 | ||
profit = 1. | ||
|
||
while current_tick <= self._end_tick: | ||
position = None | ||
if self.prices[current_tick] < self.prices[current_tick - 1]: | ||
while (current_tick <= self._end_tick and | ||
self.prices[current_tick] < self.prices[current_tick - 1]): | ||
current_tick += 1 | ||
position = Positions.Short | ||
else: | ||
while (current_tick <= self._end_tick and | ||
self.prices[current_tick] >= self.prices[current_tick - 1]): | ||
current_tick += 1 | ||
position = Positions.Long | ||
|
||
if position == Positions.Long: | ||
current_price = self.prices[current_tick - 1] | ||
last_trade_price = self.prices[last_trade_tick] | ||
shares = profit / last_trade_price | ||
profit = shares * current_price | ||
last_trade_tick = current_tick - 1 | ||
|
||
return profit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .SimpleROIEnv import SimpleROIEnv | ||
from .FreqtradeEnv import FreqtradeEnv | ||
from .FreqtradeEnv import FreqtradeEnv | ||
from .GymAnytrading import GymAnytrading |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters