Skip to content

Commit

Permalink
Added support for synchronous deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
yannbouteiller committed Feb 4, 2023
1 parent 338414c commit 743b8f4
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ jobs:
python -m tlspyo --generate
- name: Test with pytest
run: |
pytest --timeout=20
pytest --timeout=30
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

setup(name='tlspyo',
packages=[package for package in find_packages()],
version='0.2.4',
download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.2.4.tar.gz',
version='0.2.5',
download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.2.5.tar.gz',
license='MIT',
description='Secure transport of python objects using TLS encryption',
long_description=long_description,
Expand Down
198 changes: 198 additions & 0 deletions test/test_groups_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import unittest
import time

from utils import HelperTester


def same_lists_no_order(l1, l2):
for elt in l1:
if elt in l2:
l2.remove(elt)
else:
return False
return len(l2) == 0


class TestGroupsSync(unittest.TestCase):

# Set up the server and all endpoints for all tests
def setUp(self):
self.ht = HelperTester(deserializer_mode="synchronous")

def test_groups_accept_all(self):
sr = self.ht.spawn_relay
se = self.ht.spawn_endpoint
relay = sr(accepted_groups=None)
ep1 = se(groups='group1')
ep2 = se(groups=('group1', 'group2'))
ep3 = se(groups='group3')
ep4 = se(groups='group5')
ep5 = se(groups=('group6', 'group5', 'group1'))
time.sleep(1.0) # let everyone handshake the relay so that broadcasts don't get overwritten before that

# test broadcasting

ep5.send_object(obj='test1', destination='group1')
r = ep1.pop(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test1', f"r:{r}")

ep5.send_object(obj='test2', destination='group1')
r = ep1.receive_all(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test2', f"r:{r}")

ep2.send_object(obj='test3', destination='group1')
r = ep1.receive_all(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test3', f"r:{r}")
r = []
while len(r) < 3:
r += ep2.receive_all(blocking=True)
self.assertEqual(len(r), 3, f"r:{r}")
self.assertEqual(r[0], 'test1', f"r:{r}")
self.assertEqual(r[1], 'test2', f"r:{r}")
self.assertEqual(r[2], 'test3', f"r:{r}")

ep1.send_object(obj='test4', destination=('group1', 'group5', 'group6'))
r = ep1.pop(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test4', f"r:{r}")
r = ep2.pop(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test4', f"r:{r}")
r = ep4.pop(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test4', f"r:{r}")
r = []
while len(r) < 6:
r += ep5.receive_all(blocking=True)
self.assertEqual(len(r), 6, f"r:{r}")
self.assertEqual(r[0], 'test1', f"r:{r}")
self.assertEqual(r[1], 'test2', f"r:{r}")
self.assertEqual(r[2], 'test3', f"r:{r}")
self.assertEqual(r[3], 'test4', f"r:{r}")
self.assertEqual(r[4], 'test4', f"r:{r}")
self.assertEqual(r[5], 'test4', f"r:{r}")

ep1.send_object(obj='test5', destination='group3')
ep1.send_object(obj='test6', destination='group3')
r = ep3.get_last(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
if r[0] == 'test5':
r = ep3.get_last(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test6', f"r:{r}")

# test producing
ep1.produce(obj='test7', group='group3')
ep3.notify(groups='group3')
r = ep3.pop(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test7', f"r:{r}")

ep1.send_object(obj='test8', destination={'group1': 3})
ep1.notify(groups='group1')
ep2.notify(groups='group1')
ep5.notify(groups='group1')
r = ep1.pop(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test8', f"r:{r}")
r = ep2.pop(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test8', f"r:{r}")
r = ep5.pop(blocking=True)
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test8', f"r:{r}")

# test mixed broadcasting / producing

ep5.send_object(obj='test9', destination={'group1': 1, 'group2': 1}) # this produces to group 1 and 2
ep5.send_object(obj='test10', destination={'group1': -1, 'group2': -1}) # this broadcasts to groups 1 and 2
ep1.notify(groups=('group1', 'group2', 'group999')) # note that ep1 is in group 1 only
r = []
while len(r) < 2:
r += ep1.receive_all(blocking=True)
self.assertEqual(len(r), 2, f"r:{r}")
self.assertIn('test9', r, f"r:{r}")
self.assertIn('test10', r, f"r:{r}")
ep2.notify(groups=('group1', 'group2')) # note that ep2 is in group 1 and group 2, but group 1 is empty
r = []
while len(r) < 3:
r += ep2.receive_all(blocking=True)
self.assertEqual(len(r), 3, f"r:{r}")
self.assertTrue(same_lists_no_order(r, ['test9', 'test10', 'test10']), f"r:{r}")
ep5.send_object(obj='test11', destination={'group1': 1}) # let us send one more consumable to group 1
r = ep2.receive_all(blocking=True) # now the notification sent by ep2 can be fulfilled
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test11', f"r:{r}")

# test multiple producing / consuming

ep5.send_object(obj='test12', destination={'group1': 10}) # this produces to group 1
ep5.send_object(obj='test13', destination={'group2': 10}) # this produces to group 2
ep2.notify(groups={'group1': 1}) # retrieve 1 elt in group 1
r = []
while len(r) < 11: # one consumable in group1 and all consumables is group 2
ep2.notify(groups={'group2': -1}) # ask for all elts in group 2
r += ep2.receive_all(blocking=True)
self.assertEqual(len(r), 11, f"r:{r}")
self.assertTrue(same_lists_no_order(r, ['test12', ] + ['test13', ] * 10), f"r:{r}")
r = []
while len(r) < 9: # all remaining consumables in group 1
ep1.notify(groups={'group1': -1}) # ask for all elts in group 1
r += ep1.receive_all(blocking=True)
self.assertEqual(len(r), 9, f"r:{r}")
self.assertTrue(same_lists_no_order(r, ['test12', ] * 9), f"r:{r}")

def test_groups_accept_some(self):
sr = self.ht.spawn_relay
se = self.ht.spawn_endpoint
accepted_groups = {
'group1': {'max_count': 2, 'max_consumables': 2},
'group2': {'max_count': 1, 'max_consumables': 1},
'group3': {'max_count': 1, 'max_consumables': 0},
'group4': {'max_count': None, 'max_consumables': None}
}
relay = sr(accepted_groups=accepted_groups)
# nobody is in group 4, and groups 5 and 6 are not accepted
ep1 = se(groups=('group6', 'group5', 'group1')) # should not connect as group 5 and 6 are not allowed
ep2 = se(groups='group1') # should connect
ep3 = se(groups=('group1', 'group2')) # should connect
ep4 = se(groups='group3') # should connect
time.sleep(1.0) # let everyone connect so that old broadcasts are not lost for new clients
ep5 = se(groups='group1') # should not connect as group1 is full
time.sleep(0.5)

# test broadcasting

ep5.send_object(obj='test1', destination='group1') # should not send as ep5 is not connected
# (the previous line should also output a warning)
time.sleep(0.5)
r = ep1.pop(blocking=False) # not connected so should not receive
self.assertEqual(len(r), 0, f"r:{r}")
r = ep2.pop(blocking=False) # should not receive since nothing should have been be sent
self.assertEqual(len(r), 0, f"r:{r}")

ep4.send_object(obj='test2', destination='group1') # should send
time.sleep(0.5)
r = ep1.pop(blocking=False) # not connected so should not receive
self.assertEqual(len(r), 0, f"r:{r}")
r = ep2.pop(blocking=True) # should receive
self.assertEqual(len(r), 1, f"r:{r}")
self.assertEqual(r[0], 'test2')

ep2.produce(obj='test3', group='group2')
ep3.notify(groups='group2')
r = []
while len(r) < 2:
r += ep3.receive_all(blocking=True)
self.assertEqual(len(r), 2, f"r:{r}")
self.assertTrue(same_lists_no_order(r, ['test2', 'test3']), f"r:{r}")

def tearDown(self):
self.ht.clear()


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def custom_deserializer(bytestring):
return obj


class TestGroups(unittest.TestCase):
class TestSerialization(unittest.TestCase):

# Set up the server and all endpoints for all tests
def setUp(self):
Expand Down
6 changes: 4 additions & 2 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@


class HelperTester:
def __init__(self, serializer=None, deserializer=None):
def __init__(self, serializer=None, deserializer=None, deserializer_mode="asynchronous"):
self.next_local_port = TEST_LOCAL_PORT_START
self.endpoints = []
self.relays = []
self.serializer = serializer
self.deserializer = deserializer
self.deserializer_mode = deserializer_mode

def spawn_endpoint(self, groups):
ep = Endpoint(
Expand All @@ -25,7 +26,8 @@ def spawn_endpoint(self, groups):
local_com_port=self.next_local_port,
header_size=TEST_HEADER_SIZE,
serializer=self.serializer,
deserializer=self.deserializer
deserializer=self.deserializer,
deserializer_mode=self.deserializer_mode
)
self.next_local_port += 1
self.endpoints.append(ep)
Expand Down
30 changes: 26 additions & 4 deletions tlspyo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def __init__(self,
recon_max_delay=60.0,
recon_initial_delay=10.0,
recon_factor=1.5,
recon_jitter=0.1):
recon_jitter=0.1,
deserializer_mode="asynchronous"):
"""
``tlspyo`` Endpoint.
Expand All @@ -156,6 +157,12 @@ def __init__(self,
recon_initial_delay (float): in case of network failure, initial delay between reconnection attempts
recon_factor (float): in case of network failure, delay will increase by this factor between attempts
recon_jitter (float): in case of network failure, jitter factor of the delay between attempts
deserializer_mode (str): one of ("synchronous", "asynchronous"); ("sync", "async") are also accepted;
in asynchronous mode, objects are deserialized by the receiver thread as soon as they arrive, such that
they become available to the calling thread as soon as it needs to retrieve them;
in synchronous mode, objects are deserialized by the calling thread upon object retrieval;
synchronous mode removes the need for potentially useless, randomly timed deserialization in the
background, at the cost of performing deserialization upon object retrieval instead
"""

assert security in (None, "TLS"), f"Unsupported security: {security}"
Expand All @@ -181,6 +188,8 @@ def __init__(self,
self._local_com_srv = socket(AF_INET, SOCK_STREAM)
self._local_com_srv.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)

self._deserialize_locally = deserializer_mode in ("synchronous", "sync")

keys_dir = os.path.abspath(keys_dir) if keys_dir is not None else keys_dir
serializer = serializer if serializer is not None else DEFAULT_SERIALIZER
deserializer = deserializer if deserializer is not None else DEFAULT_DESERIALIZER
Expand Down Expand Up @@ -215,6 +224,9 @@ def __init__(self,
def __del__(self):
self.stop()

def _deserialize(self, obj):
return self._client.deserializer(obj)

def _manage_received_objects(self):
"""
Called in its own thread.
Expand All @@ -230,9 +242,10 @@ def _manage_received_objects(self):
buf += self._local_com_conn.recv(self._max_buf_len)
i, j = self._process_header(buf)
while j <= len(buf):
stamp, cmd, obj = self._client.deserializer(buf[i:j])
stamp, cmd, obj = self._deserialize(buf[i:j])
if cmd == "OBJ":
self.__obj_buffer.put(self._client.deserializer(obj)) # TODO: maxlen
to_put = obj if self._deserialize_locally else self._deserialize(obj)
self.__obj_buffer.put(to_put) # TODO: maxlen
buf = buf[j:]
i, j = self._process_header(buf)

Expand Down Expand Up @@ -361,6 +374,12 @@ def stop(self):
self._local_com_srv.close()
self._local_com_addr = None

def _process_received_list(self, received_list):
if self._deserialize_locally:
for i, obj in enumerate(received_list):
received_list[i] = self._deserialize(obj)
return received_list

def receive_all(self, blocking=False):
"""
Returns all received objects in a list, from oldest to newest.
Expand All @@ -376,6 +395,7 @@ def receive_all(self, blocking=False):
while len(elem) > 0:
cpy += elem
elem = get_from_queue(self.__obj_buffer, blocking=False)
cpy = self._process_received_list(cpy)
return cpy

def pop(self, max_items=1, blocking=False):
Expand All @@ -402,6 +422,7 @@ def pop(self, max_items=1, blocking=False):
if len(cpy) >= max_items:
break
elem = get_from_queue(self.__obj_buffer, blocking=False)
cpy = self._process_received_list(cpy)
return cpy

def get_last(self, max_items=1, blocking=False):
Expand All @@ -427,4 +448,5 @@ def get_last(self, max_items=1, blocking=False):
while len(elem) > 0:
cpy += elem
elem = get_from_queue(self.__obj_buffer, blocking=False)
return cpy[-max_items:]
cpy = self._process_received_list(cpy[-max_items:])
return cpy

0 comments on commit 743b8f4

Please sign in to comment.