Skip to content

Commit

Permalink
Tensortrade (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexBabescu authored Apr 18, 2022
1 parent 6820362 commit 59ec49f
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 130 deletions.
3 changes: 2 additions & 1 deletion environment_cpuonly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,5 @@ dependencies:
- ta
- freqtrade
- stable-baselines3
- sb3-contrib
- sb3-contrib
- tensortrade
3 changes: 2 additions & 1 deletion environment_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,5 @@ dependencies:
- ta
- freqtrade
- stable-baselines3
- sb3-contrib
- sb3-contrib
- tensortrade
14 changes: 14 additions & 0 deletions tb_callbacks/SaveOnStepCallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, check_freq: int, save_name: str, save_dir: str, log_dir: str,
self.log_dir = Path(log_dir)
self.save_path = Path(save_dir) / save_name
self.best_mean_reward = -np.inf
self.best_net_worth = -np.inf

assert self.log_dir.exists()
assert Path(save_dir).exists()
Expand All @@ -27,6 +28,9 @@ def _on_step(self) -> bool:
# total_reward = self.training_env.buf_infos[0]['total_reward']
# self.logger.record('total_reward', total_reward)

net_worth = self.training_env.buf_infos[0]['net_worth']
self.logger.record('net_worth', net_worth)

if self.n_calls % self.check_freq == 0:
# Retrieve training reward
x, y = ts2xy(load_results(str(self.log_dir)), 'timesteps')
Expand All @@ -45,3 +49,13 @@ def _on_step(self) -> bool:
self.model.save(self.save_path) # type: ignore

return True

def on_rollout_end(self):
net_worth = self.training_env.buf_infos[0]['net_worth']
if net_worth > self.best_net_worth:
self.best_net_worth = net_worth
save_name = f"{self.save_path}_net_worth"
if self.verbose:
print("Best net worth: {:.2f} - Last net worth: {:.2f}".format(self.best_net_worth, net_worth))
print(f"Saving new best net worth to {save_name}")
self.model.save(save_name) # type: ignore
199 changes: 142 additions & 57 deletions train_freqtrade.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,119 @@
import datetime
from os import access
from pathlib import Path

import mpu
import tensortrade.env.default as default
import torch as th
from freqtrade.configuration import Configuration, TimeRange
from freqtrade.data import history
from freqtrade.data.dataprovider import DataProvider
from freqtrade.exchange import Exchange as FreqtradeExchange
from freqtrade.resolvers import StrategyResolver
from gym.spaces import Discrete, Space
from stable_baselines3.a2c.a2c import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.ppo.ppo import PPO
from stable_baselines3.a2c.a2c import A2C
from tensortrade.env.default.actions import BSH, TensorTradeActionScheme
from tensortrade.env.default.rewards import PBR, RiskAdjustedReturns, SimpleProfit, TensorTradeRewardScheme
from tensortrade.env.generic import ActionScheme, TradingEnv
from tensortrade.feed.core import DataFeed, NameSpace, Stream
from tensortrade.oms.exchanges import Exchange, ExchangeOptions
from tensortrade.oms.instruments import BTC, ETH, LTC, USD, Instrument
from tensortrade.oms.orders import proportion_order
from tensortrade.oms.services.execution.simulated import execute_order
from tensortrade.oms.wallets import Portfolio, Wallet

from tb_callbacks import SaveOnStepCallback
from trading_environments import FreqtradeEnv, SimpleROIEnv, GymAnytrading
from trading_environments import FreqtradeEnv, GymAnytrading, SimpleROIEnv

"""Settings"""
PAIR = "BTC/USDT"
TRAINING_RANGE = "20210601-20210901"
PAIR = "ADA/USDT"
TRAINING_RANGE = "20210901-20211231"
WINDOW_SIZE = 10
LOAD_PREPROCESSED_DATA = False # useful if you have to calculate a lot of features
SAVE_PREPROCESSED_DATA = True
LEARNING_TIME_STEPS = int(1e+6)
LEARNING_TIME_STEPS = int(1e9)
LOG_DIR = "./logs/"
TENSORBOARD_LOG = "./tensorboard/"
MODEL_DIR = "./models/"
USER_DATA = Path(__file__).parent / "user_data"
"""End of settings"""

freqtrade_config = Configuration.from_files(['user_data/config.json'])
freqtrade_config = Configuration.from_files([str(USER_DATA / "config.json")])
_preprocessed_data_file = "preprocessed_data.pickle"
from gym.spaces import Discrete, Space


class BuySellHold(TensorTradeActionScheme):
"""A simple discrete action scheme where the only options are to buy, sell,
or hold.
Parameters
----------
cash : `Wallet`
The wallet to hold funds in the base instrument.
asset : `Wallet`
The wallet to hold funds in the quote instrument.
"""

registered_name = "bsh"

def __init__(self, cash: "Wallet", asset: "Wallet"):
super().__init__()
self.cash = cash
self.asset = asset

self.listeners = []

@property
def action_space(self):
return Discrete(3)

def attach(self, listener):
self.listeners += [listener]
return self

def get_orders(self, action: int, portfolio: "Portfolio"):
order = None

if action == 2: # Hold
return []

if action == 0: # Buy
if self.cash.balance == 0:
return []
order = proportion_order(portfolio, self.cash, self.asset, 1.0)

if action == 1: # Sell
if self.asset.balance == 0:
return []
order = proportion_order(portfolio, self.asset, self.cash, 1.0)

for listener in self.listeners:
listener.on_action(action)

return [order]

def reset(self):
super().reset()


def main():

strategy = StrategyResolver.load_strategy(freqtrade_config)
strategy.dp = DataProvider(freqtrade_config, FreqtradeExchange(freqtrade_config), None)
required_startup = strategy.startup_candle_count
timeframe = freqtrade_config.get('timeframe')
timeframe = freqtrade_config.get("timeframe")
data = dict()

if LOAD_PREPROCESSED_DATA:
assert Path(_preprocessed_data_file).exists(), "Unable to load preprocessed data!"
data = mpu.io.read(_preprocessed_data_file)
assert PAIR in data, f"Loaded preprocessed data does not contain pair {PAIR}!"
else:
data = _load_data(freqtrade_config, PAIR, timeframe, TRAINING_RANGE, WINDOW_SIZE)
data = strategy.advise_all_indicators(data)
data = _load_data(freqtrade_config, timeframe, TRAINING_RANGE)
data = strategy.advise_all_indicators({PAIR: data[PAIR]})
if SAVE_PREPROCESSED_DATA:
mpu.io.write(_preprocessed_data_file, data)

Expand All @@ -51,89 +122,103 @@ def main():

del data

price_data = pair_data[['date', 'open', 'close', 'high', 'low', 'volume']].copy()
price_data = pair_data[["date", "open", "close", "high", "low", "volume"]].copy()

pair_data.drop(columns=['date', 'open', 'close', 'high', 'low', 'volume'], inplace=True)
pair_data.drop(columns=["date", "open", "close", "high", "low", "volume"], inplace=True)
pair_data.fillna(0, inplace=True)

ADA = Instrument("ADA", 3, "Cardano")

price = Stream.source(list(price_data["close"]), dtype="float").rename("USD-ADA")

exchange_options = ExchangeOptions(commission=0.0035)
binance = Exchange("binance", service=execute_order, options=exchange_options)(price)

cash = Wallet(binance, 1000 * USD)
asset = Wallet(binance, 0 * ADA)

portfolio = Portfolio(USD, [cash, asset])

features = [Stream.source(list(pair_data[c]), dtype="float").rename(c) for c in pair_data.columns]

feed = DataFeed(features)
feed.compile()

renderer_feed = DataFeed(
[
Stream.source(list(price_data["date"])).rename("date"),
Stream.source(list(price_data["open"]), dtype="float").rename("open"),
Stream.source(list(price_data["high"]), dtype="float").rename("high"),
Stream.source(list(price_data["low"]), dtype="float").rename("low"),
Stream.source(list(price_data["close"]), dtype="float").rename("close"),
Stream.source(list(price_data["volume"]), dtype="float").rename("volume"),
]
)

action_scheme = BuySellHold(cash=cash, asset=asset)

# reward_scheme = PBR(price)
reward_scheme = SimpleProfit(window_size=8)

trading_env = default.create(
portfolio=portfolio,
action_scheme=action_scheme,
reward_scheme=reward_scheme,
feed=feed,
renderer_feed=renderer_feed,
window_size=WINDOW_SIZE,
max_allowed_loss=0.50,
)

# trading_env = FreqtradeEnv(
# data=pair_data,
# prices=price_data,
# window_size=WINDOW_SIZE, # how many past candles should it use as features
# pair=PAIR,
# stake_amount=freqtrade_config['stake_amount'],
# punish_holding_amount=0,
# )

# trading_env = SimpleROIEnv(
# data=pair_data,
# prices=price_data,
# window_size=WINDOW_SIZE, # how many past candles should it use as features
# required_startup=required_startup,
# minimum_roi=0.02, # 2% target ROI
# roi_candles=24, # 24 candles * 5m = 120 minutes
# punish_holding_amount=0,
# punish_missed_buy=True
# )

trading_env = GymAnytrading(
signal_features=pair_data,
prices=price_data.close,
window_size=WINDOW_SIZE, # how many past candles should it use as features
)

trading_env = Monitor(trading_env, LOG_DIR)
trading_env = Monitor(trading_env, LOG_DIR, info_keywords=("net_worth",))

# Optional policy_kwargs
# see https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html?highlight=policy_kwargs#custom-network-architecture
# policy_kwargs = dict(activation_fn=th.nn.ReLU,
# net_arch=[dict(pi=[32, 32], vf=[32, 32])])
# policy_kwargs = dict(activation_fn=th.nn.Tanh, net_arch=[32, dict(pi=[64, 64], vf=[64, 64])])
policy_kwargs = dict(net_arch=[32, dict(pi=[64, 64], vf=[64, 64])])
policy_kwargs = dict(net_arch=[128, dict(pi=[128, 128], vf=[128, 128])])

start_date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

model = PPO( # See https://stable-baselines3.readthedocs.io/en/master/guide/algos.html for other algos with discrete action space
"MlpPolicy",
"MlpPolicy", # MlpPolicy MultiInputPolicy
trading_env,
verbose=0,
device='auto',
device="cuda",
tensorboard_log=TENSORBOARD_LOG,
# policy_kwargs=policy_kwargs
# n_steps = len(pair_data),
# batch_size = 1000,
# n_epochs = 20,
policy_kwargs=policy_kwargs,
)

base_name = f"{strategy.get_strategy_name()}_{trading_env.env.__class__.__name__}_{model.__class__.__name__}_{start_date}"
base_name = f"{strategy.get_strategy_name()}_TensorTrade_{model.__class__.__name__}_{start_date}"

tb_callback = SaveOnStepCallback(
check_freq=5000,
save_name=f"best_model_{base_name}",
save_dir=MODEL_DIR,
log_dir=LOG_DIR,
verbose=1)
check_freq=10000, save_name=f"best_model_{base_name}", save_dir=MODEL_DIR, log_dir=LOG_DIR, verbose=1
)

print(f"You can run tensorboard with: 'tensorboard --logdir {Path(TENSORBOARD_LOG).absolute()}'")
print("Learning started.")

model.learn(
total_timesteps=LEARNING_TIME_STEPS,
callback=tb_callback
)
model.learn(total_timesteps=LEARNING_TIME_STEPS, callback=tb_callback)
model.save(f"{MODEL_DIR}final_model_{base_name}")


def _load_data(config, pair, timeframe, timerange, window_size):
def _load_data(config, timeframe, timerange):
timerange = TimeRange.parse_timerange(timerange)

return history.load_data(
datadir=config['datadir'],
pairs=[pair],
datadir=config["datadir"],
pairs=config["pairs"],
timeframe=timeframe,
timerange=timerange,
startup_candles=window_size + 1,
startup_candles=config["startup_candle_count"],
fail_without_data=True,
data_format=config.get('dataformat_ohlcv', 'json'),
data_format=config.get("dataformat_ohlcv", "json"),
)


if __name__ == "__main__":
main()
11 changes: 8 additions & 3 deletions user_data/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"stake_amount": 200,
"tradable_balance_ratio": 0.99,
"fiat_display_currency": "USD",
"timeframe": "5m",
"timeframe": "15m",
"dry_run": true,
"dry_run_wallet": 1000,
"cancel_open_orders_on_exit": false,
Expand Down Expand Up @@ -40,7 +40,12 @@
"rateLimit": 200
},
"pair_whitelist": [
"BTC/USDT"
"ETH/USDT",
"ADA/USDT",
"EOS/USDT",
"XLM/USDT",
"LTC/USDT",
"TRX/USDT"
],
"pair_blacklist": [
"BNB/.*"
Expand Down Expand Up @@ -79,7 +84,7 @@
"jwt_secret_key": "595e9cb58c29748270a739db7aa127fc30a9ecfbf0d11fc9d977de4d8a86037f",
"CORS_origins": [],
"username": "freqtrader",
"password": ""
"password": "freqtrader"
},
"bot_name": "freqtrade",
"initial_state": "running",
Expand Down
Loading

0 comments on commit 59ec49f

Please sign in to comment.