Skip to content

Commit

Permalink
Added GymAnytrading
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexBabescu committed Oct 30, 2021
1 parent fe1c5f9 commit 1274942
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 12 deletions.
237 changes: 237 additions & 0 deletions trading_environments/GymAnytrading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
from enum import Enum

import gym
import matplotlib.pyplot as plt
import numpy as np
from gym import spaces
from gym.utils import seeding


class Actions(Enum):
Hold = 0
Buy = 1
Sell = 2


class Positions(Enum):
Short = 0
Long = 1

def opposite(self):
return Positions.Short if self == Positions.Long else Positions.Long


class GymAnytrading(gym.Env):
"""
Based on https://github.com/AminHP/gym-anytrading
"""

metadata = {'render.modes': ['human']}

def __init__(self, signal_features, prices, window_size, fee=0.0):
assert signal_features.ndim == 2

self.seed()
self.signal_features = signal_features
self.prices = prices
self.window_size = window_size
self.fee = fee
self.shape = (window_size, self.signal_features.shape[1])

# spaces
self.action_space = spaces.Discrete(len(Actions))
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)

# episode
self._start_tick = self.window_size
self._end_tick = len(self.prices) - 1
self._done = None
self._current_tick = None
self._last_trade_tick = None
self._position = None
self._position_history = None
self._total_reward = None
self._total_profit = None
self._first_rendering = None
self.history = None


def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]


def reset(self):
self._done = False
self._current_tick = self._start_tick
self._last_trade_tick = self._current_tick - 1
self._position = Positions.Short
self._position_history = (self.window_size * [None]) + [self._position]
self._total_reward = 0.
self._total_profit = 1. # unit
self._first_rendering = True
self.history = {}
return self._get_observation()


def step(self, action):
self._done = False
self._current_tick += 1

if self._current_tick == self._end_tick:
self._done = True

step_reward = self._calculate_reward(action)
self._total_reward += step_reward

self._update_profit(action)

trade = False
if ((action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)):
trade = True

if trade:
self._position = self._position.opposite()
self._last_trade_tick = self._current_tick

self._position_history.append(self._position)
observation = self._get_observation()
info = dict(
total_reward = self._total_reward,
total_profit = self._total_profit,
position = self._position.value
)
self._update_history(info)

return observation, step_reward, self._done, info


def _get_observation(self):
return self.signal_features[(self._current_tick-self.window_size):self._current_tick]


def _update_history(self, info):
if not self.history:
self.history = {key: [] for key in info.keys()}

for key, value in info.items():
self.history[key].append(value)


def render(self, mode='human'):
def _plot_position(position, tick):
color = None
if position == Positions.Short:
color = 'red'
elif position == Positions.Long:
color = 'green'
if color:
plt.scatter(tick, self.prices[tick], color=color)

if self._first_rendering:
self._first_rendering = False
plt.cla()
plt.plot(self.prices)
start_position = self._position_history[self._start_tick]
_plot_position(start_position, self._start_tick)

_plot_position(self._position, self._current_tick)

plt.suptitle(
"Total Reward: %.6f" % self._total_reward + ' ~ ' +
"Total Profit: %.6f" % self._total_profit
)

plt.pause(0.01)


def render_all(self, mode='human'):
window_ticks = np.arange(len(self._position_history))
plt.plot(self.prices)

short_ticks = []
long_ticks = []
for i, tick in enumerate(window_ticks):
if self._position_history[i] == Positions.Short:
short_ticks.append(tick)
elif self._position_history[i] == Positions.Long:
long_ticks.append(tick)

plt.plot(short_ticks, self.prices[short_ticks], 'ro')
plt.plot(long_ticks, self.prices[long_ticks], 'go')

plt.suptitle(
"Total Reward: %.6f" % self._total_reward + ' ~ ' +
"Total Profit: %.6f" % self._total_profit
)


def close(self):
plt.close()

def save_rendering(self, filepath):
plt.savefig(filepath)

def pause_rendering(self):
plt.show()

def _calculate_reward(self, action):
step_reward = 0

trade = False
if ((action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)):
trade = True

if trade:
current_price = self.prices[self._current_tick]
last_trade_price = self.prices[self._last_trade_tick]
price_diff = current_price - last_trade_price

if self._position == Positions.Long:
step_reward += price_diff

return step_reward

def _update_profit(self, action):
trade = False
if ((action == Actions.Buy.value and self._position == Positions.Short) or
(action == Actions.Sell.value and self._position == Positions.Long)):
trade = True

if trade or self._done:
current_price = self.prices[self._current_tick]
last_trade_price = self.prices[self._last_trade_tick]

if self._position == Positions.Long:
shares = (self._total_profit * (1 - self.fee)) / last_trade_price
self._total_profit = (shares * (1 - self.fee)) * current_price

def max_possible_profit(self):
current_tick = self._start_tick
last_trade_tick = current_tick - 1
profit = 1.

while current_tick <= self._end_tick:
position = None
if self.prices[current_tick] < self.prices[current_tick - 1]:
while (current_tick <= self._end_tick and
self.prices[current_tick] < self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Short
else:
while (current_tick <= self._end_tick and
self.prices[current_tick] >= self.prices[current_tick - 1]):
current_tick += 1
position = Positions.Long

if position == Positions.Long:
current_price = self.prices[current_tick - 1]
last_trade_price = self.prices[last_trade_tick]
shares = profit / last_trade_price
profit = shares * current_price
last_trade_tick = current_tick - 1

return profit
3 changes: 2 additions & 1 deletion trading_environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .SimpleROIEnv import SimpleROIEnv
from .FreqtradeEnv import FreqtradeEnv
from .FreqtradeEnv import FreqtradeEnv
from .GymAnytrading import GymAnytrading
28 changes: 17 additions & 11 deletions train_freqtrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from stable_baselines3.a2c.a2c import A2C

from tb_callbacks import SaveOnStepCallback
from trading_environments import FreqtradeEnv, SimpleROIEnv
from trading_environments import FreqtradeEnv, SimpleROIEnv, GymAnytrading

"""Settings"""
PAIR = "BTC/USDT"
Expand Down Expand Up @@ -57,26 +57,32 @@ def main():
pair_data.fillna(0, inplace=True)


trading_env = FreqtradeEnv(
data=pair_data,
prices=price_data,
window_size=WINDOW_SIZE, # how many past candles should it use as features
pair=PAIR,
stake_amount=freqtrade_config['stake_amount'],
punish_holding_amount=0,
)
# trading_env = FreqtradeEnv(
# data=pair_data,
# prices=price_data,
# window_size=WINDOW_SIZE, # how many past candles should it use as features
# pair=PAIR,
# stake_amount=freqtrade_config['stake_amount'],
# punish_holding_amount=0,
# )

# trading_env = SimpleROIEnv(
# data=pair_data,
# prices=price_data,
# window_size=WINDOW_SIZE, # how many past candles should it use as features
# required_startup=required_startup,
# minimum_roi=0.03, # 2% target ROI
# roi_candles=48,
# minimum_roi=0.02, # 2% target ROI
# roi_candles=24, # 24 candles * 5m = 120 minutes
# punish_holding_amount=0,
# punish_missed_buy=True
# )

trading_env = GymAnytrading(
signal_features=pair_data,
prices=price_data.close,
window_size=WINDOW_SIZE, # how many past candles should it use as features
)

trading_env = Monitor(trading_env, LOG_DIR)

# Optional policy_kwargs
Expand Down

0 comments on commit 1274942

Please sign in to comment.