-
Notifications
You must be signed in to change notification settings - Fork 0
/
mct.py
259 lines (206 loc) · 7.18 KB
/
mct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import numpy as np
from numpy import log as ln
import pandas as pd
import random
import copy
import pickle
from state import State
class MC_node():
def __init__(self, state, game_is_done=False, parent_edge=None, depth=None):
self.state = state
self.N = 0
self.edges = []
self.sorted = False # sinify whether self.edges is sorted or not, only roll out in MCT can set it to false
self.depth = depth
self.parent_edge = parent_edge
self.game_is_done = game_is_done
def add_edge(self, e):
self.edges.append(e)
def add_edges(self, es):
self.edges.extend(es)
def is_leaf(self):
if len(self.edges) == 0:
return True
else:
return False
def get_actions_of_edges(self):
return [i.action for i in self.edges]
def get_first_child(self):
return self.edges[0].get_out_node()
def get_children(self):
return [i.get_out_node() for i in self.edges]
def get_edge_with_action(self, action):
for i in self.edges:
if np.equal(i.action.state, action.state):
return i
return 0
def get_infor_of_edges(self):
dat = pd.DataFrame(columns = self.get_actions_of_edges())
dat.loc['N', :] = self.get_N_of_edges()
dat.loc['W', :] = self.get_W_of_edges()
dat.loc['c', :] = self.get_c_of_edges()
dat.loc['Win rate', :] = self.get_winrate_of_edges()
dat.loc['Part2', :] = self.get_part()
dat.loc['Value', :] = self.get_value_of_edges()
dat.loc['Original Value'] = self.get_orig_value_of_edges()
dat.loc['Game is done'] = self.get_done_of_edges()
return dat
def get_num_edges(self):
return len(self.edges)
def get_N_of_edges(self):
return [i.N for i in self.edges]
def get_c_of_edges(self):
return [i.c for i in self.edges]
def get_part(self):
return [i.get_part() for i in self.edges]
def get_value_of_edges(self):
return [i.get_value() for i in self.edges]
def get_orig_value_of_edges(self):
return [i.value for i in self.edges]
def get_winrate_of_edges(self):
return [i.get_winrate() for i in self.edges]
def get_distribution(self):
return [i.get_distribution() for i in self.edges]
def get_W_of_edges(self):
return [i.W for i in self.edges]
def get_done_of_edges(self):
return [i.out_node.game_is_done for i in self.edges]
def get_N(self):
return self.N
def get_state(self):
return self.state
def reset_sorted(self):
### Aborted
self.sorted = False
def sort_edges_by_value(self):
return sorted(self.edges, key = lambda x: x.get_value(), reverse = True)
def sort_edges_by_winrate(self):
return sorted(self.edges, key = lambda x: x.get_winrate(), reverse = True)
def sort_edges_by_N(self):
return sorted(self.edges, key = lambda x: (x.N, x.value), reverse = True)
def get_sum_winrates(self):
return sum(self.get_winrate_of_edges())
def __eq__(self, other):
return self.state == other.state
def __hash__(self):
return hash(self.state)
class MC_edge():
def __init__(self, action, in_node, out_node, c, value):
self.action = action
self.in_node = in_node
self.out_node = out_node
self.out_node.parent_edge = self
self.value = value
self.N = 0
self.W = 0
self.c = c
def get_in_node(self):
return self.in_node
def get_out_node(self):
return self.out_node
def get_action(self):
return self.action
def get_state(self):
return (self.Q, self.U, self.W, self.N, self.P)
def get_value(self):
if self.N == 0:
return 4 + self.value
else:
return self.get_winrate() + self.c*np.sqrt(ln(self.in_node.N)/(self.N))
def get_part(self):
if self.N == 0:
return self.value
else:
return np.sqrt(ln(self.in_node.N)/self.N)
def get_winrate(self):
if self.N == 0:
return self.value
else:
return self.W/self.N
def get_distribution(self):
if self.N == 0:
return 0.0
else:
return self.W/self.N
class MCFE_tree:
def __init__(self, root_state):
# init attribute
self.root = MC_node(root_state, depth=0)
self.tree = set()
self.actions = set()
self.add_node_to_tree(self.root)
def add_node_to_tree(self, node):
if node not in self.tree:
self.tree.add(node)
return 1
else:
return 0
def add_actions(self, action_list):
for a in action_list:
self.actions.add(State(a))
def back_fill(self, reward, node):
"""
Performs the backpropagation.
Keyword arguments:
reward -- reward for the current node/state
node -- node corresponding to the current state
"""
n = node
n.N += 1
while n != self.root:
#print("Hello")
#print(n)
edge = n.parent_edge
edge.N += 1
edge.W += reward
n = n.parent_edge.in_node
n.N += 1
def expansion(self, leaf, edges, values):
"""
Adds edges to tree and returns most valuable edge.
Keyword arguments:
leaf -- leaf that is supposed to be expanded
edges -- new edges
values -- related values of the edges/actions
"""
leaf.add_edges(edges)
out = edges[0]
best_score= values[0]
for edge, value in zip(edges, values):
if best_score < value:
best_score = value
out = edge
self.add_node_to_tree(out.get_out_node())
return out
def selection(self, root=None):
"""
Selection with UCT1
Keyword arguments:
leaf -- node from where the selection starts
"""
path = []
current_node = root
if root.is_leaf():
return root, path # the paths here are empty
else:
while not current_node.is_leaf():
edges = current_node.sort_edges_by_value()
edge = edges[0]
path.append(edge)
current_node = edge.get_out_node()
return current_node, path
def selection_with_N(self):
"""
Selection with highest N.
"""
path = []
current_node = self.root
if self.root.is_leaf():
return self.root, path # the paths here are empty
else:
while not current_node.is_leaf():
edges = current_node.sort_edges_by_N()
edge = edges[0]
path.append(edge)
current_node = edge.get_out_node()
return current_node, path