diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..26d3352
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/.idea/TradingGym.iml b/.idea/TradingGym.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/TradingGym.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..236312a
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/koshish.py b/koshish.py
new file mode 100644
index 0000000..adfcc9e
--- /dev/null
+++ b/koshish.py
@@ -0,0 +1,33 @@
+import random
+import numpy as np
+import pandas as pd
+import trading_env
+
+df = pd.read_hdf('dataset/SGXTWsample.h5', 'STW')
+
+env = trading_env.make(env_id='training_v1', obs_data_len=256, step_len=128,
+ df=df, fee=0.1, max_position=5, deal_col_name='Price',
+ feature_names=['Price', 'Volume',
+ 'Ask_price','Bid_price',
+ 'Ask_deal_vol','Bid_deal_vol',
+ 'Bid/Ask_deal', 'Updown'])
+env = trading_env.make(env_id='backtest_v1', obs_data_len=4096, step_len=5,
+ df=df, fee=0.1, max_position=50, deal_col_name='Price',
+ feature_names=['Price', 'Volume',
+ 'Ask_price','Bid_price',
+ 'Ask_deal_vol','Bid_deal_vol',
+ 'Bid/Ask_deal', 'Updown'])
+env.reset()
+env.render()
+
+state, reward, done, info = env.step(random.randrange(3))
+
+### randow choice action and show the transaction detail
+for i in range(500):
+ print(i)
+ state, reward, done, info = env.step(random.randrange(3))
+ print(state, reward)
+ env.render()
+ if done:
+ break
+env.transaction_details
\ No newline at end of file
diff --git a/trading_env/__init__.py b/trading_env/__init__.py
index 75a205c..5948bad 100644
--- a/trading_env/__init__.py
+++ b/trading_env/__init__.py
@@ -1,4 +1,5 @@
-from .envs import available_envs_module
+from trading_env.envs import available_envs_module
+
def available_envs():
available_envs = [env_module.__name__.split('.')[-1] for env_module in available_envs_module]
diff --git a/trading_env/envs/backtest_v1.py b/trading_env/envs/backtest_v1.py
index 5306319..61d6c97 100644
--- a/trading_env/envs/backtest_v1.py
+++ b/trading_env/envs/backtest_v1.py
@@ -88,9 +88,11 @@ def reset(self):
self.df_sample = self._choice_section()
self.step_st = 0
# define the price to calculate the reward
- self.price = self.df_sample[self.price_name].as_matrix()
+ self.price = self.df_sample[self.price_name].to_numpy()
+
+ # self.price = self.df_sample[self.price_name].as_matrix()
# define the observation feature
- self.obs_features = self.df_sample[self.using_feature].as_matrix()
+ self.obs_features = self.df_sample[self.using_feature].to_numpy()
#maybe make market position feature in final feature, set as option
self.posi_arr = np.zeros_like(self.price)
# position variation
@@ -370,8 +372,9 @@ def render(self, save=False):
self.fig.savefig('fig/%s.png' % str(self.t_index))
elif self.render_on == 1:
- self.ax.lines.remove(self.price_plot[0])
- [self.ax3.lines.remove(plot) for plot in self.features_plot]
+ self.price_plot[0].remove()
+ # self.ax.lines.remove(self.price_plot[0])
+ [plot.remove() for plot in self.features_plot]
self.fluc_reward_plot_p.remove()
self.fluc_reward_plot_n.remove()
self.target_box.remove()
diff --git a/trading_env/envs/training_v1.py b/trading_env/envs/training_v1.py
index 9128732..a87ef79 100644
--- a/trading_env/envs/training_v1.py
+++ b/trading_env/envs/training_v1.py
@@ -361,8 +361,14 @@ def render(self, save=False):
self.fig.savefig('fig/%s.png' % str(self.t_index))
elif self.render_on == 1:
- self.ax.lines.remove(self.price_plot[0])
- [self.ax3.lines.remove(plot) for plot in self.features_plot]
+ line_to_remove = self.price_plot[0] # Assuming the line you want to remove is at index 0
+ print(line_to_remove, type(line_to_remove))
+ line_to_remove.remove()
+ # self.ax.lines.remove(line_to_remove)
+
+ # self.ax.lines.remove(self.price_plot[0])
+ # [self.ax3.lines.remove(plot) for plot in self.features_plot]
+ [plot.remove() for plot in self.features_plot]
self.fluc_reward_plot_p.remove()
self.fluc_reward_plot_n.remove()
self.target_box.remove()