Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Refactor log #1130

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion blocks/log/log.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""The event-based main loop of Blocks."""
from abc import ABCMeta
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from numbers import Integral
from uuid import uuid4
Expand Down Expand Up @@ -101,6 +101,40 @@ def previous_row(self):
def last_epoch_row(self):
return self[self.status['_epoch_ends'][-1]]

@abstractmethod
def writer(self):
"""Creates a read+write log.

Returns
-------
manager : context manager
The log which will be opened in a read+write regime.

"""

@abstractmethod
def reader(self):
"""Creates a read-only log.

Returns
-------
manager : context manager
The log which will be opened in a read-only regime.

"""

def new_iteration(self):
"""Starts a new iteration."""
pass

def __enter__(self):
# Do nothing by default
return self

def __exit__(self, exc_type, exc_val, exc_tb):
# Do nothing by default
pass


class TrainingLog(defaultdict, TrainingLogBase):
"""Training log using a `defaultdict` as backend.
Expand Down Expand Up @@ -133,3 +167,9 @@ def __getitem__(self, time):
def __setitem__(self, time, value):
self._check_time(time)
return super(TrainingLog, self).__setitem__(time, value)

def writer(self):
return self

def reader(self):
return self
16 changes: 14 additions & 2 deletions blocks/log/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def __init__(self, database=None, **kwargs):
if database is None:
database = config.sqlite_database
self.database = database
self.conn = sqlite3.connect(database)
self.status = SQLiteStatus(self)
self.conn = sqlite3.connect(self.database)
sqlite3.register_adapter(numpy.ndarray, adapt_ndarray)
with self.conn:
self.conn.execute("""CREATE TABLE IF NOT EXISTS entries (
Expand All @@ -133,7 +134,6 @@ def __init__(self, database=None, **kwargs):
value,
PRIMARY KEY(uuid, "key")
);""")
self.status = SQLiteStatus(self)
super(SQLiteLog, self).__init__(**kwargs)

@property
Expand Down Expand Up @@ -176,6 +176,18 @@ def __len__(self):
"WHERE uuid IN ancestors ORDER BY time ASC", (self.h_uuid,)
).fetchone()[0]

def writer(self):
return self

def reader(self):
return self

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pass


class SQLiteStatus(MutableMapping):
def __init__(self, log):
Expand Down
80 changes: 43 additions & 37 deletions blocks/main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,12 @@ def run(self):
set(self.algorithm.parameters)):
logger.warning("different parameters for model and algorithm")

with change_recursion_limit(config.recursion_limit):
with change_recursion_limit(config.recursion_limit), \
self.log.writer() as log:
self.original_sigint_handler = signal.signal(
signal.SIGINT, self._handle_epoch_interrupt)
signal.SIGINT, self._handle_epoch_interrupt(log))
self.original_sigterm_handler = signal.signal(
signal.SIGTERM, self._handle_batch_interrupt)
signal.SIGTERM, self._handle_batch_interrupt(log))
try:
logger.info("Entered the main loop")
if not self.status['training_started']:
Expand All @@ -174,30 +175,30 @@ def run(self):
# We can not write "else:" here because extensions
# called "before_training" could have changed the status
# of the main loop.
if self.log.status['iterations_done'] > 0:
self.log.resume()
if log.status['iterations_done'] > 0:
log.resume()
self._run_extensions('on_resumption')
self.status['epoch_interrupt_received'] = False
self.status['batch_interrupt_received'] = False
with Timer('training', self.profile):
while self._run_epoch():
while self._run_epoch(log):
pass
except TrainingFinish:
self.log.current_row['training_finished'] = True
log.current_row['training_finished'] = True
except Exception as e:
self._restore_signal_handlers()
self.log.current_row['got_exception'] = traceback.format_exc()
logger.error("Error occured during training." + error_message)
log.current_row['got_exception'] = traceback.format_exc()
logger.error("Error occurred during training." + error_message)
try:
self._run_extensions('on_error')
except Exception:
logger.error(traceback.format_exc())
logger.error("Error occured when running extensions." +
logger.error("Error occurred when running extensions." +
error_in_error_handling_message)
reraise_as(e)
finally:
self._restore_signal_handlers()
if self.log.current_row.get('training_finished', False):
if log.current_row.get('training_finished', False):
self._run_extensions('after_training')
if config.profile:
self.profile.report()
Expand All @@ -218,7 +219,7 @@ def find_extension(self, name):
return unpack([extension for extension in self.extensions
if extension.name == name], singleton=True)

def _run_epoch(self):
def _run_epoch(self, log):
if not self.status.get('epoch_started', False):
try:
self.log.status['received_first_batch'] = False
Expand All @@ -229,17 +230,18 @@ def _run_epoch(self):
self.status['epoch_started'] = True
self._run_extensions('before_epoch')
with Timer('epoch', self.profile):
while self._run_iteration():
while self._run_iteration(log):
pass
self.status['epoch_started'] = False
self.status['epochs_done'] += 1
# Log might not allow mutating objects, so use += instead of append
self.status['_epoch_ends'] += [self.status['iterations_done']]
self._run_extensions('after_epoch')
self._check_finish_training('epoch')
self._check_finish_training('epoch', log)
return True

def _run_iteration(self):
def _run_iteration(self, log):
log.new_iteration()
try:
with Timer('read_data', self.profile):
batch = next(self.epoch_iterator)
Expand All @@ -253,7 +255,7 @@ def _run_iteration(self):
self.algorithm.process_batch(batch)
self.status['iterations_done'] += 1
self._run_extensions('after_batch', batch)
self._check_finish_training('batch')
self._check_finish_training('batch', log)
return True

def _run_extensions(self, method_name, *args):
Expand All @@ -262,7 +264,7 @@ def _run_extensions(self, method_name, *args):
with Timer(type(extension).__name__, self.profile):
extension.dispatch(CallbackName(method_name), *args)

def _check_finish_training(self, level):
def _check_finish_training(self, level, log):
"""Checks whether the current training should be terminated.

Parameters
Expand All @@ -275,32 +277,36 @@ def _check_finish_training(self, level):
# In case when keyboard interrupt is handled right at the end of
# the iteration the corresponding log record can be found only in
# the previous row.
if (self.log.current_row.get('training_finish_requested', False) or
if (log.current_row.get('training_finish_requested', False) or
self.status.get('batch_interrupt_received', False)):
raise TrainingFinish
if (level == 'epoch' and
self.status.get('epoch_interrupt_received', False)):
raise TrainingFinish

def _handle_epoch_interrupt(self, signal_number, frame):
# Try to complete the current epoch if user presses CTRL + C
logger.warning('Received epoch interrupt signal.' +
epoch_interrupt_message)
signal.signal(signal.SIGINT, self._handle_batch_interrupt)
self.log.current_row['epoch_interrupt_received'] = True
# Add a record to the status. Unlike the log record it will be
# easy to access at later iterations.
self.status['epoch_interrupt_received'] = True

def _handle_batch_interrupt(self, signal_number, frame):
# After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch
self._restore_signal_handlers()
logger.warning('Received batch interrupt signal.' +
batch_interrupt_message)
self.log.current_row['batch_interrupt_received'] = True
# Add a record to the status. Unlike the log record it will be
# easy to access at later iterations.
self.status['batch_interrupt_received'] = True
def _handle_epoch_interrupt(self, log):
def _handler(signal_number, frame):
# Try to complete the current epoch if user presses CTRL + C
logger.warning('Received epoch interrupt signal.' +
epoch_interrupt_message)
signal.signal(signal.SIGINT, self._handle_batch_interrupt(log))
log.current_row['epoch_interrupt_received'] = True
# Add a record to the status. Unlike the log record it will be
# easy to access at later iterations.
self.status['epoch_interrupt_received'] = True
return _handler

def _handle_batch_interrupt(self, log):
def _handler(signal_number, frame):
# After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch
self._restore_signal_handlers()
logger.warning('Received batch interrupt signal.' +
batch_interrupt_message)
log.current_row['batch_interrupt_received'] = True
# Add a record to the status. Unlike the log record it will be
# easy to access at later iterations.
self.status['batch_interrupt_received'] = True
return _handler

def _restore_signal_handlers(self):
signal.signal(signal.SIGINT, self.original_sigint_handler)
Expand Down
35 changes: 26 additions & 9 deletions tests/test_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from six.moves import cPickle

from blocks.main_loop import MainLoop
from blocks.log import TrainingLog
from blocks.extensions import TrainingExtension, FinishAfter, Printing
from blocks.utils import unpack
from blocks.config import config
Expand All @@ -33,11 +34,12 @@ def test_main_loop():
main_loop.run()
assert_raises(AttributeError, getattr, main_loop, 'model')

assert main_loop.log.status['iterations_done'] == 20
assert main_loop.log.status['_epoch_ends'] == [10, 20]
assert len(main_loop.log) == 20
for i in range(20):
assert main_loop.log[i + 1]['batch'] == {'data': i % 10}
with main_loop.log.reader() as log:
assert log.status['iterations_done'] == 20
assert log.status['_epoch_ends'] == [10, 20]
assert len(log) == 20
for i in range(20):
assert log[i + 1]['batch'] == {'data': i % 10}

config.profile = old_config_profile_value

Expand All @@ -50,7 +52,8 @@ def do_test(with_serialization):
extensions=[WriteBatchExtension(),
FinishAfter(after_n_batches=14)])
main_loop.run()
assert main_loop.log.status['iterations_done'] == 14
with main_loop.log.reader() as log:
assert log.status['iterations_done'] == 14

if with_serialization:
main_loop = cPickle.loads(cPickle.dumps(main_loop))
Expand All @@ -62,15 +65,29 @@ def do_test(with_serialization):
["after_batch"],
predicate=lambda log: log.status['iterations_done'] == 27)
main_loop.run()
assert main_loop.log.status['iterations_done'] == 27
assert main_loop.log.status['epochs_done'] == 2
with main_loop.log.reader() as log:
assert log.status['iterations_done'] == 27
assert log.status['epochs_done'] == 2
for i in range(27):
assert main_loop.log[i + 1]['batch'] == {"data": i % 10}
assert log[i + 1]['batch'] == {"data": i % 10}

do_test(False)
do_test(True)


def test_new_iteration_call():
data_stream = IterableDataset(range(10)).get_example_stream()
log = TrainingLog()

log.new_iteration = MagicMock()
main_loop = MainLoop(
MockAlgorithm(), data_stream, log=log,
extensions=[FinishAfter(after_n_epochs=1)])
main_loop.run()

assert log.new_iteration.call_args_list == [()]*11


def test_training_interrupt():
def process_batch(batch):
time.sleep(0.1)
Expand Down