Skip to content

Commit e94bea6

Browse files
committed
Merge WeightedUnrealisedProfit
2 parents 40c7559 + be6d144 commit e94bea6

7 files changed

+56
-3
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ data/postgres/*
88
data/log/*
99
data/reports/*
1010
*.pkl
11+
venv/*

lib/RLTrader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from stable_baselines import PPO2
1515

1616
from lib.env.TradingEnv import TradingEnv
17-
from lib.env.reward import BaseRewardStrategy, IncrementalProfit
17+
from lib.env.reward import BaseRewardStrategy, IncrementalProfit, WeightedUnrealisedProfit
1818
from lib.data.providers.dates import ProviderDateFormat
1919
from lib.data.providers import BaseDataProvider, StaticDataProvider, ExchangeDataProvider
2020
from lib.util.logger import init_logger

lib/env/TradingEnv.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import List, Dict
88

99
from lib.env.render import TradingChart
10-
from lib.env.reward import BaseRewardStrategy, IncrementalProfit
10+
from lib.env.reward import BaseRewardStrategy, IncrementalProfit, WeightedUnrealisedProfit
1111
from lib.env.trade import BaseTradeStrategy, SimulatedTradeStrategy
1212
from lib.data.providers import BaseDataProvider
1313
from lib.data.features.transform import max_min_normalize, mean_normalize, log_and_difference, difference
@@ -109,6 +109,7 @@ def _take_action(self, action: int):
109109
elif asset_sold:
110110
self.asset_held -= asset_sold
111111
self.balance += sale_revenue
112+
self.reward_strategy.reset_reward()
112113

113114
self.trades.append({'step': self.current_step, 'amount': asset_sold,
114115
'total': sale_revenue, 'type': 'sell'})
@@ -191,6 +192,8 @@ def reset(self):
191192
self.asset_held = 0
192193
self.current_step = 0
193194

195+
self.reward_strategy.reset_reward()
196+
194197
self.account_history = pd.DataFrame([{
195198
'balance': self.balance,
196199
'asset_bought': 0,

lib/env/reward/BaseRewardStrategy.py

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ class BaseRewardStrategy(object, metaclass=ABCMeta):
99
def __init__(self):
1010
pass
1111

12+
@abstractmethod
13+
def reset_reward(self):
14+
raise NotImplementedError()
15+
1216
@abstractmethod
1317
def get_reward(self,
1418
current_step: int,

lib/env/reward/IncrementalProfit.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@ class IncrementalProfit(BaseRewardStrategy):
1212
def __init__(self):
1313
pass
1414

15+
def reset_reward(self):
16+
pass
17+
1518
def get_reward(self,
1619
current_step: int,
1720
current_price: Callable[[str], float],
1821
observations: pd.DataFrame,
1922
account_history: pd.DataFrame,
20-
net_worths: List[float]):
23+
net_worths: List[float]) -> float:
2124
reward = 0
2225

2326
curr_balance = account_history['balance'].values[-1]
+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from collections import deque
2+
3+
import pandas as pd
4+
import numpy as np
5+
from typing import List, Callable
6+
7+
from lib.env.reward.BaseRewardStrategy import BaseRewardStrategy
8+
9+
10+
class WeightedUnrealisedProfit(BaseRewardStrategy):
11+
def __init__(self, **kwargs):
12+
self.decay_rate = kwargs.get('decay_rate', 1e-2)
13+
self.decay_denominator = np.exp(-1 * self.decay_rate)
14+
15+
self.reset_reward()
16+
17+
def reset_reward(self):
18+
self.rewards = deque(np.zeros(1, dtype=float))
19+
self.sum = 0.0
20+
21+
def calc_reward(self, reward):
22+
self.sum = self.sum - self.decay_denominator * self.rewards.popleft()
23+
self.sum = self.sum * self.decay_denominator
24+
self.sum = self.sum + reward
25+
26+
self.rewards.append(reward)
27+
28+
return self.sum / self.decay_denominator
29+
30+
def get_reward(self,
31+
current_step: int,
32+
current_price: Callable[[str], float],
33+
observations: pd.DataFrame,
34+
account_history: pd.DataFrame,
35+
net_worths: List[float]) -> float:
36+
if account_history['asset_sold'].values[-1] > 0:
37+
reward = self.calc_reward(account_history['sale_revenue'].values[-1])
38+
else:
39+
reward = self.calc_reward(account_history['asset_held'].values[-1] * current_price)
40+
41+
return reward

lib/env/reward/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from lib.env.reward.IncrementalProfit import IncrementalProfit
2+
from lib.env.reward.WeightedUnrealisedProfit import WeightedUnrealisedProfit
23
from lib.env.reward.BaseRewardStrategy import BaseRewardStrategy

0 commit comments

Comments
 (0)