-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathenvironment.py
executable file
·145 lines (115 loc) · 3.23 KB
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright (c) 2020 DeNA Co., Ltd.
# Licensed under The MIT License [see LICENSE for details]
# game environment
import importlib
ENVS = {
'TicTacToe': 'handyrl.envs.tictactoe',
'Geister': 'handyrl.envs.geister',
'ParallelTicTacToe': 'handyrl.envs.parallel_tictactoe',
'HungryGeese': 'handyrl.envs.kaggle.hungry_geese',
}
def prepare_env(env_args):
env_name = env_args['env']
env_source = ENVS.get(env_name, env_name)
env_module = importlib.import_module(env_source)
if env_module is None:
print("No environment %s" % env_name)
elif hasattr(env_module, 'prepare'):
env_module.prepare()
def make_env(env_args):
env_name = env_args['env']
env_source = ENVS.get(env_name, env_name)
env_module = importlib.import_module(env_source)
if env_module is None:
print("No environment %s" % env_name)
else:
return env_module.Environment(env_args)
# base class of Environment
class BaseEnvironment:
def __init__(self, args={}):
pass
def __str__(self):
return ''
#
# Should be defined in all games
#
def reset(self, args={}):
raise NotImplementedError()
#
# Should be defined in all games except you implement original step() function
#
def play(self, action, player):
raise NotImplementedError()
#
# Should be defined in games which has simultaneous trainsition
#
def step(self, actions):
for p, action in actions.items():
if action is not None:
self.play(action, p)
#
# Should be defined if you use multiplayer sequential action game
#
def turn(self):
return 0
#
# Should be defined if you use multiplayer simultaneous action game
#
def turns(self):
return [self.turn()]
#
# Should be defined if there are other players besides the turn player
# who should observe the environment (mainly with RNNs)
#
def observers(self):
return []
#
# Should be defined in all games
#
def terminal(self):
raise NotImplementedError()
#
# Should be defined if you use immediate reward
#
def reward(self):
return {}
#
# Should be defined in all games
#
def outcome(self):
raise NotImplementedError()
#
# Should be defined in all games
#
def legal_actions(self, player):
raise NotImplementedError()
#
# Should be defined if you use multiplayer game or add name to each player
#
def players(self):
return [0]
#
# Should be defined in all games
#
def observation(self, player=None):
raise NotImplementedError()
#
# Should be defined if you encode action as special string
#
def action2str(self, a, player=None):
return str(a)
#
# Should be defined if you encode action as special string
#
def str2action(self, s, player=None):
return int(s)
#
# Should be defined if you use network battle mode
#
def diff_info(self, player=None):
return ''
#
# Should be defined if you use network battle mode
#
def update(self, info, reset):
raise NotImplementedError()