-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenv_wrappers.py
112 lines (96 loc) · 3.49 KB
/
env_wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import time
import numpy as np
import torch
from gym import spaces
from inspect import getargspec
class GymWrapper(object):
'''
for multi-agent
'''
def __init__(self, env):
self.env = env
@property
def observation_dim(self):
'''
for multi-agent, this is the obs per agent
'''
# tuple space
if hasattr(self.env.observation_space, 'spaces'):
# Return whether the object has an attribute with the given name.
total_obs_dim = 0
for space in self.env.observation_space.spaces:
if hasattr(self.env.action_space, 'shape'):
# 淦:The product of an empty array is the neutral element 1。实际是1+1+59 = 61
total_obs_dim += int(np.prod(space.shape))
else:
total_obs_dim += 1
return total_obs_dim
else:
return int(np.prod(self.env.observation_space.shape))
@property
def num_actions(self):
if hasattr(self.env.action_space, 'nvec'):
# MultiDiscrete
return int(self.env.action_space.nvec[0])
elif hasattr(self.env.action_space, 'n'):
# Discrete,traffic junction属于这种情况,只有两种动作
return self.env.action_space.n
@property
def dim_actions(self):
# for multi-agent, this is the number of action per agent
if hasattr(self.env.action_space, 'nvec'):
# MultiDiscrete
return self.env.action_space.shape[0]
# return len(self.env.action_space.shape)
elif hasattr(self.env.action_space, 'n'):
# Discrete => only 1 action takes place at a time.
return 1
@property
def action_space(self):
return self.env.action_space
def reset(self, epoch):
reset_args = getargspec(self.env.reset).args
# Get the names and default values of a function's parameters.
if 'epoch' in reset_args:
obs = self.env.reset(epoch)
else:
obs = self.env.reset()
obs = self._flatten_obs(obs)
return obs
def display(self):
self.env.render()
time.sleep(0.5)
def end_display(self):
self.env.exit_render()
def step(self, action):
# TODO: Modify all environments to take list of action
# instead of doing this
if self.dim_actions == 1:
action = action[0]
obs, r, done, info = self.env.step(action)
obs = self._flatten_obs(obs)
return (obs, r, done, info)
def reward_terminal(self):
if hasattr(self.env, 'reward_terminal'):
return self.env.reward_terminal()
else:
return np.zeros(1)
def _flatten_obs(self, obs):
if isinstance(obs, tuple):
_obs=[]
for agent in obs: #list/tuple of observations.
ag_obs = []
for obs_kind in agent:
ag_obs.append(np.array(obs_kind).flatten())
_obs.append(np.concatenate(ag_obs))
obs = np.stack(_obs)
obs = obs.reshape(1, -1, self.observation_dim)
obs = torch.from_numpy(obs).double()
return obs
def get_stat(self):
if hasattr(self.env, 'stat'):
self.env.stat.pop('steps_taken', None) # pop(key[,default]): 删除字典键 key 和其对应的值
# stat: 'success'; 'add_rate'
return self.env.stat
else:
return dict()