Skip to content

Commit 40c7559

Browse files
committed
Merge branch 'rshtirmer-trade-stats'
2 parents a351a7d + 80fa433 commit 40c7559

9 files changed

+106
-47
lines changed

.dockerignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ tensorboard
66
agents
77
data/tensorboard
88
data/agents
9-
data/postgres
9+
data/postgres
10+
data/reports

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ data/tensorboard/*
66
data/agents/*
77
data/postgres/*
88
data/log/*
9+
data/reports/*
910
*.pkl

cli.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def run_optimize(args, logger):
1515
from lib.RLTrader import RLTrader
1616

1717
trader = RLTrader(**vars(args), logger=logger)
18-
trader.optimize(args.trials)
18+
trader.optimize(n_trials=args.trials, n_prune_evals_per_trial=args.prune_evals, n_tests_per_eval=args.eval_tests)
1919

2020

2121
if __name__ == '__main__':
@@ -39,8 +39,16 @@ def run_optimize(args, logger):
3939
trader = RLTrader(**vars(args), logger=logger)
4040

4141
if args.command == 'train':
42-
trader.train(n_epochs=args.epochs)
42+
trader.train(n_epochs=args.epochs,
43+
save_every=args.save_every,
44+
test_trained_model=args.test_trained,
45+
render_test_env=args.render_test,
46+
render_report=args.render_report,
47+
save_report=args.save_report)
4348
elif args.command == 'test':
44-
trader.test(model_epoch=args.model_epoch, should_render=args.no_render)
49+
trader.test(model_epoch=args.model_epoch,
50+
render_env=args.render_env,
51+
render_report=args.render_report,
52+
save_report=args.save_report)
4553
elif args.command == 'update-static-data':
4654
download_data_async()

data/.DS_Store

6 KB
Binary file not shown.

lib/RLTrader.py

+47-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
import optuna
33
import numpy as np
4+
import pandas as pd
5+
import quantstats as qs
46

57
from os import path
68
from typing import Dict
@@ -136,8 +138,12 @@ def optimize_params(self, trial, n_prune_evals_per_trial: int = 2, n_tests_per_e
136138
validation_env = SubprocVecEnv([make_env(validation_provider, i) for i in range(1)])
137139

138140
model_params = self.optimize_agent_params(trial)
139-
model = self.Model(self.Policy, train_env, verbose=self.model_verbose, nminibatches=1,
140-
tensorboard_log=self.tensorboard_path, **model_params)
141+
model = self.Model(self.Policy,
142+
train_env,
143+
verbose=self.model_verbose,
144+
nminibatches=1,
145+
tensorboard_log=self.tensorboard_path,
146+
**model_params)
141147

142148
last_reward = -np.finfo(np.float16).max
143149
n_steps_per_eval = int(len(train_provider.data_frame) / n_prune_evals_per_trial)
@@ -154,7 +160,7 @@ def optimize_params(self, trial, n_prune_evals_per_trial: int = 2, n_tests_per_e
154160
trades = train_env.get_attr('trades')
155161

156162
if len(trades[0]) < 1:
157-
self.logger.info('Pruning trial for not making any trades: ', eval_idx)
163+
self.logger.info(f'Pruning trial for not making any trades: {eval_idx}')
158164
raise optuna.structs.TrialPruned()
159165

160166
state = None
@@ -179,9 +185,9 @@ def optimize_params(self, trial, n_prune_evals_per_trial: int = 2, n_tests_per_e
179185

180186
return -1 * last_reward
181187

182-
def optimize(self, n_trials: int = 20, *optimize_params):
188+
def optimize(self, n_trials: int = 20, **optimize_params):
183189
try:
184-
self.optuna_study.optimize(self.optimize_params, n_trials=n_trials, n_jobs=1, *optimize_params)
190+
self.optuna_study.optimize(self.optimize_params, n_trials=n_trials, n_jobs=1, **optimize_params)
185191
except KeyboardInterrupt:
186192
pass
187193

@@ -195,7 +201,13 @@ def optimize(self, n_trials: int = 20, *optimize_params):
195201

196202
return self.optuna_study.trials_dataframe()
197203

198-
def train(self, n_epochs: int = 10, save_every: int = 1, test_trained_model: bool = False, render_trained_model: bool = False):
204+
def train(self,
205+
n_epochs: int = 10,
206+
save_every: int = 1,
207+
test_trained_model: bool = True,
208+
render_test_env: bool = False,
209+
render_report: bool = True,
210+
save_report: bool = False):
199211
train_provider, test_provider = self.data_provider.split_data_train_test(self.train_split_percentage)
200212

201213
del test_provider
@@ -204,8 +216,12 @@ def train(self, n_epochs: int = 10, save_every: int = 1, test_trained_model: boo
204216

205217
model_params = self.get_model_params()
206218

207-
model = self.Model(self.Policy, train_env, verbose=self.model_verbose, nminibatches=self.n_minibatches,
208-
tensorboard_log=self.tensorboard_path, **model_params)
219+
model = self.Model(self.Policy,
220+
train_env,
221+
verbose=self.model_verbose,
222+
nminibatches=self.n_minibatches,
223+
tensorboard_log=self.tensorboard_path,
224+
**model_params)
209225

210226
self.logger.info(f'Training for {n_epochs} epochs')
211227

@@ -221,11 +237,14 @@ def train(self, n_epochs: int = 10, save_every: int = 1, test_trained_model: boo
221237
model.save(model_path)
222238

223239
if test_trained_model:
224-
self.test(model_epoch, should_render=render_trained_model)
240+
self.test(model_epoch,
241+
render_env=render_test_env,
242+
render_report=render_report,
243+
save_report=save_report)
225244

226245
self.logger.info(f'Trained {n_epochs} models')
227246

228-
def test(self, model_epoch: int = 0, should_render: bool = True):
247+
def test(self, model_epoch: int = 0, render_env: bool = True, render_report: bool = True, save_report: bool = False):
229248
train_provider, test_provider = self.data_provider.split_data_train_test(self.train_split_percentage)
230249

231250
del train_provider
@@ -247,14 +266,30 @@ def test(self, model_epoch: int = 0, should_render: bool = True):
247266

248267
for _ in range(len(test_provider.data_frame)):
249268
action, state = model.predict(zero_completed_obs, state=state)
250-
obs, reward, _, __ = test_env.step([action[0]])
269+
obs, reward, done, info = test_env.step([action[0]])
251270

252271
zero_completed_obs[0, :] = obs
253272

254273
rewards.append(reward)
255274

256-
if should_render:
275+
if render_env:
257276
test_env.render(mode='human')
258277

278+
if done:
279+
net_worths = pd.DataFrame({
280+
'Date': info[0]['timestamps'],
281+
'Balance': info[0]['networths'],
282+
})
283+
284+
net_worths.set_index('Date', drop=True, inplace=True)
285+
returns = net_worths.pct_change()[1:]
286+
287+
if render_report:
288+
qs.plots.snapshot(returns.Balance, title='RL Trader Performance')
289+
290+
if save_report:
291+
reports_path = path.join('data', 'reports', f'{self.study_name}__{model_epoch}.html')
292+
qs.reports.html(returns.Balance, file=reports_path)
293+
259294
self.logger.info(
260295
f'Finished testing model ({self.study_name}__{model_epoch}): ${"{:.2f}".format(np.sum(rewards))}')

lib/cli/RLTraderCLI.py

+37-24
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class RLTraderCLI:
88
def __init__(self):
99
config_parser = argparse.ArgumentParser(add_help=False)
1010
config_parser.add_argument("-f", "--from-config", help="Specify config file", metavar="FILE")
11+
1112
args, _ = config_parser.parse_known_args()
1213
defaults = {}
1314

@@ -17,44 +18,56 @@ def __init__(self):
1718
defaults = dict(config.items("Defaults"))
1819

1920
formatter = argparse.ArgumentDefaultsHelpFormatter
20-
self.parser = argparse.ArgumentParser(
21-
formatter_class=formatter,
22-
parents=[config_parser],
23-
description=__doc__
24-
)
21+
self.parser = argparse.ArgumentParser(formatter_class=formatter,
22+
parents=[config_parser],
23+
description=__doc__)
2524

26-
self.parser.add_argument("--data-provider", "-o", type=str, default="static")
27-
self.parser.add_argument("--input-data-path", "-t", type=str, default="data/input/coinbase-1h-btc-usd.csv")
25+
self.parser.add_argument("--data-provider", "-d", type=str, default="static")
26+
self.parser.add_argument("--input-data-path", "-n", type=str, default="data/input/coinbase-1h-btc-usd.csv")
2827
self.parser.add_argument("--pair", "-p", type=str, default="BTC/USD")
29-
self.parser.add_argument("--debug", "-n", action='store_false')
28+
self.parser.add_argument("--debug", "-D", action='store_false')
3029
self.parser.add_argument('--mini-batches', type=int, default=1, help='Mini batches', dest='n_minibatches')
3130
self.parser.add_argument('--train-split-percentage', type=float, default=0.8, help='Train set percentage')
32-
self.parser.add_argument('--verbose-model', type=int, default=1, help='Verbose model')
33-
self.parser.add_argument('--params-db-path', type=str, default='sqlite:///data/params.db',
34-
help='Params path')
35-
self.parser.add_argument(
36-
'--tensor-board-path',
37-
type=str,
38-
default=os.path.join('data', 'tensorboard'),
39-
help='Tensorboard path',
40-
dest='tensorboard_path'
41-
)
42-
self.parser.add_argument('--parallel-jobs', type=int, default=multiprocessing.cpu_count(),
31+
self.parser.add_argument('--verbose-model', type=int, default=1, help='Verbose model', dest='model_verbose')
32+
self.parser.add_argument('--params-db-path', type=str, default='sqlite:///data/params.db', help='Params path')
33+
self.parser.add_argument('--tensorboard-path',
34+
type=str,
35+
default=os.path.join('data', 'tensorboard'),
36+
help='Tensorboard path')
37+
self.parser.add_argument('--parallel-jobs',
38+
type=int,
39+
default=multiprocessing.cpu_count(),
4340
help='How many processes in parallel')
4441

4542
subparsers = self.parser.add_subparsers(help='Command', dest="command")
4643

4744
optimize_parser = subparsers.add_parser('optimize', description='Optimize model parameters')
4845
optimize_parser.add_argument('--trials', type=int, default=1, help='Number of trials')
49-
50-
optimize_parser.add_argument('--verbose-model', type=int, default=1, help='Verbose model', dest='model_verbose')
46+
optimize_parser.add_argument('--prune-evals',
47+
type=int,
48+
default=2,
49+
help='Number of pruning evaluations per trial')
50+
optimize_parser.add_argument('--eval-tests', type=int, default=1, help='Number of tests per pruning evaluation')
5151

5252
train_parser = subparsers.add_parser('train', description='Train model')
53-
train_parser.add_argument('--epochs', type=int, default=1, help='Number of epochs to train')
53+
train_parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train')
54+
train_parser.add_argument('--save-every', type=int, default=1, help='Save the trained model every n epochs')
55+
train_parser.add_argument('--no-test', dest="test_trained", action="store_false", help='Test each saved model')
56+
train_parser.add_argument('--render-test', dest="render_test",
57+
action="store_true", help='Render the test environment')
58+
train_parser.add_argument('--no-report', dest="render_report", action="store_false",
59+
help='Render the performance report')
60+
train_parser.add_argument('--save-report', dest="save_report", action="store_true",
61+
help='Save the performance report as .html')
5462

5563
test_parser = subparsers.add_parser('test', description='Test model')
56-
test_parser.add_argument('--model-epoch', type=int, default=1, help='Model epoch index')
57-
test_parser.add_argument('--no-render', action='store_false', help='Do not render test')
64+
test_parser.add_argument('--model-epoch', type=int, default=0, help='Model epoch index')
65+
test_parser.add_argument('--no-render', dest="render_env", action="store_false",
66+
help='Render the test environment')
67+
test_parser.add_argument('--no-report', dest="render_report", action="store_false",
68+
help='Render the performance report')
69+
test_parser.add_argument('--save-report', dest="save_report", action="store_true",
70+
help='Save the performance report as .html')
5871

5972
subparsers.add_parser('update-static-data', description='Update static data')
6073

lib/data/features/indicators.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@
3535
('KCLI', ta.keltner_channel_lband_indicator, ['High', 'Low', 'Close']),
3636
('DCHI', ta.donchian_channel_hband_indicator, ['Close']),
3737
('DCLI', ta.donchian_channel_lband_indicator, ['Close']),
38-
('ADI', ta.acc_dist_index, ['High', 'Low', 'Close', 'volume']),
39-
('OBV', ta.on_balance_volume, ['close', 'volume']),
38+
('ADI', ta.acc_dist_index, ['High', 'Low', 'Close', 'Volume BTC']),
39+
('OBV', ta.on_balance_volume, ['Close', 'Volume BTC']),
4040
('CMF', ta.chaikin_money_flow, ['High', 'Low', 'Close', 'Volume BTC']),
41-
('FI', ta.force_index, ['Close', 'Volume']),
41+
('FI', ta.force_index, ['Close', 'Volume BTC']),
4242
('EM', ta.ease_of_movement, ['High', 'Low', 'Close', 'Volume BTC']),
4343
('VPT', ta.volume_price_trend, ['Close', 'Volume BTC']),
4444
('NVI', ta.negative_volume_index, ['Close', 'Volume BTC']),

lib/env/TradingEnv.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def _take_action(self, action: int):
115115

116116
current_net_worth = round(self.balance + self.asset_held * self._current_price(), self.base_precision)
117117
self.net_worths.append(current_net_worth)
118-
119118
self.account_history = self.account_history.append({
120119
'balance': self.balance,
121120
'asset_bought': asset_bought,
@@ -155,6 +154,7 @@ def _reward(self):
155154

156155
def _next_observation(self):
157156
self.current_ohlcv = self.data_provider.next_ohlcv()
157+
self.timestamps.append(pd.to_datetime(self.current_ohlcv.Date.item(), unit='s'))
158158
self.observations = self.observations.append(self.current_ohlcv, ignore_index=True)
159159

160160
if self.stationarize_obs:
@@ -187,6 +187,7 @@ def reset(self):
187187

188188
self.balance = self.initial_balance
189189
self.net_worths = [self.initial_balance]
190+
self.timestamps = []
190191
self.asset_held = 0
191192
self.current_step = 0
192193

@@ -210,8 +211,7 @@ def step(self, action):
210211
obs = self._next_observation()
211212
reward = self._reward()
212213
done = self._done()
213-
214-
return obs, reward, done, {}
214+
return obs, reward, done, {'networths': self.net_worths, 'timestamps': self.timestamps}
215215

216216
def render(self, mode='human'):
217217
if mode == 'system':

requirements.base.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ statsmodels==0.10.0rc2
1010
empyrical
1111
ccxt
1212
psycopg2
13-
configparser
13+
configparser
14+
quantstats

0 commit comments

Comments
 (0)