diff --git a/node/transport_layer_test.py b/test/test_transport.py similarity index 55% rename from node/transport_layer_test.py rename to test/test_transport.py index d2a75b939..66c085d39 100644 --- a/node/transport_layer_test.py +++ b/test/test_transport.py @@ -1,89 +1,82 @@ import json import unittest -import mock -from transport import TransportLayer -import protocol - -# Test the callback features of the TransportLayer class -class TestTransportLayerCallbacks(unittest.TestCase): - one_called = False - two_called = False - three_called = False +import mock - def _callback_one(self, arg): - self.assertFalse(self.one_called) - self.one_called = True +from node import protocol, transport - def _callback_two(self, arg): - self.assertFalse(self.two_called) - self.two_called = True - def _callback_three(self, arg): - self.assertFalse(self.three_called) - self.three_called = True +class TestTransportLayerCallbacks(unittest.TestCase): + """Test the callback features of the TransportLayer class.""" def setUp(self): - self.tl = TransportLayer(1, 'localhost', None, 1) - self.tl.add_callback('section_one', self._callback_one) - self.tl.add_callback('section_one', self._callback_two) - self.tl.add_callback('all', self._callback_three) + self.callback1 = mock.Mock() + self.callback2 = mock.Mock() + self.callback3 = mock.Mock() + + self.tl = transport.TransportLayer(1, 'localhost', None, 1) + self.tl.add_callback('section_one', self.callback1) + self.tl.add_callback('section_one', self.callback2) + self.tl.add_callback('all', self.callback3) def _assert_called(self, one, two, three): - self.assertEqual(self.one_called, one) - self.assertEqual(self.two_called, two) - self.assertEqual(self.three_called, three) + self.assertEqual(self.callback1.call_count, one) + self.assertEqual(self.callback2.call_count, two) + self.assertEqual(self.callback3.call_count, three) def test_fixture(self): - self._assert_called(False, False, False) + self._assert_called(0, 0, 0) def test_callbacks(self): self.tl.trigger_callbacks('section_one', None) - self._assert_called(True, True, True) + self._assert_called(1, 1, 1) def test_all_callback(self): self.tl.trigger_callbacks('section_with_no_register', None) - self._assert_called(False, False, True) + self._assert_called(0, 0, 1) def test_explicit_all_section(self): self.tl.trigger_callbacks('all', None) - self._assert_called(False, False, True) + self._assert_called(0, 0, 1) class TestTransportLayerMessageHandling(unittest.TestCase): + def setUp(self): - self.tl = TransportLayer(1, 'localhost', None, 1) + self.tl = transport.TransportLayer(1, 'localhost', None, 1) - # The ok message should not trigger any callbacks def test_on_message_ok(self): + """OK message should trigger no callbacks.""" self.tl.trigger_callbacks = mock.MagicMock( side_effect=AssertionError() ) self.tl._on_message(protocol.ok()) - # Any non-ok message should cause trigger_callbacks to be called with - # the type of message and the message object (dict) def test_on_message_not_ok(self): + """ + Any non-OK message should cause trigger_callbacks to be called with + the type of message and the message object (dict). + """ data = protocol.shout({}) self.tl.trigger_callbacks = mock.MagicMock() self.tl._on_message(data) self.tl.trigger_callbacks.assert_called_with(data['type'], data) - # Invalid serialized messages should be dropped def test_on_raw_message_invalid(self): + """Invalid serialized messages should be dropped.""" self.tl._init_peer = mock.MagicMock() self.tl._on_message = mock.MagicMock() self.tl._on_raw_message('invalid serialization') self.assertFalse(self.tl._init_peer.called) self.assertFalse(self.tl._on_message.called) - # A hello message with no uri should not add a peer def test_on_raw_message_hello_no_uri(self): + """A hello message with no uri should not add a peer.""" self.tl._on_raw_message([json.dumps(protocol.hello_request({}))]) self.assertEqual(0, len(self.tl.peers)) - # A hello message with a uri should result in a new peer def test_on_raw_message_hello_with_uri(self): + """A hello message with a uri should result in a new peer.""" request = protocol.hello_request({ 'uri': 'tcp://localhost:12345' }) @@ -92,11 +85,15 @@ def test_on_raw_message_hello_with_uri(self): class TestTransportLayerProfile(unittest.TestCase): + def test_get_profile(self): - tl = TransportLayer(1, '1.1.1.1', 12345, 1) + tl = transport.TransportLayer(1, '1.1.1.1', 12345, 1) self.assertEqual( tl.get_profile(), protocol.hello_request({ 'uri': 'tcp://1.1.1.1:12345' }) ) + +if __name__ == "__main__": + unittest.main()