From 3af67ab77b2e5ec5ef8a335990fc5fb7c734d2bb Mon Sep 17 00:00:00 2001 From: Yann Bouteiller Date: Fri, 5 Jan 2024 22:58:45 -0500 Subject: [PATCH] Version 0.3.0 (#5) * SubjectAtlName support --- setup.py | 6 ++--- tlspyo/api.py | 60 ++++++++++++++++++++++++++----------------- tlspyo/client.py | 2 +- tlspyo/credentials.py | 40 ++++++++++++++++++++++------- tlspyo/utils.py | 10 ++++---- 5 files changed, 75 insertions(+), 43 deletions(-) diff --git a/setup.py b/setup.py index 00fdce4..927c63e 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,6 @@ from setuptools import setup, find_packages import sys -from pathlib import Path - if sys.version_info < (3, 7): sys.exit('Sorry, Python < 3.7 is not supported, upgrade your python installation to use tlspyo.') @@ -14,8 +12,8 @@ setup(name='tlspyo', packages=[package for package in find_packages()], - version='0.2.5', - download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.2.5.tar.gz', + version='0.3.0', + download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.3.0.tar.gz', license='MIT', description='Secure transport of python objects using TLS encryption', long_description=long_description, diff --git a/tlspyo/api.py b/tlspyo/api.py index 473e016..8a5b2e0 100644 --- a/tlspyo/api.py +++ b/tlspyo/api.py @@ -67,7 +67,6 @@ def __init__(self, assert accepted_groups is None or isinstance(accepted_groups, dict), "Invalid format for accepted_groups." - self._stopped = False self._header_size = header_size self._local_com_port = local_com_port self._local_com_srv = socket(AF_INET, SOCK_STREAM) @@ -93,6 +92,9 @@ def __init__(self, self._local_com_conn, self._local_com_addr = self._local_com_srv.accept() self._send_local('TEST') + self._stop_lock = Lock() + self._stopped = False + def __del__(self): self.stop() @@ -105,14 +107,19 @@ def stop(self): """ Stop the Relay. """ - if not self._stopped: - self._stopped = True - self._send_local('STOP') + try: + with self._stop_lock: + if not self._stopped: + self._send_local('STOP') - self._p.join() - self._local_com_conn.close() - self._local_com_srv.close() - self._local_com_addr = None + self._p.join() + self._local_com_conn.close() + self._local_com_srv.close() + self._local_com_addr = None + self._stopped = True + except KeyboardInterrupt as e: + self.stop() + raise e class Endpoint: @@ -172,8 +179,6 @@ def __init__(self, elif security == "SSL": security = "TLS" - self._stopped = False - # threading for local object receiving self.__obj_buffer = queue.Queue() self.__socket_closed_lock = Lock() @@ -221,6 +226,9 @@ def __init__(self, self._t_manage_received_objects = Thread(target=self._manage_received_objects, daemon=True) self._t_manage_received_objects.start() + self._stop_lock = Lock() + self._stopped = False + def __del__(self): self.stop() @@ -236,7 +244,6 @@ def _manage_received_objects(self): # Check if socket is still open with self.__socket_closed_lock: if self.__socket_closed_flag: - self._local_com_conn.close() return buf += self._local_com_conn.recv(self._max_buf_len) @@ -357,22 +364,27 @@ def stop(self): """ Stop the Endpoint. """ - if not self._stopped: - self._stopped = True - # send STOP to the local server - self._send_local(cmd='STOP', dest=None, obj=None) + try: + with self._stop_lock: + if not self._stopped: + # send STOP to the local server + self._send_local(cmd='STOP', dest=None, obj=None) - # Join the message reading thread - with self.__socket_closed_lock: - self.__socket_closed_flag = True - self._t_manage_received_objects.join() + # Join the message reading thread + with self.__socket_closed_lock: + self.__socket_closed_flag = True + self._t_manage_received_objects.join() - # join Twisted process and stop local server - self._p.join() + # join Twisted process and stop local server + self._p.join() - self._local_com_conn.close() - self._local_com_srv.close() - self._local_com_addr = None + self._local_com_conn.close() + self._local_com_srv.close() + self._local_com_addr = None + self._stopped = True + except KeyboardInterrupt as e: + self.stop() + raise e def _process_received_list(self, received_list): if self._deserialize_locally: diff --git a/tlspyo/client.py b/tlspyo/client.py index 844e6b7..3e56bc7 100644 --- a/tlspyo/client.py +++ b/tlspyo/client.py @@ -49,7 +49,7 @@ def dataReceived(self, data): stamp, cmd, obj = self._client.deserializer(self._buffer[i:j]) if cmd == 'ACK': try: - logger.info(f"ACK received after {time.monotonic() - self._client.pending_acks[stamp][0]}s.") + logger.debug(f"ACK received after {time.monotonic() - self._client.pending_acks[stamp][0]}s.") del self._client.pending_acks[stamp] # delete pending ACK except KeyError: logger.warning(f"Received ACK for stamp {stamp} not present in pending ACKs.") diff --git a/tlspyo/credentials.py b/tlspyo/credentials.py index 68426a6..5a11d1f 100644 --- a/tlspyo/credentials.py +++ b/tlspyo/credentials.py @@ -28,6 +28,7 @@ def generate_tls_credentials( folder_path, email_address="emailAddress", common_name="default", + subject_alt_name=('DNS:default',), country_name="CA", locality_name="localityName", state_or_province_name="stateOrProvinceName", @@ -42,6 +43,7 @@ def generate_tls_credentials( folder_path (path-like object): path were the files will be created email_address (str): your email address common_name (str): your hostname + subject_alt_name (tuple of str): your subject alt name list country_name (str): your country code locality_name (str): your locality name state_or_province_name (str): your state name @@ -57,18 +59,26 @@ def generate_tls_credentials( k = crypto.PKey() k.generate_key(crypto.TYPE_RSA, 4096) cert = crypto.X509() - cert.get_subject().C = country_name - cert.get_subject().ST = state_or_province_name - cert.get_subject().L = locality_name - cert.get_subject().O = organization_name - cert.get_subject().OU = organization_unit_name - cert.get_subject().CN = common_name - cert.get_subject().emailAddress = email_address - cert.set_serial_number(serial_number) + + subject = cert.get_subject() + subject.commonName = common_name + subject.emailAddress = email_address + subject.organizationName = organization_name + subject.organizationalUnitName = organization_unit_name + subject.localityName = locality_name + subject.stateOrProvinceName = state_or_province_name + subject.countryName = country_name + + cert.set_issuer(subject) cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notAfter(validity_end_in_seconds) - cert.set_issuer(cert.get_subject()) cert.set_pubkey(k) + cert.set_serial_number(serial_number) + cert.set_version(2) # for SAN + cert.add_extensions([ + crypto.X509Extension(b'subjectAltName', False, ','.join(subject_alt_name).encode()) + ]) + cert.sign(k, 'sha512') with open(cert_file, "wt") as f: f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")) @@ -87,6 +97,7 @@ def credentials_generator_tool(custom=False): folder_path = get_default_keys_folder() email_address = "emailAddress" common_name = "default" + subject_alt_name = ["DNS:" + common_name] country_name = "CA" locality_name = "localityName" state_or_province_name = "stateOrProvinceName" @@ -118,6 +129,16 @@ def credentials_generator_tool(custom=False): common_name = inp print(common_name) + subject_alt_name = ["DNS:" + common_name] + print(f"\nSubject alternative name (hostnames, leave empty to stop adding) {subject_alt_name}:") + inp = input() + if inp != "": + subject_alt_name = [] + while inp != "": + subject_alt_name.append(inp) + inp = input() + print(subject_alt_name) + print(f"\nCountry code [{country_name}]:") inp = input() if inp != "": @@ -163,6 +184,7 @@ def credentials_generator_tool(custom=False): generate_tls_credentials(folder_path=folder_path, email_address=email_address, common_name=common_name, + subject_alt_name=tuple(subject_alt_name), country_name=country_name, locality_name=locality_name, state_or_province_name=state_or_province_name, diff --git a/tlspyo/utils.py b/tlspyo/utils.py index b43ead3..1b9e2d0 100644 --- a/tlspyo/utils.py +++ b/tlspyo/utils.py @@ -1,11 +1,11 @@ -import signal +# import signal import queue -try: - signal.signal(signal.SIGINT, signal.SIG_DFL) -except Exception as e: - pass +# try: +# signal.signal(signal.SIGINT, signal.SIG_DFL) +# except Exception as e: +# pass def wait_event(event):