forked from yaozhaoyz/Pair-Trading-Reinforcement-Learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMachineLearning.py
191 lines (151 loc) · 6.61 KB
/
MachineLearning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import random
random.seed(0)
import numpy as np
np.random.seed(0)
from MAIN.Basics import Processor, Space
from operator import itemgetter
class StateSpace(Processor, Space):
def __init__(self, agent):
self.agent = agent
super().__init__(agent.config['StateSpaceState'])
def process(self, method, env_state):
self.agent.data['NETWORK_STATE'] = self._get_network_input(method, env_state)
self.agent.data['ENGINE_STATE' ] = self._get_engine_input()
def _get_network_input(self, method, env_state):
if method == 'stateless':
method = self.agent.config['StateSpaceNetworkSampleType']
state = self.get_random_sample(method)
state = 0
return state
elif method == 'binary':
method = self.agent.config['StateSpaceNetworkSampleType']
state = self.get_random_sample(method)
#state = 0
state = env_state
return state
else:
raise ValueError('method should be stateless/binary')
def _get_engine_input(self):
#method = self.agent.config['StateSpaceEngineSampleConversion']
#state = self.agent.data['NETWORK_STATE']
#state = self.convert(state, method)
# Input of the state space to the network is just a conversion (ex. index_to_dict) of that input to the network.
state = {'transaction_cost': [0.001]}
return state
class ActionSpace(Processor, Space):
def __init__(self, agent):
self.agent = agent
super().__init__(agent.config['ActionSpaceAction'])
def process(self):
self.agent.data['NETWORK_ACTION'] = self._get_network_input()
self.agent.data['ENGINE_ACTION' ] = self._get_engine_input()
def _get_network_input(self):
method = self.agent.config['ActionSpaceNetworkSampleType']
if method == 'exploration':
self.agent.exploration.process()
action = self.agent.data['EXPLORATION_ACTION']
else:
action = self.get_random_sample(method)
return action
def _get_engine_input(self):
method = self.agent.config['ActionSpaceEngineSampleConversion']
index = self.agent.data['EXPLORATION_ACTION']
action = self.convert(index, method)
return action
class RewardEngine(Processor):
def __init__(self, agent, engine):
self.engine = engine
self.agent = agent
#def process(self):
def process(self, index):
#reward, record = self._get_reward()
reward, record = self._get_reward(index)
self.agent.data['ENGINE_REWARD'] = reward
self.agent.data['ENGINE_RECORD'] = record
#def _get_reward(self):
def _get_reward(self, index):
state = self.agent.data['ENGINE_STATE']
action = self.agent.data['ENGINE_ACTION']
self.engine.process(index=index, **state, **action)
return self.engine.reward, self.engine.record
class Exploration(Processor):
def __init__(self, agent):
self.agent = agent
self.method = agent.config['ExplorationMethod']
self.counter = agent.counters[agent.config['ExplorationCounter']]
self.func = self.get_func(self.method)
if self.method == 'boltzmann':
self.target_attr = getattr(self.agent, self.agent.config['ExplorationBoltzmannProbAttribute'])
def process(self):
self.agent.data['EXPLORATION_ACTION'] = self.func()
def get_func(self, method):
method = '_' + method
return getattr(self, method)
def _random(self):
n_action = self.agent.action_space.n_combination
action_idx = random.randrange(n_action)
return action_idx
def _greedy(self):
self.agent.feed_dict[self.agent.input_layer] = [self.agent.data['NETWORK_STATE']]
q_value = self.agent.session.run(self.agent.output_layer, feed_dict=self.agent.feed_dict)
q_value = q_value.reshape(-1,)
action_idx = np.argmax(q_value)
return action_idx
def _e_greedy(self):
e = self.counter.value
action_idx = self._random() if random.random() < e else self._greedy()
self.counter.step()
return action_idx
def _boltzmann(self):
self.agent.data['BOLTZMANN_TEMP'] = self.counter.value
self.agent.feed_dict[self.agent.input_layer] = [self.agent.data['NETWORK_STATE']]
self.agent.feed_dict[self.agent.temp ] = [self.agent.data['BOLTZMANN_TEMP']]
prob = self.agent.session.run(self.target_attr, feed_dict=self.agent.feed_dict)
action_idx = np.random.choice(self.agent.action_space.n_combination, p=prob)
self.counter.step()
return action_idx
class ExperienceBuffer(Processor):
def __init__(self, agent):
buffer_size = int(agent.config['ExperienceBufferBufferSize'])
self.agent = agent
self.buffer = []
self.buffer_size = buffer_size
def process(self, method):
if method == 'add':
self._add_sample(self.agent.data['SAMPLE'])
elif method == 'get':
self.agent.data['EXPERIENCE_BUFFER_SAMPLE'] = self._get_sample()
else:
raise ValueError("Error: method name should be add/get.")
def _add_sample(self, sample):
sample_length = len(sample)
buffer_length = len(self.buffer)
is_single_sample = True if sample_length == 1 else False
if is_single_sample is True:
total_length = buffer_length
elif is_single_sample is False:
total_length = buffer_length + sample_length
else:
raise ValueError("Error: Boolean value required for input is_single_sample.")
if total_length > buffer_length:
idx_start = total_length - buffer_length
self.buffer = self.buffer[idx_start:]
self.buffer.extend(sample)
else:
self.buffer.extend(sample)
def _get_sample(self):
size = int(self.agent.config['ExperienceBufferSamplingSize'])
sample = itemgetter(*np.random.randint(len(self.buffer), size=size))(self.buffer)
return sample
class Recorder(Processor):
def __init__(self, agent):
self.data_field = agent.config['RecorderDataField']
self.record_freq = agent.config['RecorderRecordFreq']
self.agent = agent
if self.data_field is not None:
self.record = {key: [] for key in self.data_field}
def process(self):
if self.data_field is not None:
if (self.agent.epoch_counter.n_step % self.record_freq) == 0:
for key in self.record.keys():
self.record[key].append(self.agent.data[key])