-
Notifications
You must be signed in to change notification settings - Fork 5
/
agent.py
695 lines (592 loc) · 28.5 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
"""
Generic autonomous agents classes for automatic statistician
Created August 2014
@authors: James Robert Lloyd ([email protected])
"""
import copy
from multiprocessing import Process, Array, Value, get_logger, Queue
from Queue import Empty
import subprocess
import constants
import time
import cPickle as pickle
import numpy as np
import os
from collections import defaultdict
import psutil
import signal
import sys
import util
import tempfile
import logging
import sys
import global_data
# set up logging for the multiprocessing library
mlogger = get_logger()
mlogger.propagate = True
mlogger.setLevel(logging.DEBUG)
# set up logging for agent module
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
if constants.NUMPY_SAVE: # how to pickle numpy arrays
import copy_reg
def np_pickler(array):
"""Function for pickle to use when pickling numpy arrays"""
with tempfile.NamedTemporaryFile(mode='wb', suffix='_pickle.npy', dir=constants.SAVE_DIR, delete=False) as fp:
np.save(fp, array)
tempfl = fp.name
return np_unpickler, (tempfl,)
def np_unpickler(np_file):
"""Function for pickle to use when unpickling numpy arrays"""
array = np.load(np_file)
os.remove(np_file)
return array
# Register np array handlers for pickle
copy_reg.pickle(np.ndarray, np_pickler, np_unpickler)
def start_communication(agent=None, state_filename=None, cgroup=None, password=''):
signal.signal(signal.SIGTERM, signal.SIG_DFL) # if termination signal received, die
try:
if cgroup is not None:
cgclassify = "echo '{}' | sudo -S cgclassify -g memory:{} --sticky {}".format(
password, cgroup, os.getpid())
subprocess.call(cgclassify, shell=True)
if state_filename is not None:
agent.init_after_serialization(state_filename=state_filename)
agent.communicate()
except SystemExit as e:
os._exit(e.code)
except:
sys.excepthook(*sys.exc_info()) # use logging for exceptions
class DummyQueue(object): # has send and receive methods that do nothing
def put(self, *args, **kwargs):
pass
def get(self, *args, **kwargs):
raise Empty
def get_nowait(self, *args, **kwargs):
raise Empty
class DummyProcess(object):
def __init__(self, pid=65536, exitcode=-15):
self.pid = pid
self.exitcode = exitcode
def is_alive(self):
return False
def start(self):
raise AssertionError
def join(self):
pass
class TerminationEx(Exception):
pass
class SaveEx(Exception):
pass
class Agent(object):
def __init__(self,
inbox_conn=None, outbox_conn=None,
communication_sleep=1, child_timeout=10, name='', exp=None):
"""
Implements a basic communication and action loop
- Get incoming messages from parent
- Perform next action
- Send outgoing messages to parent
- Check to see if terminated
"""
self.saving_children = []
self.start_time = time.time()
if inbox_conn is None:
inbox_conn = DummyQueue()
if outbox_conn is None:
inbox_conn = DummyQueue()
self.inbox_conn = inbox_conn
self.outbox_conn = outbox_conn
self.inbox = []
self.child_processes = dict() # contains all created processes
self.conns_to_children = dict() # dead processes removed periodically by child_tidying
self.conns_from_children = dict() # ditto
self.child_inboxes = defaultdict(list)
self.child_serialization_filenames = dict() # contains all created processes
self.child_classes = dict() # has entry for every created process
self.child_kwargs = defaultdict(dict)
self.child_states = dict() # this is primarily for debugging purposes
# if you need to know the state of a process try to check directly
# 'terminated', 'finished', 'saved', 'saved unterminated',
# 'unknown', 'sleeping', 'stopped', 'running', 'unstarted'
# 'unknown' - dead with unrecognised exit code
# 'saved' means has also terminated. 'saved unterminated' - manager received save, but process not terminated
self.children_told_to_save = dict() # contains save start times
self.child_flags = dict() # for communication about save
self.last_child_started = None
self.last_child_start_time = None
self.flag = Value('i', 0) # value shared between child and parent to communicate about save
self.communication_sleep = communication_sleep
self.child_timeout = child_timeout
self.name = name
self.exp = exp
if hasattr(global_data, 'data'):
self.data = global_data.data
else:
self.data = None
self._namegen_count = -1 # used to save value of self.namegen - do not use this directly!
self.save_file = None
self.fa_completed = False # whether first action has been completed
self.state = ''
self.cgroup = None
self.password = ''
self.startable_children = set()
self.attributes_not_to_save = ['inbox_conn', 'outbox_conn', # 'child_processes',
'conns_to_children', 'conns_from_children', 'data',
'flag', 'child_flags']
def namegen(self):
self._namegen_count += 1
return self.name + '.' + str(self._namegen_count)
def load_file(self, datafile, array_name):
data = np.loadtxt(datafile, dtype=float, ndmin=2)
self.load_array(data, array_name)
def load_array(self, data, array_name):
# Array 'double' is the same as python 'float' (the default for numpy arrays), illogically
self.data[array_name] = np.frombuffer(Array('d', data.ravel()).get_obj())
self.data[array_name].shape = data.shape # will raise an error if array cannot be reshaped without copying
logger.info("%s: Loaded array %s", self.name, array_name)
def init_after_serialization(self, state_filename):
"""Load state from file and create children"""
with open(state_filename, 'rb') as pickle_file:
data = pickle.load(pickle_file) # data is an instance of the same class as self
self.__dict__.update(data.__dict__)
# Delete state file
os.remove(state_filename)
# Create child processes, pass on shared memory
self.start_children()
logger.info("%s: Initialised after serialization", self.name)
@property
def children(self):
return self.child_processes.keys()
def create_children(self, names=None, new_names=None, classes=None, start=False, use_cgroup=True):
"""If classes, create children from list of classes.
Classes can be a list, or a list of (class, kwargs) tuples
If names, revive named pickles or rerun child if pickle doesn't exist (i.e. save failed).
If neither, revive all pickled children."""
names_and_classes = dict()
if classes is None and names is None:
names = self.child_serialization_filenames.keys()
cgroup = None
if use_cgroup:
cgroup = self.cgroup
if classes is not None:
if type(classes[0]) is not tuple: # if there's just classes (no kwargs) then convert
classes = [(x, {}) for x in classes]
if new_names is None:
new_names = [self.namegen() for _ in range(len(classes))]
for name, a_class in zip(new_names, classes):
names_and_classes[name] = a_class
if names is not None:
for name in names:
exit_code = self.child_processes[name].exitcode
if exit_code is None:
if self.child_processes[name].is_alive():
logger.error("%s: Child %s has not terminated",
self.name, name)
continue
else:
logger.error("%s: Child %s has been created but not started",
self.name, name)
if exit_code < 0:
logger.warn("%s: Child %s exited with code %d and I won't restart it",
self.name, name, exit_code)
continue
elif exit_code == 1:
logger.info("%s: Child %s has a saved state", self.name, name)
names_and_classes[name] = (self.child_classes[name], self.child_kwargs[name])
elif exit_code == 0:
logger.warn("%s: Child %s terminated of its own accord and I won't restart it",
self.name, name)
continue
# Loop over names and classes, creating children
for name, (cl, kwargs) in names_and_classes.iteritems():
child = cl(**kwargs)
child.name = name
self.child_classes[name] = cl
self.child_kwargs[name] = kwargs
# Create conns
child.inbox_conn = self.conns_to_children[child.name] = Queue()
self.conns_from_children[child.name] = child.outbox_conn = Queue()
# Set save file
if name in self.child_processes and self.child_processes[name].exitcode == 1:
# child has been started before
pickle_file = child.save_file = self.child_serialization_filenames[name]
else:
root, ext = os.path.splitext(self.save_file)
child.save_file = self.child_serialization_filenames[name] = root + '_' + name + ext
pickle_file = None
# share communication value
child.flag = self.child_flags[name] = Value('i', 0)
# Create process
p = Process(target=start_communication,
kwargs=dict(agent=child, state_filename=pickle_file, cgroup=cgroup,
password=self.password))
p.name = name
logger.info("%s: Created child %s", self.name, name)
self.child_processes[child.name] = p
self.child_states[name] = 'unstarted'
del child
if start:
p.start()
self.startable_children.update(names_and_classes.keys())
return names_and_classes.keys() # returns list of created child names
def start_children(self, names=None):
# start or resume or recreate the child processes.
if names is None:
names = self.child_states.keys() # all the children ever
logger.debug("Starting %s", str(names))
successes = []
deadkids = []
started = []
for name in names:
assert name in self.child_states # check if it's a real child
dead = False
proc = self.child_processes[name]
if self.child_processes[name].pid is None: # has it been started?
self.child_processes[name].start()
logger.info("%s: Started %s", self.name, name)
successes.append(name)
started.append(name)
elif proc.is_alive(): # is it alive?
try:
# NB Can't test for aliveness with NoSuchProcess error, as pid might be reused
process = psutil.Process(pid=self.child_processes[name].pid)
with self.child_flags[name].get_lock():
if self.child_flags[name].value == 1: # FIXME - need informative flag values
self.child_flags[name].value = 0
logger.info("Child %s had been told to save. I have cancelled this instruction.", name)
successes.append(name)
continue
elif self.child_flags[name].value == 2:
logger.warn("Child %s is in the middle of saving and I won't start it", name)
continue
# We could only send resume to stopped children, but might introduce race conditions
try:
for child in process.children(recursive=True):
try:
child.resume()
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError):
pass
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError) as e:
logger.warn("Error %s getting children for resume for child %s", e.strerror, name)
process.resume()
logger.info("Resumed %s", name)
successes.append(name)
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError): # not alive
dead = True
else:
dead = True
if dead:
deadkids.append(name)
if deadkids:
newkids = self.create_children(names=deadkids, start=True)
successes.append(newkids)
if len(started) > 0:
self.last_child_started = started[-1]
self.last_child_start_time = time.time()
return successes
def pause_children(self, names=None):
# send pause signal to alive children
if names is None:
names = self.conns_from_children.keys()
logger.debug("Pausing %s", str(names))
for name in names:
assert name in self.child_states # check it's a real child
proc = self.child_processes[name]
if proc.is_alive():
try:
process = psutil.Process(pid=proc.pid)
try:
for child in process.children(recursive=True):
try:
child.send_signal(signal.SIGTSTP)
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError):
pass
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError) as e:
logger.warn("Error %s getting children for pause for child %s", e.strerror, name)
process.send_signal(signal.SIGTSTP)
logger.info("Paused %s", name)
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError): # child may have terminated
pass
def resume_children(self, names=None):
"""Resumes paused children. Shouldn't do anything to children that are not paused"""
if names is None:
names = self.conns_from_children.keys()
logger.debug("Resuming %s", str(names))
for name in names:
assert name in self.child_states # check this is a child's name
proc = self.child_processes[name]
if proc.is_alive():
try:
process = psutil.Process(pid=proc.pid)
try:
for child in process.children(recursive=True):
try:
child.resume()
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError):
pass
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError) as e:
logger.warn("Error %s getting children for resume for child %s", e.strerror, name)
process.resume()
logger.info("Resumed %s", name)
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError):
pass
def terminate_children(self, names=None, kill_unstarted=False):
# Terminate named children if they are alive or unstarted
if names is None:
names = self.conns_from_children.keys() # names not in this are definitely dead
logger.debug("Terminating %s", str(names))
for name in names:
assert name in self.child_states # check this is a child's name
if self.child_processes[name].pid is None:
if kill_unstarted:
self.child_processes[name] = DummyProcess(exitcode=-15)
self.conns_to_children.pop(name).close()
self.conns_from_children.pop(name).close()
self.child_states[name] = 'terminated'
logger.info("Terminated unstarted child %s", name)
else:
logger.info("Child %s is unstarted - pid %s", name, str(self.child_processes[name].pid))
else:
if self.child_processes[name].exitcode is not None:
logger.info("Child %s already dead - exitcode %s", name, str(self.child_processes[name].exitcode))
continue
try:
process = psutil.Process(pid=self.child_processes[name].pid)
try:
for proc in process.children(recursive=True):
try:
proc.resume()
proc.terminate()
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError):
pass
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError) as e:
logger.warn("Error %s getting children for terminate for child %s", e.strerror, name)
process.resume()
process.terminate()
logger.info("Terminated %s", name)
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError):
logger.info("Child %s already dead", name)
self.child_processes[name].join() # make sure it has time to respond to terminate signal
def signal_ignore(self, signum, frame):
logger.warn("Ignored signal %d", signum)
def send_to_children(self, message, names=None):
"""Send a message to all children"""
message = copy.deepcopy(message) # in case message changes before it's put on the pipe
if names is None:
names = self.conns_to_children.keys()
logger.debug('Sending to %s', str(names))
for name in names:
# check that process is not unstarted or terminated. Shouldn't send to dead processes
if self.child_processes[name].is_alive():
self.conns_to_children[name].put(message)
logger.debug("Sent to %s, message subject '%s'", name, message['subject'])
def send_to_parent(self, message):
message = copy.deepcopy(message)
self.outbox_conn.put(message)
logger.debug("Sent to parent, message subject '%s'", message['subject'])
def standard_responses(self, message):
# Some standard actions for messages from parent.
if message['subject'] == 'pause': # check for special subjects
self.pause()
elif message['subject'] == 'save and terminate':
self.save_file = message['filename']
raise SaveEx
elif message['subject'] == 'terminate':
raise TerminationEx
def get_parent_inbox(self):
"""Transfer items from inbox queue into local inbox"""
while True:
try:
message = self.inbox_conn.get_nowait()
self.inbox.append(message)
logger.debug("Message from parent, subject '%s'", message['subject'])
except Empty:
break
def child_tidying(self):
# Fill in child state
# Join dead process to prevent zombies
# Read last messages from those that have died a good death
# Close connections with children and remove from child processes
dead_children = set()
logger.debug("Tidying children")
names = self.conns_to_children.keys()
unstartable_children = []
self.saving_children = []
for name in names:
get_last_messages = False
dead = False
if self.child_processes[name].pid is None: # hasn't been started
self.child_states[name] = 'unstarted'
elif self.child_processes[name].is_alive():
try: # we don't really need to record this, but it's useful for debugging
p = psutil.Process(pid=self.child_processes[name].pid) # more accurate check for life
self.child_states[name] = p.status() # running, sleeping, stopped, disk sleep, zombie
except (psutil.NoSuchProcess, psutil.AccessDenied, IOError): # maybe it just died
dead = True
if self.child_flags[name].value == 3:
self.child_states[name] = "saved unterminated"
logger.warn("%s still hasn't terminated", name)
unstartable_children.append(name)
self.saving_children.append(name)
elif self.child_flags[name].value == 2:
self.child_states[name] = "saving now"
self.saving_children.append(name)
unstartable_children.append(name)
elif self.child_flags[name].value == 1:
self.child_states[name] = "told to save"
else:
dead = True
if dead:
dead_children.add(name)
if self.child_processes[name].exitcode == -15:
self.child_states[name] = 'terminated'
unstartable_children.append(name)
elif self.child_processes[name].exitcode == -9:
self.child_states[name] = 'killed'
unstartable_children.append(name)
elif self.child_processes[name].exitcode == 0:
self.child_states[name] = 'finished' # i.e finished of its own accord
get_last_messages = True
unstartable_children.append(name)
elif self.child_processes[name].exitcode == 1:
self.child_states[name] = 'saved'
self.startable_children.add(name)
get_last_messages = True
else:
self.child_states[name] = 'unknown %s' % str(self.child_processes[name].exitcode)
logger.error("Child %s died with exitcode %s", name, str(self.child_processes[name].exitcode))
unstartable_children.append(name)
if get_last_messages:
# get final messages
inbox_conn = self.conns_from_children[name]
while True:
try:
message = inbox_conn.get_nowait()
self.child_inboxes[name].append(message)
logger.debug("Message from %s, subject '%s'", name, message['subject'])
except Empty:
break
self.startable_children.difference_update(unstartable_children) # remove unstartable children
if len(self.child_states) > 0:
logger.info("Child states %s", str(self.child_states))
logger.info("Startable children: %s", str(self.startable_children))
logger.info("Children currently saving: %s", str(self.saving_children))
if len(dead_children) > 0:
logger.debug("Dead children %s", str(dead_children))
for name in dead_children:
self.child_processes[name].join()
self.conns_to_children.pop(name).close()
self.conns_from_children.pop(name).close()
self.children_told_to_save.pop(name, None)
def get_child_inboxes(self, names=None):
# Get messages from children
if names is None:
names = self.conns_from_children.keys()
for name in names:
assert name in self.child_states # otherwise this child doesn't exist
if self.child_processes[name].is_alive():
inbox_conn = self.conns_from_children[name]
while True:
try:
message = inbox_conn.get_nowait()
self.child_inboxes[name].append(message)
logger.debug("Message from %s, subject '%s'", name, message['subject'])
except Empty:
break
def first_action(self):
"""Actions to be performed when first created"""
pass
def next_action(self):
"""Inspect messages and state and perform next action checking if process stopped or paused"""
pass
def tidy_up(self):
"""Run anything pertinent before termination"""
self.terminate_children()
def __getstate__(self):
odict = self.__dict__.copy() # copy the dict since we change it
for key in self.attributes_not_to_save:
del odict[key] # remove filehandle entry
for name, proc in self.child_processes.items(): # TODO: fix hack
self.child_processes[name] = DummyProcess(pid=proc.pid, exitcode=proc.exitcode)
return odict
def save(self):
"""Save important things to file"""
self.flag.value = 2
logger.info("Saving")
# Ignore sigtstp messages from now on:
signal.signal(signal.SIGTSTP, self.signal_ignore)
# Saving only happens on the first CPU
p = psutil.Process()
current_cpus = p.cpu_affinity()
if len(current_cpus) > 1:
p.cpu_affinity([current_cpus[0]])
# Save children
self.save_children(save_timeout=300)
# Save myself
with open(self.save_file, 'wb') as pickle_file:
pickle.dump(self, pickle_file, pickle.HIGHEST_PROTOCOL)
logger.info("Completed saving")
self.flag.value = 3
def save_children(self, names=None, save_timeout=300):
"""Save important things to file"""
if names is None:
names = self.conns_to_children.keys()
logger.info("Saving children %s", str(names))
self.start_saving_children(names)
while any(i in self.children_told_to_save.keys() for i in names):
time.sleep(1)
self.finish_saving_children(save_timeout=save_timeout)
# children are guaranteed to be dead after this
logger.info("Saved children %s", str(names))
def start_saving_children(self, names=None):
"""Sends save message to children"""
if names is None:
names = self.conns_to_children.keys()
logger.info("Starting saving of %s", str(names))
self.resume_children(names=names)
for name in names:
if self.child_processes[name].is_alive() is False:
continue
if self.child_flags[name].value == 0:
self.child_flags[name].value = 1
logger.info("Told %s to save", name)
self.children_told_to_save[name] = time.time()
def finish_saving_children(self, save_timeout=300):
"""Checks if children are still alive and terminates them if so"""
for name, child_save_time in self.children_told_to_save.items():
if self.child_processes[name].is_alive() is False:
self.children_told_to_save.pop(name)
elif time.time() - child_save_time > save_timeout:
logger.warn("Terminating %s - save took too long", name)
self.terminate_children(names=[name])
self.children_told_to_save.pop(name)
def pause(self):
self.pause_children()
logger.info("Pausing myself")
signal.pause()
def perform_communication_cycle(self):
time0 = time.time()
self.child_tidying()
self.get_child_inboxes()
self.get_parent_inbox()
self.next_action()
time.sleep(self.communication_sleep)
logger.debug("Communication cycle took %.1f seconds", time.time() - time0)
if self.flag.value == 1:
raise SaveEx
def communicate(self):
"""Receive incoming messages, perform actions as appropriate and send outgoing messages"""
try:
if self.fa_completed is False:
self.first_action()
self.fa_completed = True
logger.debug("Completed first action")
while True:
self.perform_communication_cycle()
except SaveEx:
self.save()
sys.exit(1) # exit with exit code 1
except TerminationEx:
logger.debug("Caught TerminationEx")
self.tidy_up()