-
Notifications
You must be signed in to change notification settings - Fork 11
/
model_vs_model.py
82 lines (59 loc) · 2.71 KB
/
model_vs_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Pit two models together on NHL 94
"""
import retro
import sys
import argparse
import logging
import numpy as np
import pygame
from common import com_print, init_logger
from envs import init_env, init_play_env
from models import init_model, get_model_probabilities, get_num_parameters
def parse_cmdline(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--p1_alg', type=str, default='ppo2')
parser.add_argument('--p2_alg', type=str, default='ppo2')
parser.add_argument('--nn', type=str, default='CnnPolicy')
parser.add_argument('--model1_desc', type=str, default='CNN')
parser.add_argument('--model2_desc', type=str, default='MLP')
parser.add_argument('--env', type=str, default='NHL941on1-Genesis')
parser.add_argument('--state', type=str, default=None)
parser.add_argument('--num_players', type=int, default='2')
parser.add_argument('--num_env', type=int, default=1)
parser.add_argument('--num_timesteps', type=int, default=0)
parser.add_argument('--output_basedir', type=str, default='~/OUTPUT')
parser.add_argument('--load_p1_model', type=str, default='')
parser.add_argument('--load_p2_model', type=str, default='')
parser.add_argument('--display_width', type=int, default='1440')
parser.add_argument('--display_height', type=int, default='810')
parser.add_argument('--deterministic', default=True, action='store_true')
parser.add_argument('--rf', type=str, default='')
args = parser.parse_args(argv)
return args
def main(argv):
args = parse_cmdline(argv[1:])
logger = init_logger(args)
com_print('========= Init =============')
play_env = init_play_env(args, 2, True)
p1_env = init_env(None, 1, None, 1, args, use_sticky_action=False)
p2_env = init_env(None, 1, None, 1, args, use_sticky_action=False)
p1_model = init_model(None, args.load_p1_model, args.p1_alg, args, p1_env, logger)
p2_model = init_model(None, args.load_p2_model, args.p2_alg, args, p2_env, logger)
play_env.model1_params = get_num_parameters(p1_model)
play_env.model2_params = get_num_parameters(p2_model)
com_print('========= Start Play Loop ==========')
state = play_env.reset()
p1_actions = []
p2_actions = []
while True:
p1_actions = p1_model.predict(state)
p2_actions = p2_model.predict(state)
play_env.p1_action_probabilities = get_model_probabilities(p1_model, state)[0]
play_env.p2_action_probabilities = get_model_probabilities(p2_model, state)[0]
actions2 = np.append(p1_actions[0], p2_actions[0])
state, reward, done, info = play_env.step([actions2])
if done:
state = play_env.reset()
if __name__ == '__main__':
main(sys.argv)