diff --git a/pyulog/core.py b/pyulog/core.py index 6069325..0ad9896 100644 --- a/pyulog/core.py +++ b/pyulog/core.py @@ -482,6 +482,19 @@ def _make_changed_param_items(self): return changed_param_items + def __eq__(self, other): + """ + If the other object has all the same data as we have, we want to + consider them equal, even if the other object has extra fields, because + the user cares about the ULog contents. + """ + if not isinstance(other, ULog): + return NotImplemented + return all( + self_value == getattr(other, field) + for field, self_value in self.__dict__.items() + ) + class Data(object): """ contains the final topic data for a single topic and instance """ diff --git a/pyulog/db.py b/pyulog/db.py index 0c40b34..dc0434b 100644 --- a/pyulog/db.py +++ b/pyulog/db.py @@ -177,6 +177,16 @@ def __init__(self, db_handle, primary_key=None, log_file=None, lazy=True, **kwar if primary_key is not None: self.load(lazy=lazy) + def __eq__(self, other): + """ + If the other object is a normal ULog, then we just want to compare ULog + data, not DatabaseULog specific fields, because we want to compare + theULog file contents. + """ + if type(other) is ULog: # pylint: disable=unidiomatic-typecheck + return other.__eq__(self) + return super().__eq__(other) + @property def primary_key(self): '''The primary key of the ulog, pointing to the correct "ULog" row in the database.''' diff --git a/test/test_db.py b/test/test_db.py index 53c2429..9da0599 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -54,8 +54,7 @@ def test_parsing(self, test_case): dbulog_saved.save() primary_key = dbulog_saved.primary_key dbulog_loaded = DatabaseULog(self.db_handle, primary_key=primary_key, lazy=False) - for ulog_key, ulog_value in ulog.__dict__.items(): - self.assertEqual(ulog_value, getattr(dbulog_loaded, ulog_key)) + self.assertEqual(ulog, dbulog_loaded) def test_lazy(self): ''' diff --git a/test/test_ulog.py b/test/test_ulog.py index 85f95c3..26d5ecf 100644 --- a/test/test_ulog.py +++ b/test/test_ulog.py @@ -20,6 +20,22 @@ class TestULog(unittest.TestCase): Tests the ULog class ''' + @data('sample') + def test_comparison(self, base_name): + ''' + Test that the custom comparison method works as expected. + ''' + ulog_file_name = os.path.join(TEST_PATH, base_name + '.ulg') + ulog1 = pyulog.ULog(ulog_file_name) + ulog2 = pyulog.ULog(ulog_file_name) + assert ulog1 == ulog2 + assert ulog1 is not ulog2 + + # make them different in arbitrary field + ulog1.data_list[0].data['timestamp'][0] += 1 + assert ulog1 != ulog2 + + @data('sample', 'sample_appended', 'sample_appended_multiple', @@ -36,21 +52,11 @@ def test_write_ulog(self, base_name): original.write_ulog(written_ulog_file_name) copied = pyulog.ULog(written_ulog_file_name) - for original_key, original_value in original.__dict__.items(): - copied_value = getattr(copied, original_key) - if original_key == '_sync_seq_cnt': - # Sync messages are counted on parse, but otherwise dropped, so - # we don't rewrite them - assert copied_value == 0 - elif original_key == '_appended_offsets': - # Abruptly ended messages just before offsets are dropped, so - # we don't rewrite appended offsets - assert copied_value == [] - elif original_key == '_incompat_flags': - # Same reasoning on incompat_flags[0] as for '_appended_offsets' - assert copied_value[0] == original_value[0] & 0xFE # pylint: disable=unsubscriptable-object - assert copied_value[1:] == original_value[1:] # pylint: disable=unsubscriptable-object - else: - assert copied_value == original_value + # Some fields are not copied but dropped, so we cheat by modifying the original + original._sync_seq_cnt = 0 # pylint: disable=protected-access + original._appended_offsets = [] # pylint: disable=protected-access + original._incompat_flags[0] &= 0xFE # pylint: disable=protected-access + + assert copied == original # vim: set et fenc=utf-8 ft=python ff=unix sts=4 sw=4 ts=4