forked from datawhalechina/joyrl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CartPole-v1_DQN_Test.py
43 lines (40 loc) · 1.93 KB
/
CartPole-v1_DQN_Test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# to test the trained model, you need to modify the "load_checkpoint=True" and the right "load_path" in the class GeneralConfig
import joyrl
class GeneralConfig():
def __init__(self) -> None:
self.env_name = "CartPole-v1" # name of environment
self.algo_name = "DQN" # name of algorithm
self.mode = "test" # train or test
self.seed = 1 # random seed
self.device = "cpu" # device to use
self.train_eps = 100 # number of episodes for training
self.test_eps = 20 # number of episodes for testing
self.eval_eps = 10 # number of episodes for evaluation
self.eval_per_episode = 5 # evaluation per episode
self.max_steps = 200 # max steps for each episode
self.load_checkpoint = True
self.load_path = "Train_CartPole-v1_DQN_20221120-000359" # path to load model
self.show_fig = False # show figure or not
self.save_fig = True # save figure or not
class AlgoConfig():
def __init__(self) -> None:
# set epsilon_start=epsilon_end can obtain fixed epsilon=epsilon_end
self.epsilon_start = 0.95 # epsilon start value
self.epsilon_end = 0.01 # epsilon end value
self.epsilon_decay = 500 # epsilon decay rate
self.gamma = 0.95 # discount factor
self.lr = 0.0001 # learning rate
self.buffer_size = 100000 # size of replay buffer
self.batch_size = 64 # batch size
self.target_update = 4 # target network update frequency
self.value_layers = [
{'layer_type': 'linear', 'layer_dim': ['n_states', 256],
'activation': 'relu'},
{'layer_type': 'linear', 'layer_dim': [256, 256],
'activation': 'relu'},
{'layer_type': 'linear', 'layer_dim': [256, 'n_actions'],
'activation': 'none'}]
if __name__ == "__main__":
general_cfg = GeneralConfig()
algo_cfg = AlgoConfig()
joyrl.run(general_cfg,algo_cfg)