Skip to content

Commit c6ea136

Browse files
committed
FIXUP - Simplifications
1 parent 9966f72 commit c6ea136

12 files changed

+186
-269
lines changed

diagnose_model.py

+9-17
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,18 @@ def get_virtual_trajectory_from_obs(
5454
virtual_to_play = self.config.players[0]
5555

5656
# Generate new root
57-
value, reward, policy_logits, hidden_state = self.model.recurrent_inference(
57+
value, reward, policy_parameters, hidden_state = self.model.recurrent_inference(
5858
root.hidden_state,
5959
torch.tensor([[action]]).to(root.hidden_state.device),
6060
)
6161
value = support_to_scalar(value, self.config.support_size).item()
6262
reward = support_to_scalar(reward, self.config.support_size).item()
6363
root = Node(0)
64+
sampled_actions = self.model.sample_actions(policy_parameters)
6465
root.expand(
65-
self.config.action_space,
66+
sampled_actions,
6667
virtual_to_play,
6768
reward,
68-
policy_logits,
6969
hidden_state,
7070
)
7171

@@ -208,10 +208,10 @@ def __init__(self, title, config):
208208
self.policies_after_planning = []
209209
# Not implemented, need to store them in every nodes of the mcts
210210
self.prior_values = []
211-
self.values_after_planning = [[numpy.NaN] * len(self.config.action_space)]
211+
self.values_after_planning = [[numpy.NaN] * sum(self.config.action_shape)]
212212
self.prior_root_value = []
213213
self.root_value_after_planning = []
214-
self.prior_rewards = [[numpy.NaN] * len(self.config.action_space)]
214+
self.prior_rewards = [[numpy.NaN] * sum(self.config.action_shape)]
215215
self.mcts_depth = []
216216

217217
def store_info(self, root, mcts_info, action, reward, new_prior_root_value=None):
@@ -222,25 +222,19 @@ def store_info(self, root, mcts_info, action, reward, new_prior_root_value=None)
222222
self.prior_policies.append(
223223
[
224224
root.children[action].prior
225-
if action in root.children.keys()
226-
else numpy.NaN
227-
for action in self.config.action_space
225+
for action in root.children.keys()
228226
]
229227
)
230228
self.policies_after_planning.append(
231229
[
232230
root.children[action].visit_count / self.config.num_simulations
233-
if action in root.children.keys()
234-
else numpy.NaN
235-
for action in self.config.action_space
231+
for action in root.children.keys()
236232
]
237233
)
238234
self.values_after_planning.append(
239235
[
240236
root.children[action].value()
241-
if action in root.children.keys()
242-
else numpy.NaN
243-
for action in self.config.action_space
237+
for action in root.children.keys()
244238
]
245239
)
246240
self.prior_root_value.append(
@@ -252,9 +246,7 @@ def store_info(self, root, mcts_info, action, reward, new_prior_root_value=None)
252246
self.prior_rewards.append(
253247
[
254248
root.children[action].reward
255-
if action in root.children.keys()
256-
else numpy.NaN
257-
for action in self.config.action_space
249+
for action in root.children.keys()
258250
]
259251
)
260252
self.mcts_depth.append(mcts_info["max_tree_depth"])

game_history.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,24 @@ def __init__(self):
1212
self.to_play_history = []
1313
self.child_visits = []
1414
self.root_values = []
15+
self.sampled_actions_history = []
1516
self.reanalysed_predicted_root_values = None
1617
# For PER
1718
self.priorities = None
1819
self.game_priority = None
1920

20-
def store_search_statistics(self, root, action_space):
21+
def store_search_statistics(self, root):
2122
# Turn visit count from root into a policy
2223
if root is not None:
2324
sum_visits = sum(child.visit_count for child in root.children.values())
24-
self.child_visits.append(
25-
[
26-
root.children[a].visit_count / sum_visits
27-
if a in root.children
28-
else 0
29-
for a in action_space
30-
]
31-
)
32-
25+
self.child_visits.append([root.children[a].visit_count / sum_visits for a in root.children.keys()])
26+
self.sampled_actions_history.append(root.sampled_actions)
3327
self.root_values.append(root.value())
3428
else:
3529
self.root_values.append(None)
3630

3731
def get_stacked_observations(
38-
self, index, num_stacked_observations, action_space_size
32+
self, index, num_stacked_observations
3933
):
4034
"""
4135
Generate a new observation with the observation at the index position
@@ -55,7 +49,7 @@ def get_stacked_observations(
5549
[
5650
numpy.ones_like(stacked_observations[0])
5751
* self.action_history[past_observation_index + 1]
58-
/ action_space_size
52+
/ len(self.sampled_actions_history[past_observation_index + 1])
5953
],
6054
)
6155
)

games/breakout.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import gym
55
import numpy
66
import torch
7+
from torch.distributions import Categorical
78

89
from .abstract_game import AbstractGame
910

@@ -54,7 +55,7 @@ def __init__(self):
5455

5556

5657
### Network
57-
self.network = "resnet" # "resnet" / "fullyconnected"
58+
self.network = "sampled" # "resnet" / "fullyconnected" / "sampled"
5859
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward)))
5960

6061
# Residual Network
@@ -76,7 +77,10 @@ def __init__(self):
7677
self.fc_value_layers = [] # Define the hidden layers in the value network
7778
self.fc_policy_layers = [] # Define the hidden layers in the policy network
7879

79-
80+
# Sampled
81+
self.sample_size = 4
82+
self.action_shape = [4]
83+
self.policy_distribution = Categorical
8084

8185
### Training
8286
self.results_path = pathlib.Path(__file__).resolve().parents[1] / "results" / pathlib.Path(__file__).stem / datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S") # Path to store the model weights and TensorBoard logs

games/cartpole.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def __init__(self):
4949

5050

5151
### Network
52-
self.network = "fullyconnected" # "resnet" / "fullyconnected"
52+
self.network = "sampled" # "resnet" / "fullyconnected"
5353
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward)))
54-
54+
5555
# Residual Network
5656
self.downsample = False # Downsample observations before representation network, False / "CNN" (lighter) / "resnet" (See paper appendix Network Architecture)
5757
self.blocks = 1 # Number of blocks in the ResNet
@@ -66,18 +66,22 @@ def __init__(self):
6666
# Fully Connected Network
6767
self.encoding_size = 8
6868
self.fc_representation_layers = [] # Define the hidden layers in the representation network
69-
self.fc_dynamics_layers = [16] # Define the hidden layers in the dynamics network
70-
self.fc_reward_layers = [16] # Define the hidden layers in the reward network
71-
self.fc_value_layers = [16] # Define the hidden layers in the value network
72-
self.fc_policy_layers = [16] # Define the hidden layers in the policy network
69+
self.fc_dynamics_layers = [32] # Define the hidden layers in the dynamics network
70+
self.fc_reward_layers = [32] # Define the hidden layers in the reward network
71+
self.fc_value_layers = [32] # Define the hidden layers in the value network
72+
self.fc_policy_layers = [128, 128] # Define the hidden layers in the policy network
7373

7474

75+
# Sampled
76+
self.sample_size = 50
77+
self.action_shape = [2]
78+
self.policy_distribution = torch.distributions.Categorical
7579

7680
### Training
7781
self.results_path = pathlib.Path(__file__).resolve().parents[1] / "results" / pathlib.Path(__file__).stem / datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S") # Path to store the model weights and TensorBoard logs
7882
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
7983
self.training_steps = 10000 # Total number of training steps (ie weights update according to a batch)
80-
self.batch_size = 128 # Number of parts of games to train on at each training step
84+
self.batch_size = 256 # Number of parts of games to train on at each training step
8185
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
8286
self.value_loss_weight = 1 # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze)
8387
self.train_on_gpu = torch.cuda.is_available() # Train on GPU if available

games/connect4.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy
55
import torch
6+
from torch.distributions import Categorical
67

78
from .abstract_game import AbstractGame
89

@@ -48,7 +49,7 @@ def __init__(self):
4849

4950

5051
### Network
51-
self.network = "resnet" # "resnet" / "fullyconnected"
52+
self.network = "sampled" # "resnet" / "fullyconnected"
5253
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward)))
5354

5455
# Residual Network
@@ -70,14 +71,17 @@ def __init__(self):
7071
self.fc_value_layers = [] # Define the hidden layers in the value network
7172
self.fc_policy_layers = [] # Define the hidden layers in the policy network
7273

73-
74+
# Sampled
75+
self.sample_size = 7
76+
self.action_shape = [7]
77+
self.policy_distribution = Categorical
7478

7579
### Training
7680
self.results_path = pathlib.Path(__file__).resolve().parents[1] / "results" / pathlib.Path(__file__).stem / datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S") # Path to store the model weights and TensorBoard logs
7781
self.save_model = True # Save the checkpoint in results_path as model.checkpoint
7882
self.training_steps = 100000 # Total number of training steps (ie weights update according to a batch)
7983
self.batch_size = 64 # Number of parts of games to train on at each training step
80-
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing
84+
self.checkpoint_interval = 200 # Number of training steps before using the model for self-playing
8185
self.value_loss_weight = 0.25 # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze)
8286
self.train_on_gpu = torch.cuda.is_available() # Train on GPU if available
8387

mcts.py

+8-21
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,19 @@ def run(
4545
.unsqueeze(0)
4646
.to(next(model.parameters()).device)
4747
)
48-
(
49-
root_predicted_value,
50-
reward,
51-
policy_logits,
52-
hidden_state,
53-
) = model.initial_inference(observation)
48+
root_predicted_value, reward, policy_parameters, hidden_state = model.initial_inference(observation)
5449
root_predicted_value = support_to_scalar(
5550
root_predicted_value, self.config.support_size
5651
).item()
5752
reward = support_to_scalar(reward, self.config.support_size).item()
5853
assert (
5954
legal_actions
6055
), f"Legal actions should not be an empty array. Got {legal_actions}."
61-
assert set(legal_actions).issubset(
62-
set(self.config.action_space)
63-
), "Legal actions should be a subset of the action space."
56+
sampled_actions = model.sample_actions(policy_parameters)
6457
root.expand(
65-
legal_actions,
58+
sampled_actions,
6659
to_play,
6760
reward,
68-
policy_logits,
6961
hidden_state,
7062
)
7163

@@ -98,17 +90,17 @@ def run(
9890
# Inside the search tree we use the dynamics function to obtain the next hidden
9991
# state given an action and the previous hidden state
10092
parent = search_path[-2]
101-
value, reward, policy_logits, hidden_state = model.recurrent_inference(
93+
value, reward, policy_parameters, hidden_state = model.recurrent_inference(
10294
parent.hidden_state,
10395
torch.tensor([[action]]).to(parent.hidden_state.device),
10496
)
97+
sampled_actions = model.sample_actions(policy_parameters)
10598
value = support_to_scalar(value, self.config.support_size).item()
10699
reward = support_to_scalar(reward, self.config.support_size).item()
107100
node.expand(
108-
self.config.action_space,
101+
sampled_actions,
109102
virtual_to_play,
110103
reward,
111-
policy_logits,
112104
hidden_state,
113105
)
114106

@@ -130,13 +122,8 @@ def select_child(self, node, min_max_stats):
130122
self.ucb_score(node, child, min_max_stats)
131123
for action, child in node.children.items()
132124
)
133-
action = numpy.random.choice(
134-
[
135-
action
136-
for action, child in node.children.items()
137-
if self.ucb_score(node, child, min_max_stats) == max_ucb
138-
]
139-
)
125+
actions = [action for action, child in node.children.items() if self.ucb_score(node, child, min_max_stats) == max_ucb]
126+
action = actions[numpy.random.choice(range(len(actions)))]
140127
return action, node.children[action]
141128

142129
def ucb_score(self, parent, child, min_max_stats):

models/muzero_network.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __new__(cls, config):
4242
config.action_shape,
4343
config.encoding_size,
4444
config.sample_size,
45-
config.blocks,
45+
config.policy_distribution,
4646
config.fc_reward_layers,
4747
config.fc_value_layers,
4848
config.fc_policy_layers,

0 commit comments

Comments
 (0)