Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.3.0 #5

Merged
merged 6 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from setuptools import setup, find_packages
import sys

from pathlib import Path


if sys.version_info < (3, 7):
sys.exit('Sorry, Python < 3.7 is not supported, upgrade your python installation to use tlspyo.')
Expand All @@ -14,8 +12,8 @@

setup(name='tlspyo',
packages=[package for package in find_packages()],
version='0.2.5',
download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.2.5.tar.gz',
version='0.3.0',
download_url='https://github.com/MISTLab/tls-python-object/archive/refs/tags/v0.3.0.tar.gz',
license='MIT',
description='Secure transport of python objects using TLS encryption',
long_description=long_description,
Expand Down
60 changes: 36 additions & 24 deletions tlspyo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(self,

assert accepted_groups is None or isinstance(accepted_groups, dict), "Invalid format for accepted_groups."

self._stopped = False
self._header_size = header_size
self._local_com_port = local_com_port
self._local_com_srv = socket(AF_INET, SOCK_STREAM)
Expand All @@ -93,6 +92,9 @@ def __init__(self,
self._local_com_conn, self._local_com_addr = self._local_com_srv.accept()
self._send_local('TEST')

self._stop_lock = Lock()
self._stopped = False

def __del__(self):
self.stop()

Expand All @@ -105,14 +107,19 @@ def stop(self):
"""
Stop the Relay.
"""
if not self._stopped:
self._stopped = True
self._send_local('STOP')
try:
with self._stop_lock:
if not self._stopped:
self._send_local('STOP')

self._p.join()
self._local_com_conn.close()
self._local_com_srv.close()
self._local_com_addr = None
self._p.join()
self._local_com_conn.close()
self._local_com_srv.close()
self._local_com_addr = None
self._stopped = True
except KeyboardInterrupt as e:
self.stop()
raise e


class Endpoint:
Expand Down Expand Up @@ -172,8 +179,6 @@ def __init__(self,
elif security == "SSL":
security = "TLS"

self._stopped = False

# threading for local object receiving
self.__obj_buffer = queue.Queue()
self.__socket_closed_lock = Lock()
Expand Down Expand Up @@ -221,6 +226,9 @@ def __init__(self,
self._t_manage_received_objects = Thread(target=self._manage_received_objects, daemon=True)
self._t_manage_received_objects.start()

self._stop_lock = Lock()
self._stopped = False

def __del__(self):
self.stop()

Expand All @@ -236,7 +244,6 @@ def _manage_received_objects(self):
# Check if socket is still open
with self.__socket_closed_lock:
if self.__socket_closed_flag:
self._local_com_conn.close()
return

buf += self._local_com_conn.recv(self._max_buf_len)
Expand Down Expand Up @@ -357,22 +364,27 @@ def stop(self):
"""
Stop the Endpoint.
"""
if not self._stopped:
self._stopped = True
# send STOP to the local server
self._send_local(cmd='STOP', dest=None, obj=None)
try:
with self._stop_lock:
if not self._stopped:
# send STOP to the local server
self._send_local(cmd='STOP', dest=None, obj=None)

# Join the message reading thread
with self.__socket_closed_lock:
self.__socket_closed_flag = True
self._t_manage_received_objects.join()
# Join the message reading thread
with self.__socket_closed_lock:
self.__socket_closed_flag = True
self._t_manage_received_objects.join()

# join Twisted process and stop local server
self._p.join()
# join Twisted process and stop local server
self._p.join()

self._local_com_conn.close()
self._local_com_srv.close()
self._local_com_addr = None
self._local_com_conn.close()
self._local_com_srv.close()
self._local_com_addr = None
self._stopped = True
except KeyboardInterrupt as e:
self.stop()
raise e

def _process_received_list(self, received_list):
if self._deserialize_locally:
Expand Down
2 changes: 1 addition & 1 deletion tlspyo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def dataReceived(self, data):
stamp, cmd, obj = self._client.deserializer(self._buffer[i:j])
if cmd == 'ACK':
try:
logger.info(f"ACK received after {time.monotonic() - self._client.pending_acks[stamp][0]}s.")
logger.debug(f"ACK received after {time.monotonic() - self._client.pending_acks[stamp][0]}s.")
del self._client.pending_acks[stamp] # delete pending ACK
except KeyError:
logger.warning(f"Received ACK for stamp {stamp} not present in pending ACKs.")
Expand Down
40 changes: 31 additions & 9 deletions tlspyo/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def generate_tls_credentials(
folder_path,
email_address="emailAddress",
common_name="default",
subject_alt_name=('DNS:default',),
country_name="CA",
locality_name="localityName",
state_or_province_name="stateOrProvinceName",
Expand All @@ -42,6 +43,7 @@ def generate_tls_credentials(
folder_path (path-like object): path were the files will be created
email_address (str): your email address
common_name (str): your hostname
subject_alt_name (tuple of str): your subject alt name list
country_name (str): your country code
locality_name (str): your locality name
state_or_province_name (str): your state name
Expand All @@ -57,18 +59,26 @@ def generate_tls_credentials(
k = crypto.PKey()
k.generate_key(crypto.TYPE_RSA, 4096)
cert = crypto.X509()
cert.get_subject().C = country_name
cert.get_subject().ST = state_or_province_name
cert.get_subject().L = locality_name
cert.get_subject().O = organization_name
cert.get_subject().OU = organization_unit_name
cert.get_subject().CN = common_name
cert.get_subject().emailAddress = email_address
cert.set_serial_number(serial_number)

subject = cert.get_subject()
subject.commonName = common_name
subject.emailAddress = email_address
subject.organizationName = organization_name
subject.organizationalUnitName = organization_unit_name
subject.localityName = locality_name
subject.stateOrProvinceName = state_or_province_name
subject.countryName = country_name

cert.set_issuer(subject)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(validity_end_in_seconds)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(k)
cert.set_serial_number(serial_number)
cert.set_version(2) # for SAN
cert.add_extensions([
crypto.X509Extension(b'subjectAltName', False, ','.join(subject_alt_name).encode())
])

cert.sign(k, 'sha512')
with open(cert_file, "wt") as f:
f.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
Expand All @@ -87,6 +97,7 @@ def credentials_generator_tool(custom=False):
folder_path = get_default_keys_folder()
email_address = "emailAddress"
common_name = "default"
subject_alt_name = ["DNS:" + common_name]
country_name = "CA"
locality_name = "localityName"
state_or_province_name = "stateOrProvinceName"
Expand Down Expand Up @@ -118,6 +129,16 @@ def credentials_generator_tool(custom=False):
common_name = inp
print(common_name)

subject_alt_name = ["DNS:" + common_name]
print(f"\nSubject alternative name (hostnames, leave empty to stop adding) {subject_alt_name}:")
inp = input()
if inp != "":
subject_alt_name = []
while inp != "":
subject_alt_name.append(inp)
inp = input()
print(subject_alt_name)

print(f"\nCountry code [{country_name}]:")
inp = input()
if inp != "":
Expand Down Expand Up @@ -163,6 +184,7 @@ def credentials_generator_tool(custom=False):
generate_tls_credentials(folder_path=folder_path,
email_address=email_address,
common_name=common_name,
subject_alt_name=tuple(subject_alt_name),
country_name=country_name,
locality_name=locality_name,
state_or_province_name=state_or_province_name,
Expand Down
10 changes: 5 additions & 5 deletions tlspyo/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import signal
# import signal
import queue


try:
signal.signal(signal.SIGINT, signal.SIG_DFL)
except Exception as e:
pass
# try:
# signal.signal(signal.SIGINT, signal.SIG_DFL)
# except Exception as e:
# pass


def wait_event(event):
Expand Down