-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
81 lines (70 loc) · 3.11 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from train import launch_params
import minerl
import gym
import argparse
from network.DQN import DQN, DoubleDQN, DQFD
from preprocess import create_actionspace
parser = argparse.ArgumentParser()
def launch_params():
######################### prepocess ############################
parser.add_argument('--ROOT',
help='root',
default = './')
parser.add_argument('--DATASET_LOC',
help='location of the dataset',
default = './data/MineRLTreechopVectorObf-v0')
parser.add_argument('--MODEL_SAVE',
help='location of the dataset',
default = './saved_network')
parser.add_argument('--ACTIONSPACE_TYPE',choices=['manually', 'k_means'],
help='way to define the actionsapce',
default='k_means')
parser.add_argument('--actionNum', type = int,
help='the number of discrete action combination',
default = 32)
##### prepare dataset
parser.add_argument('--PREPARE_DATASET',
help='if True, would automatically prepare dataset',
default=False)
######################### about RL training #####################
parser.add_argument('--env',
help='the environment for minerl to make',
default = 'MineRLTreechopVectorObf-v0')
parser.add_argument('--port',
help='the port to launch Minecraft',
default = 5656)
parser.add_argument('--device',
help='running device for training model',
default = 'cuda:0')
parser.add_argument('--dim_DQN_Qnet', type = int,
help='parameters for DQN-Qnet architecture',
default = 32)
parser.add_argument('--CONTINUOUS_FRAME', type = int,
help='number of continuous frame to be stacked together',
default = 1)
parser.add_argument('--mode',
help='mode should be train or evaluate',
default = 'evaluate')
parser.add_argument('--agentname',
help='mode should be train or evaluate',
default = 'best_model.pt')
parser.add_argument('--ACTION_UPDATE_INTERVAL', type = int,
help='step intervals between update action',
default = 3)
parser.add_argument('--EVALUATE_NUM', type = int,
help='step intervals between update action',
default = 40)
parser.add_argument('--EPSILON', type = float,
help='epsilon at the end of explore',
default = 0.1)
if __name__ == "__main__":
launch_params()
args = parser.parse_args()
## create action space
actionspace = create_actionspace(args)
## train network
env = gym.make(args.env)
env.make_interactive(port=args.port, realtime=True)
obs = env.reset()
net = DQFD(args, actionspace, env)
net.evaluate()