diff --git a/ga3c/ProcessAgent.py b/ga3c/ProcessAgent.py index a9b2701..14a6ed5 100644 --- a/ga3c/ProcessAgent.py +++ b/ga3c/ProcessAgent.py @@ -24,6 +24,12 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import sys +if sys.version_info >= (3,0): + from queue import Empty +else: + from Queue import Empty + from datetime import datetime from multiprocessing import Process, Queue, Value @@ -72,7 +78,12 @@ def predict(self, state): # put the state in the prediction q self.prediction_q.put((self.id, state)) # wait for the prediction to come back - p, v = self.wait_q.get() + try: + p, v = self.wait_q.get(True, 10) + except Empty as e: + # Couldn't receive prediction in long time + # Predictors removed? + return None, None return p, v def select_action(self, prediction): @@ -97,6 +108,10 @@ def run_episode(self): continue prediction, value = self.predict(self.env.current_state) + if prediction is None and value is None: + # Fatal error, couldn't get prediction from predictors + break + action = self.select_action(prediction) reward, done = self.env.step(action) reward_sum += reward diff --git a/ga3c/ProcessStats.py b/ga3c/ProcessStats.py index 937e0fb..92bf2b4 100644 --- a/ga3c/ProcessStats.py +++ b/ga3c/ProcessStats.py @@ -26,9 +26,9 @@ import sys if sys.version_info >= (3,0): - from queue import Queue as queueQueue + from queue import Empty, Queue as queueQueue else: - from Queue import Queue as queueQueue + from Queue import Empty, Queue as queueQueue from datetime import datetime from multiprocessing import Process, Queue, Value @@ -51,6 +51,8 @@ def __init__(self): self.agent_count = Value('i', 0) self.total_frame_count = 0 + self.exit_flag = Value('i', 0) + def FPS(self): # average FPS from the beginning of the training (not current FPS) return np.ceil(self.total_frame_count / (time.time() - self.start_time)) @@ -67,8 +69,14 @@ def run(self): self.start_time = time.time() first_time = datetime.now() - while True: - episode_time, reward, length = self.episode_log_q.get() + episode_time = None + reward = None + length = None + while self.exit_flag.value == 0: + try: + episode_time, reward, length = self.episode_log_q.get(True, 0.001) + except Empty as e: + continue results_logger.write('%s, %d, %d\n' % (episode_time.strftime("%Y-%m-%d %H:%M:%S"), reward, length)) results_logger.flush() diff --git a/ga3c/Server.py b/ga3c/Server.py index 28d8a46..50c7272 100644 --- a/ga3c/Server.py +++ b/ga3c/Server.py @@ -84,7 +84,16 @@ def remove_trainer(self): self.trainers[-1].exit_flag = True self.trainers[-1].join() self.trainers.pop() + + def remove_stats(self): + self.stats.exit_flag.value = True + self.stats.join() + + def remove_adjustment(self): + self.dynamic_adjustment.exit_flag = True + self.dynamic_adjustment.join() + def train_model(self, x_, r_, a_, trainer_id): self.model.train(x_, r_, a_, trainer_id) self.training_step += 1 @@ -123,10 +132,18 @@ def main(self): time.sleep(0.01) - self.dynamic_adjustment.exit_flag = True + # Remove dynamic adjustment first to avoid it creating new agents + # and whatnot + self.remove_adjustment() + # Set exit_flags already so each agent will stop when episode ends + # Without this removing all agents will take some time + for agent in self.agents: + agent.exit_flag.value = True while self.agents: self.remove_agent() while self.predictors: self.remove_predictor() while self.trainers: self.remove_trainer() + self.remove_stats() + diff --git a/ga3c/ThreadPredictor.py b/ga3c/ThreadPredictor.py index 38c9ed1..3bbb14c 100644 --- a/ga3c/ThreadPredictor.py +++ b/ga3c/ThreadPredictor.py @@ -24,6 +24,12 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import sys +if sys.version_info >= (3,0): + from queue import Empty +else: + from Queue import Empty + from threading import Thread import numpy as np @@ -47,12 +53,24 @@ def run(self): dtype=np.float32) while not self.exit_flag: - ids[0], states[0] = self.server.prediction_q.get() + try: + ids[0], states[0] = self.server.prediction_q.get(True, 0.001) + except Empty as e: + continue size = 1 while size < Config.PREDICTION_BATCH_SIZE and not self.server.prediction_q.empty(): - ids[size], states[size] = self.server.prediction_q.get() - size += 1 + try: + ids[size], states[size] = self.server.prediction_q.get(True, 0.001) + size += 1 + except Empty as e: + if self.exit_flag: + break + + # Make sure we are not supposed to exit + # exit_flag could change at any point during above lines + if self.exit_flag: + break batch = states[:size] p, v = self.server.model.predict_p_and_v(batch) diff --git a/ga3c/ThreadTrainer.py b/ga3c/ThreadTrainer.py index 4e364ad..26d2694 100644 --- a/ga3c/ThreadTrainer.py +++ b/ga3c/ThreadTrainer.py @@ -24,6 +24,12 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import sys +if sys.version_info >= (3,0): + from queue import Empty +else: + from Queue import Empty + from threading import Thread import numpy as np @@ -43,7 +49,14 @@ def run(self): while not self.exit_flag: batch_size = 0 while batch_size <= Config.TRAINING_MIN_BATCH_SIZE: - x_, r_, a_ = self.server.training_q.get() + try: + x_, r_, a_ = self.server.training_q.get(True, 0.001) + except Empty as e: + # Check if trainer should quit + if self.exit_flag: + break + continue + if batch_size == 0: x__ = x_; r__ = r_; a__ = a_ else: