Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for hanging on quit (fixes issues #4 and #21 ) #23

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion ga3c/ProcessAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import sys
if sys.version_info >= (3,0):
from queue import Empty
else:
from Queue import Empty

from datetime import datetime
from multiprocessing import Process, Queue, Value

Expand Down Expand Up @@ -72,7 +78,12 @@ def predict(self, state):
# put the state in the prediction q
self.prediction_q.put((self.id, state))
# wait for the prediction to come back
p, v = self.wait_q.get()
try:
p, v = self.wait_q.get(True, 10)
except Empty as e:
# Couldn't receive prediction in long time
# Predictors removed?
return None, None
return p, v

def select_action(self, prediction):
Expand All @@ -97,6 +108,10 @@ def run_episode(self):
continue

prediction, value = self.predict(self.env.current_state)
if prediction is None and value is None:
# Fatal error, couldn't get prediction from predictors
break

action = self.select_action(prediction)
reward, done = self.env.step(action)
reward_sum += reward
Expand Down
16 changes: 12 additions & 4 deletions ga3c/ProcessStats.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

import sys
if sys.version_info >= (3,0):
from queue import Queue as queueQueue
from queue import Empty, Queue as queueQueue
else:
from Queue import Queue as queueQueue
from Queue import Empty, Queue as queueQueue

from datetime import datetime
from multiprocessing import Process, Queue, Value
Expand All @@ -51,6 +51,8 @@ def __init__(self):
self.agent_count = Value('i', 0)
self.total_frame_count = 0

self.exit_flag = Value('i', 0)

def FPS(self):
# average FPS from the beginning of the training (not current FPS)
return np.ceil(self.total_frame_count / (time.time() - self.start_time))
Expand All @@ -67,8 +69,14 @@ def run(self):

self.start_time = time.time()
first_time = datetime.now()
while True:
episode_time, reward, length = self.episode_log_q.get()
episode_time = None
reward = None
length = None
while self.exit_flag.value == 0:
try:
episode_time, reward, length = self.episode_log_q.get(True, 0.001)
except Empty as e:
continue
results_logger.write('%s, %d, %d\n' % (episode_time.strftime("%Y-%m-%d %H:%M:%S"), reward, length))
results_logger.flush()

Expand Down
19 changes: 18 additions & 1 deletion ga3c/Server.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,16 @@ def remove_trainer(self):
self.trainers[-1].exit_flag = True
self.trainers[-1].join()
self.trainers.pop()

def remove_stats(self):
self.stats.exit_flag.value = True
self.stats.join()

def remove_adjustment(self):
self.dynamic_adjustment.exit_flag = True
self.dynamic_adjustment.join()


def train_model(self, x_, r_, a_, trainer_id):
self.model.train(x_, r_, a_, trainer_id)
self.training_step += 1
Expand Down Expand Up @@ -123,10 +132,18 @@ def main(self):

time.sleep(0.01)

self.dynamic_adjustment.exit_flag = True
# Remove dynamic adjustment first to avoid it creating new agents
# and whatnot
self.remove_adjustment()
# Set exit_flags already so each agent will stop when episode ends
# Without this removing all agents will take some time
for agent in self.agents:
agent.exit_flag.value = True
while self.agents:
self.remove_agent()
while self.predictors:
self.remove_predictor()
while self.trainers:
self.remove_trainer()
self.remove_stats()

24 changes: 21 additions & 3 deletions ga3c/ThreadPredictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import sys
if sys.version_info >= (3,0):
from queue import Empty
else:
from Queue import Empty

from threading import Thread

import numpy as np
Expand All @@ -47,12 +53,24 @@ def run(self):
dtype=np.float32)

while not self.exit_flag:
ids[0], states[0] = self.server.prediction_q.get()
try:
ids[0], states[0] = self.server.prediction_q.get(True, 0.001)
except Empty as e:
continue

size = 1
while size < Config.PREDICTION_BATCH_SIZE and not self.server.prediction_q.empty():
ids[size], states[size] = self.server.prediction_q.get()
size += 1
try:
ids[size], states[size] = self.server.prediction_q.get(True, 0.001)
size += 1
except Empty as e:
if self.exit_flag:
break

# Make sure we are not supposed to exit
# exit_flag could change at any point during above lines
if self.exit_flag:
break

batch = states[:size]
p, v = self.server.model.predict_p_and_v(batch)
Expand Down
15 changes: 14 additions & 1 deletion ga3c/ThreadTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import sys
if sys.version_info >= (3,0):
from queue import Empty
else:
from Queue import Empty

from threading import Thread
import numpy as np

Expand All @@ -43,7 +49,14 @@ def run(self):
while not self.exit_flag:
batch_size = 0
while batch_size <= Config.TRAINING_MIN_BATCH_SIZE:
x_, r_, a_ = self.server.training_q.get()
try:
x_, r_, a_ = self.server.training_q.get(True, 0.001)
except Empty as e:
# Check if trainer should quit
if self.exit_flag:
break
continue

if batch_size == 0:
x__ = x_; r__ = r_; a__ = a_
else:
Expand Down