-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
142 lines (122 loc) · 6.12 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import argparse
import gym
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os, sys
from wrappers import make_atari, wrap_deepmind, wrap_pytorch
from hyperparameters import get_hyperparameters
def parse_input():
parser = argparse.ArgumentParser()
parser.add_argument('dqn_type', type=str, help='[dqn, double, dueling, prioritize, noisy]')
parser.add_argument('game', type=str, help='[cartpole, breakout, pong]')
args = parser.parse_args()
return args.dqn_type, args.game
def train(env, model, tmodel, buffer, optimizer, hyperparameters):
losses = []
all_rewards = []
episode_reward = 0
state, _ = env.reset()
for frame_idx in tqdm(range(1, hyperparameters["num_frames"] + 1)):
# select an action
if hyperparameters["dqn_type"] != "noisy":
epsilon = hyperparameters["epsilon_func"](frame_idx)
action = model.act(state, epsilon)
else:
action = model.act(state)
# save transition into buffer
next_state, reward, done, _, _ = env.step(action)
buffer.push(state, action, reward, next_state, done)
state = next_state
episode_reward += reward
if done:
# reset the environment, record total reward
state, _ = env.reset()
all_rewards.append(episode_reward)
episode_reward = 0
if len(buffer) >= hyperparameters["train_initial"]:
# start training after collected enough samples
if tmodel:
loss = hyperparameters["loss_func"](hyperparameters["num_frames"], hyperparameters["batch_size"], hyperparameters["gamma"], model, tmodel, buffer, optimizer)
else:
loss = hyperparameters["loss_func"](hyperparameters["num_frames"], hyperparameters["batch_size"], hyperparameters["gamma"], model, buffer, optimizer)
losses.append(loss.detach().item())
if frame_idx % 1000 == 0 and tmodel:
tmodel.load_state_dict(model.state_dict())
if frame_idx % 50000 == 0:
dqn_type, game = hyperparameters["dqn_type"], hyperparameters["game"]
if not os.path.exists(f"modelstats/{dqn_type}_{game}"):
os.mkdir(f"modelstats/{dqn_type}_{game}")
torch.save(model.state_dict(), f'modelstats/{dqn_type}_{game}/{game}_{frame_idx}_frame_{dqn_type}.pt')
return losses, all_rewards
def plot(losses, rewards, path):
plt.figure(figsize=(20,5))
plt.subplot(131)
plt.title('reward')
plt.plot(rewards)
plt.subplot(132)
plt.title('loss')
plt.plot(losses)
plt.savefig(path)
def main():
dqn_type, game = parse_input()
hyperparameters = get_hyperparameters(dqn_type, game)
if game == "cartpole":
# demo
env = gym.make("CartPole-v1")
model = hyperparameters["Model"](env.observation_space.shape[0], env.action_space.n)
tmodel = None
else:
# select game
if game == "pong":
env = wrap_pytorch(wrap_deepmind(make_atari("PongNoFrameskip-v4")))
elif game == "breakout":
env = wrap_pytorch(wrap_deepmind(make_atari("BreakoutNoFrameskip-v4")))
else:
print(f"{game} not supported", file=sys.stderr)
# select model
if dqn_type == "dqn":
model = hyperparameters["Model"](env.observation_space.shape, env.action_space.n)
tmodel = None
if torch.cuda.is_available():
model = model.cuda()
replay_buffer = hyperparameters["Buffer"](hyperparameters["buffer_size"])
elif dqn_type == "double":
model, tmodel = hyperparameters["Model"](env.observation_space.shape, env.action_space.n), hyperparameters["Model"](env.observation_space.shape, env.action_space.n)
tmodel.load_state_dict(model.state_dict())
if torch.cuda.is_available():
model, tmodel = model.cuda(), tmodel.cuda()
replay_buffer = hyperparameters["Buffer"](hyperparameters["buffer_size"])
elif dqn_type == "dueling":
# use double td_loss train dueling dqn
model, tmodel = hyperparameters["Model"](env.observation_space.shape, env.action_space.n), hyperparameters["Model"](env.observation_space.shape, env.action_space.n)
tmodel.load_state_dict(model.state_dict())
if torch.cuda.is_available():
model, tmodel = model.cuda(), tmodel.cuda()
replay_buffer = hyperparameters["Buffer"](hyperparameters["buffer_size"])
elif dqn_type == "prioritized":
model, tmodel = hyperparameters["Model"](env.observation_space.shape, env.action_space.n), hyperparameters["Model"](env.observation_space.shape, env.action_space.n)
tmodel.load_state_dict(model.state_dict())
if torch.cuda.is_available():
model, tmodel = model.cuda(), tmodel.cuda()
replay_buffer = hyperparameters["Buffer"](hyperparameters["buffer_size"], hyperparameters["beta_func"])
elif dqn_type == "noisy":
# use double td_loss train dueling dqn
model, tmodel = hyperparameters["Model"](env.observation_space.shape, env.action_space.n), hyperparameters["Model"](env.observation_space.shape, env.action_space.n)
tmodel.load_state_dict(model.state_dict())
if torch.cuda.is_available():
model, tmodel = model.cuda(), tmodel.cuda()
replay_buffer = hyperparameters["Buffer"](hyperparameters["buffer_size"])
optimizer = optim.Adam(model.parameters(), lr=hyperparameters["lr"])
losses, rewards = train(env, model, tmodel, replay_buffer, optimizer, hyperparameters)
losses, rewards = np.array(losses), np.array(rewards)
# save result
if not os.path.exists(f"results/{dqn_type}_{game}"):
os.mkdir(f"results/{dqn_type}_{game}")
plot(losses, rewards, f"results/{dqn_type}_{game}/reward_and_loss.png")
torch.save(model.state_dict(), f'results/{dqn_type}_{game}/model.pt')
np.savez(f"results/{dqn_type}_{game}/loss_and_reward.npz", loss=losses, reward=rewards)
if __name__ == "__main__":
main()