From b735854642263dd997c8ede049a009fd77fdbdbf Mon Sep 17 00:00:00 2001 From: Yann Bouteiller Date: Fri, 15 Dec 2023 07:21:39 -0500 Subject: [PATCH 1/6] robustness against keyboardinterrupt --- setup.py | 4 +-- tlspyo/api.py | 97 +++++++++++++++++++++++++++++---------------------- 2 files changed, 57 insertions(+), 44 deletions(-) diff --git a/setup.py b/setup.py index 00fdce4..c34e70c 100644 --- a/setup.py +++ b/setup.py @@ -14,8 +14,8 @@ setup(name='tlspyo', packages=[package for package in find_packages()], - version='0.2.5', - download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.2.5.tar.gz', + version='0.2.6', + download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.2.6.tar.gz', license='MIT', description='Secure transport of python objects using TLS encryption', long_description=long_description, diff --git a/tlspyo/api.py b/tlspyo/api.py index 473e016..719312a 100644 --- a/tlspyo/api.py +++ b/tlspyo/api.py @@ -4,6 +4,7 @@ from threading import Thread, Lock from multiprocessing import Process import os +import weakref from tlspyo.server import Server from tlspyo.client import Client @@ -93,7 +94,9 @@ def __init__(self, self._local_com_conn, self._local_com_addr = self._local_com_srv.accept() self._send_local('TEST') - def __del__(self): + self._finalizer = weakref.finalize(self, self._finalize) + + def _finalize(self): self.stop() def _send_local(self, cmd): @@ -105,14 +108,17 @@ def stop(self): """ Stop the Relay. """ - if not self._stopped: - self._stopped = True - self._send_local('STOP') + try: + if not self._stopped: + self._send_local('STOP') - self._p.join() - self._local_com_conn.close() - self._local_com_srv.close() - self._local_com_addr = None + self._p.join() + self._local_com_conn.close() + self._local_com_srv.close() + self._local_com_addr = None + self._stopped = True + except KeyboardInterrupt: + self.stop() class Endpoint: @@ -218,10 +224,12 @@ def __init__(self, self._local_com_conn, self._local_com_addr = self._local_com_srv.accept() self._send_local(cmd='TEST') - self._t_manage_received_objects = Thread(target=self._manage_received_objects, daemon=True) + self._t_manage_received_objects = Thread(target=self._manage_received_objects, daemon=False) self._t_manage_received_objects.start() - def __del__(self): + self._finalizer = weakref.finalize(self, self._finalize) + + def _finalize(self): self.stop() def _deserialize(self, obj): @@ -231,23 +239,25 @@ def _manage_received_objects(self): """ Called in its own thread. """ - buf = b"" - while True: - # Check if socket is still open - with self.__socket_closed_lock: - if self.__socket_closed_flag: - self._local_com_conn.close() - return - - buf += self._local_com_conn.recv(self._max_buf_len) - i, j = self._process_header(buf) - while j <= len(buf): - stamp, cmd, obj = self._deserialize(buf[i:j]) - if cmd == "OBJ": - to_put = obj if self._deserialize_locally else self._deserialize(obj) - self.__obj_buffer.put(to_put) # TODO: maxlen - buf = buf[j:] + try: + buf = b"" + while True: + # Check if socket is still open + with self.__socket_closed_lock: + if self.__socket_closed_flag: + return + + buf += self._local_com_conn.recv(self._max_buf_len) i, j = self._process_header(buf) + while j <= len(buf): + stamp, cmd, obj = self._deserialize(buf[i:j]) + if cmd == "OBJ": + to_put = obj if self._deserialize_locally else self._deserialize(obj) + self.__obj_buffer.put(to_put) # TODO: maxlen + buf = buf[j:] + i, j = self._process_header(buf) + finally: + self._local_com_conn.close() def _process_header(self, buf): i = self._header_size @@ -357,22 +367,25 @@ def stop(self): """ Stop the Endpoint. """ - if not self._stopped: - self._stopped = True - # send STOP to the local server - self._send_local(cmd='STOP', dest=None, obj=None) - - # Join the message reading thread - with self.__socket_closed_lock: - self.__socket_closed_flag = True - self._t_manage_received_objects.join() - - # join Twisted process and stop local server - self._p.join() - - self._local_com_conn.close() - self._local_com_srv.close() - self._local_com_addr = None + try: + if not self._stopped: + # send STOP to the local server + self._send_local(cmd='STOP', dest=None, obj=None) + + # Join the message reading thread + with self.__socket_closed_lock: + self.__socket_closed_flag = True + self._t_manage_received_objects.join() + + # join Twisted process and stop local server + self._p.join() + + self._local_com_conn.close() + self._local_com_srv.close() + self._local_com_addr = None + self._stopped = True + except KeyboardInterrupt: + self.stop() def _process_received_list(self, received_list): if self._deserialize_locally: From c4526d752996aaac7853e5ebe73d5abb116f9d97 Mon Sep 17 00:00:00 2001 From: Yann Bouteiller Date: Sun, 17 Dec 2023 12:49:21 -0500 Subject: [PATCH 2/6] reduced verbosity --- tlspyo/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tlspyo/client.py b/tlspyo/client.py index 844e6b7..3e56bc7 100644 --- a/tlspyo/client.py +++ b/tlspyo/client.py @@ -49,7 +49,7 @@ def dataReceived(self, data): stamp, cmd, obj = self._client.deserializer(self._buffer[i:j]) if cmd == 'ACK': try: - logger.info(f"ACK received after {time.monotonic() - self._client.pending_acks[stamp][0]}s.") + logger.debug(f"ACK received after {time.monotonic() - self._client.pending_acks[stamp][0]}s.") del self._client.pending_acks[stamp] # delete pending ACK except KeyError: logger.warning(f"Received ACK for stamp {stamp} not present in pending ACKs.") From 553de74394172e1d6b77096018c0b9e2d7f9b7f2 Mon Sep 17 00:00:00 2001 From: Yann Bouteiller Date: Tue, 26 Dec 2023 08:31:22 -0500 Subject: [PATCH 3/6] hotfix service-identity version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c34e70c..996fbd7 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ install_requires=[ 'twisted', 'pyOpenSSL>22.1.0', - 'service_identity', + 'service_identity==21.1.0', 'platformdirs' ], extras_requires={ From e2649cc7a8481cc839c43ae4259f7622ce74a473 Mon Sep 17 00:00:00 2001 From: Yann Bouteiller Date: Fri, 5 Jan 2024 15:08:44 -0500 Subject: [PATCH 4/6] SAN support in TLS certificates, version 0.3 --- setup.py | 8 +++----- tlspyo/credentials.py | 40 +++++++++++++++++++++++++++++++--------- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index 996fbd7..927c63e 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,6 @@ from setuptools import setup, find_packages import sys -from pathlib import Path - if sys.version_info < (3, 7): sys.exit('Sorry, Python < 3.7 is not supported, upgrade your python installation to use tlspyo.') @@ -14,8 +12,8 @@ setup(name='tlspyo', packages=[package for package in find_packages()], - version='0.2.6', - download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.2.6.tar.gz', + version='0.3.0', + download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.3.0.tar.gz', license='MIT', description='Secure transport of python objects using TLS encryption', long_description=long_description, @@ -26,7 +24,7 @@ install_requires=[ 'twisted', 'pyOpenSSL>22.1.0', - 'service_identity==21.1.0', + 'service_identity', 'platformdirs' ], extras_requires={ diff --git a/tlspyo/credentials.py b/tlspyo/credentials.py index 68426a6..5a11d1f 100644 --- a/tlspyo/credentials.py +++ b/tlspyo/credentials.py @@ -28,6 +28,7 @@ def generate_tls_credentials( folder_path, email_address="emailAddress", common_name="default", + subject_alt_name=('DNS:default',), country_name="CA", locality_name="localityName", state_or_province_name="stateOrProvinceName", @@ -42,6 +43,7 @@ def generate_tls_credentials( folder_path (path-like object): path were the files will be created email_address (str): your email address common_name (str): your hostname + subject_alt_name (tuple of str): your subject alt name list country_name (str): your country code locality_name (str): your locality name state_or_province_name (str): your state name @@ -57,18 +59,26 @@ def generate_tls_credentials( k = crypto.PKey() k.generate_key(crypto.TYPE_RSA, 4096) cert = crypto.X509() - cert.get_subject().C = country_name - cert.get_subject().ST = state_or_province_name - cert.get_subject().L = locality_name - cert.get_subject().O = organization_name - cert.get_subject().OU = organization_unit_name - cert.get_subject().CN = common_name - cert.get_subject().emailAddress = email_address - cert.set_serial_number(serial_number) + + subject = cert.get_subject() + subject.commonName = common_name + subject.emailAddress = email_address + subject.organizationName = organization_name + subject.organizationalUnitName = organization_unit_name + subject.localityName = locality_name + subject.stateOrProvinceName = state_or_province_name + subject.countryName = country_name + + cert.set_issuer(subject) cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notAfter(validity_end_in_seconds) - cert.set_issuer(cert.get_subject()) cert.set_pubkey(k) + cert.set_serial_number(serial_number) + cert.set_version(2) # for SAN + cert.add_extensions([ + crypto.X509Extension(b'subjectAltName', False, ','.join(subject_alt_name).encode()) + ]) + cert.sign(k, 'sha512') with open(cert_file, "wt") as f: f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")) @@ -87,6 +97,7 @@ def credentials_generator_tool(custom=False): folder_path = get_default_keys_folder() email_address = "emailAddress" common_name = "default" + subject_alt_name = ["DNS:" + common_name] country_name = "CA" locality_name = "localityName" state_or_province_name = "stateOrProvinceName" @@ -118,6 +129,16 @@ def credentials_generator_tool(custom=False): common_name = inp print(common_name) + subject_alt_name = ["DNS:" + common_name] + print(f"\nSubject alternative name (hostnames, leave empty to stop adding) {subject_alt_name}:") + inp = input() + if inp != "": + subject_alt_name = [] + while inp != "": + subject_alt_name.append(inp) + inp = input() + print(subject_alt_name) + print(f"\nCountry code [{country_name}]:") inp = input() if inp != "": @@ -163,6 +184,7 @@ def credentials_generator_tool(custom=False): generate_tls_credentials(folder_path=folder_path, email_address=email_address, common_name=common_name, + subject_alt_name=tuple(subject_alt_name), country_name=country_name, locality_name=locality_name, state_or_province_name=state_or_province_name, From e40426d55dbf99b5d4319b4e4d6f28cbeeeca821 Mon Sep 17 00:00:00 2001 From: Yann Bouteiller Date: Fri, 5 Jan 2024 15:56:06 -0500 Subject: [PATCH 5/6] thread-safe stop --- tlspyo/api.py | 63 +++++++++++++++++++++++++++------------------------ 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/tlspyo/api.py b/tlspyo/api.py index 719312a..4499213 100644 --- a/tlspyo/api.py +++ b/tlspyo/api.py @@ -68,7 +68,6 @@ def __init__(self, assert accepted_groups is None or isinstance(accepted_groups, dict), "Invalid format for accepted_groups." - self._stopped = False self._header_size = header_size self._local_com_port = local_com_port self._local_com_srv = socket(AF_INET, SOCK_STREAM) @@ -95,6 +94,8 @@ def __init__(self, self._send_local('TEST') self._finalizer = weakref.finalize(self, self._finalize) + self._stop_lock = Lock() + self._stopped = False def _finalize(self): self.stop() @@ -109,16 +110,18 @@ def stop(self): Stop the Relay. """ try: - if not self._stopped: - self._send_local('STOP') - - self._p.join() - self._local_com_conn.close() - self._local_com_srv.close() - self._local_com_addr = None - self._stopped = True - except KeyboardInterrupt: + with self._stop_lock: + if not self._stopped: + self._send_local('STOP') + + self._p.join() + self._local_com_conn.close() + self._local_com_srv.close() + self._local_com_addr = None + self._stopped = True + except KeyboardInterrupt as e: self.stop() + raise e class Endpoint: @@ -178,8 +181,6 @@ def __init__(self, elif security == "SSL": security = "TLS" - self._stopped = False - # threading for local object receiving self.__obj_buffer = queue.Queue() self.__socket_closed_lock = Lock() @@ -228,6 +229,8 @@ def __init__(self, self._t_manage_received_objects.start() self._finalizer = weakref.finalize(self, self._finalize) + self._stop_lock = Lock() + self._stopped = False def _finalize(self): self.stop() @@ -368,24 +371,26 @@ def stop(self): Stop the Endpoint. """ try: - if not self._stopped: - # send STOP to the local server - self._send_local(cmd='STOP', dest=None, obj=None) - - # Join the message reading thread - with self.__socket_closed_lock: - self.__socket_closed_flag = True - self._t_manage_received_objects.join() - - # join Twisted process and stop local server - self._p.join() - - self._local_com_conn.close() - self._local_com_srv.close() - self._local_com_addr = None - self._stopped = True - except KeyboardInterrupt: + with self._stop_lock: + if not self._stopped: + # send STOP to the local server + self._send_local(cmd='STOP', dest=None, obj=None) + + # Join the message reading thread + with self.__socket_closed_lock: + self.__socket_closed_flag = True + self._t_manage_received_objects.join() + + # join Twisted process and stop local server + self._p.join() + + self._local_com_conn.close() + self._local_com_srv.close() + self._local_com_addr = None + self._stopped = True + except KeyboardInterrupt as e: self.stop() + raise e def _process_received_list(self, received_list): if self._deserialize_locally: From f9e7b5c3744cd80bc8f6e0d26f18b38e080d8fe6 Mon Sep 17 00:00:00 2001 From: Yann Bouteiller Date: Fri, 5 Jan 2024 22:46:47 -0500 Subject: [PATCH 6/6] Remove weakref --- tlspyo/api.py | 42 ++++++++++++++++++------------------------ tlspyo/utils.py | 10 +++++----- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/tlspyo/api.py b/tlspyo/api.py index 4499213..8a5b2e0 100644 --- a/tlspyo/api.py +++ b/tlspyo/api.py @@ -4,7 +4,6 @@ from threading import Thread, Lock from multiprocessing import Process import os -import weakref from tlspyo.server import Server from tlspyo.client import Client @@ -93,11 +92,10 @@ def __init__(self, self._local_com_conn, self._local_com_addr = self._local_com_srv.accept() self._send_local('TEST') - self._finalizer = weakref.finalize(self, self._finalize) self._stop_lock = Lock() self._stopped = False - def _finalize(self): + def __del__(self): self.stop() def _send_local(self, cmd): @@ -225,14 +223,13 @@ def __init__(self, self._local_com_conn, self._local_com_addr = self._local_com_srv.accept() self._send_local(cmd='TEST') - self._t_manage_received_objects = Thread(target=self._manage_received_objects, daemon=False) + self._t_manage_received_objects = Thread(target=self._manage_received_objects, daemon=True) self._t_manage_received_objects.start() - self._finalizer = weakref.finalize(self, self._finalize) self._stop_lock = Lock() self._stopped = False - def _finalize(self): + def __del__(self): self.stop() def _deserialize(self, obj): @@ -242,25 +239,22 @@ def _manage_received_objects(self): """ Called in its own thread. """ - try: - buf = b"" - while True: - # Check if socket is still open - with self.__socket_closed_lock: - if self.__socket_closed_flag: - return - - buf += self._local_com_conn.recv(self._max_buf_len) + buf = b"" + while True: + # Check if socket is still open + with self.__socket_closed_lock: + if self.__socket_closed_flag: + return + + buf += self._local_com_conn.recv(self._max_buf_len) + i, j = self._process_header(buf) + while j <= len(buf): + stamp, cmd, obj = self._deserialize(buf[i:j]) + if cmd == "OBJ": + to_put = obj if self._deserialize_locally else self._deserialize(obj) + self.__obj_buffer.put(to_put) # TODO: maxlen + buf = buf[j:] i, j = self._process_header(buf) - while j <= len(buf): - stamp, cmd, obj = self._deserialize(buf[i:j]) - if cmd == "OBJ": - to_put = obj if self._deserialize_locally else self._deserialize(obj) - self.__obj_buffer.put(to_put) # TODO: maxlen - buf = buf[j:] - i, j = self._process_header(buf) - finally: - self._local_com_conn.close() def _process_header(self, buf): i = self._header_size diff --git a/tlspyo/utils.py b/tlspyo/utils.py index b43ead3..1b9e2d0 100644 --- a/tlspyo/utils.py +++ b/tlspyo/utils.py @@ -1,11 +1,11 @@ -import signal +# import signal import queue -try: - signal.signal(signal.SIGINT, signal.SIG_DFL) -except Exception as e: - pass +# try: +# signal.signal(signal.SIGINT, signal.SIG_DFL) +# except Exception as e: +# pass def wait_event(event):