diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..26d3352 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/TradingGym.iml b/.idea/TradingGym.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/TradingGym.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..236312a --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/koshish.py b/koshish.py new file mode 100644 index 0000000..adfcc9e --- /dev/null +++ b/koshish.py @@ -0,0 +1,33 @@ +import random +import numpy as np +import pandas as pd +import trading_env + +df = pd.read_hdf('dataset/SGXTWsample.h5', 'STW') + +env = trading_env.make(env_id='training_v1', obs_data_len=256, step_len=128, + df=df, fee=0.1, max_position=5, deal_col_name='Price', + feature_names=['Price', 'Volume', + 'Ask_price','Bid_price', + 'Ask_deal_vol','Bid_deal_vol', + 'Bid/Ask_deal', 'Updown']) +env = trading_env.make(env_id='backtest_v1', obs_data_len=4096, step_len=5, + df=df, fee=0.1, max_position=50, deal_col_name='Price', + feature_names=['Price', 'Volume', + 'Ask_price','Bid_price', + 'Ask_deal_vol','Bid_deal_vol', + 'Bid/Ask_deal', 'Updown']) +env.reset() +env.render() + +state, reward, done, info = env.step(random.randrange(3)) + +### randow choice action and show the transaction detail +for i in range(500): + print(i) + state, reward, done, info = env.step(random.randrange(3)) + print(state, reward) + env.render() + if done: + break +env.transaction_details \ No newline at end of file diff --git a/trading_env/__init__.py b/trading_env/__init__.py index 75a205c..5948bad 100644 --- a/trading_env/__init__.py +++ b/trading_env/__init__.py @@ -1,4 +1,5 @@ -from .envs import available_envs_module +from trading_env.envs import available_envs_module + def available_envs(): available_envs = [env_module.__name__.split('.')[-1] for env_module in available_envs_module] diff --git a/trading_env/envs/backtest_v1.py b/trading_env/envs/backtest_v1.py index 5306319..61d6c97 100644 --- a/trading_env/envs/backtest_v1.py +++ b/trading_env/envs/backtest_v1.py @@ -88,9 +88,11 @@ def reset(self): self.df_sample = self._choice_section() self.step_st = 0 # define the price to calculate the reward - self.price = self.df_sample[self.price_name].as_matrix() + self.price = self.df_sample[self.price_name].to_numpy() + + # self.price = self.df_sample[self.price_name].as_matrix() # define the observation feature - self.obs_features = self.df_sample[self.using_feature].as_matrix() + self.obs_features = self.df_sample[self.using_feature].to_numpy() #maybe make market position feature in final feature, set as option self.posi_arr = np.zeros_like(self.price) # position variation @@ -370,8 +372,9 @@ def render(self, save=False): self.fig.savefig('fig/%s.png' % str(self.t_index)) elif self.render_on == 1: - self.ax.lines.remove(self.price_plot[0]) - [self.ax3.lines.remove(plot) for plot in self.features_plot] + self.price_plot[0].remove() + # self.ax.lines.remove(self.price_plot[0]) + [plot.remove() for plot in self.features_plot] self.fluc_reward_plot_p.remove() self.fluc_reward_plot_n.remove() self.target_box.remove() diff --git a/trading_env/envs/training_v1.py b/trading_env/envs/training_v1.py index 9128732..a87ef79 100644 --- a/trading_env/envs/training_v1.py +++ b/trading_env/envs/training_v1.py @@ -361,8 +361,14 @@ def render(self, save=False): self.fig.savefig('fig/%s.png' % str(self.t_index)) elif self.render_on == 1: - self.ax.lines.remove(self.price_plot[0]) - [self.ax3.lines.remove(plot) for plot in self.features_plot] + line_to_remove = self.price_plot[0] # Assuming the line you want to remove is at index 0 + print(line_to_remove, type(line_to_remove)) + line_to_remove.remove() + # self.ax.lines.remove(line_to_remove) + + # self.ax.lines.remove(self.price_plot[0]) + # [self.ax3.lines.remove(plot) for plot in self.features_plot] + [plot.remove() for plot in self.features_plot] self.fluc_reward_plot_p.remove() self.fluc_reward_plot_n.remove() self.target_box.remove()