-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
49 lines (39 loc) · 1.56 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import copy
import collections
import urllib
# Both of these are copied from the adversarial policies codebase
def update(d, u):
"""Recursive dictionary update."""
for k, v in u.items():
if isinstance(v, collections.Mapping):
d[k] = update(d.get(k, {}), v)
else:
d[k] = v
return d
def sacred_copy(o):
"""Perform a deep copy on nested dictionaries and lists.
If `d` is an instance of dict or list, copies `d` to a dict or list
where the values are recursively copied using `sacred_copy`. Otherwise, `d`
is copied using `copy.deepcopy`. Note this intentionally loses subclasses.
This is useful if e.g. `d` is a Sacred read-only dict. However, it can be
undesirable if e.g. `d` is an OrderedDict.
:param o: (object) if dict, copy recursively; otherwise, use `copy.deepcopy`.
:return A deep copy of d."""
if isinstance(o, dict):
return {k: sacred_copy(v) for k, v in o.items()}
elif isinstance(o, list):
return [sacred_copy(v) for v in o]
else:
return copy.deepcopy(o)
def detect_ec2():
"""Auto-detect if we are running on EC2."""
try:
EC2_ID_URL = 'http://169.254.169.254/latest/dynamic/instance-identity/document'
with urllib.request.urlopen(EC2_ID_URL, timeout=3) as f:
response = f.read().decode()
if 'availabilityZone' in response:
return True
else:
raise ValueError(f"Received unexpected response from '{EC2_ID_URL}'")
except urllib.error.URLError:
return False