-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdqn.py
93 lines (76 loc) · 3.63 KB
/
dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import random, math, torch, os
from network import QValue
class DQN():
def __init__ (self,config):
self.init_net(config)
def init_net(self,config):
self.predict_net = QValue(config).to(config.Env.device)
self.target_net = QValue(config).to(config.Env.device)
self.target_net.load_state_dict(self.predict_net.state_dict())
self.opt_lossf = self.init_opt_lossf(config)
self.init_memory(config)
self.step_num = 0
self.config = config
self.device = self.config.Env.device
def init_opt_lossf(self,config):
optimizer = config.Network.optimizer(self.predict_net.parameters(), lr=config.Network.LR, amsgrad=True)
criterion = config.Network.loss_function
return optimizer, criterion
def init_memory(self,config):
self.buffer = config.Buffer.types(config, self.predict_net, self.target_net, self.opt_lossf)
def action(self, state, phase = "train"):
if phase =="train":
sample = random.random()
eps_threshold = self.config.Network.EPS_END + (self.config.Network.EPS_START - self.config.Network.EPS_END) * \
math.exp(-1. * self.step_num / self.config.Network.EPS_DECAY)
self.step_num += 1
if sample > eps_threshold:
with torch.no_grad():
return self.predict_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor([[self.config.Env.action_space.sample()]], device=self.device, dtype=torch.long)
else:
with torch.no_grad():
return self.predict_net(state).max(1)[1].view(1, 1)
def training(self):
env = self.config.Env.make_env
for i_episode in range(self.config.Env.num_episodes):
state = env.reset()
state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
is_done = False
num_step = 0
while num_step <=self.config.Env.max_step or is_done !=True :
action = self.action(state)
# state, reward, done, {}
next_state, reward, done ,_= env.step(action.item())
reward = torch.tensor([reward], device=self.device)
next_state = torch.tensor(next_state, dtype=torch.float32, device=self.device).unsqueeze(0)
self.buffer.update(state, action, next_state, reward, done)
state = next_state
num_step +=1
if self.config.Env.training_render:
env.render()
if done:
is_done = True
break
print(f"[i_episode]:{i_episode}")
def testing(self):
env = self.config.Env.make_env
self.buffer.load()
for i_episode in range(self.config.Env.te_num_episodes):
state = env.reset()
state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
is_done = False
num_step = 0
while num_step <=self.config.Env.te_max_step or is_done !=True :
action = self.action(state,phase="test")
# state, reward, done, {}
next_state, reward, done ,_= env.step(action.item())
next_state = torch.tensor(next_state, dtype=torch.float32, device=self.device).unsqueeze(0)
state = next_state
num_step +=1
if done:
is_done = True
break
if self.config.Env.testing_render:
env.render()