diff --git a/src/bastion/HISTORY.rst b/src/bastion/HISTORY.rst index f75ca6a433f..3fa88b93865 100644 --- a/src/bastion/HISTORY.rst +++ b/src/bastion/HISTORY.rst @@ -2,6 +2,10 @@ Release History =============== +1.0.1 ++++++ +* Added support for concurrent connections. + 1.0.0 ++++++ * Removing preview flag and update MFA documentation. diff --git a/src/bastion/azext_bastion/tunnel.py b/src/bastion/azext_bastion/tunnel.py index 9106526499c..0096a251392 100644 --- a/src/bastion/azext_bastion/tunnel.py +++ b/src/bastion/azext_bastion/tunnel.py @@ -18,7 +18,7 @@ import logging as logs from contextlib import closing from datetime import datetime -from threading import Thread +from threading import Thread, Lock import requests import urllib3 @@ -54,6 +54,8 @@ def __init__(self, cli_ctx, local_addr, local_port, bastion, bastion_endpoint, r self.node_id = None self.host_name = None self.cli_ctx = cli_ctx + self.active_connections = 0 + self.connection_lock = Lock() logger.info('Creating a socket on port: %s', self.local_port) self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) logger.info('Setting socket options') @@ -118,8 +120,16 @@ def _listen(self): self.sock.listen(100) index = 0 while True: - self.client, _address = self.sock.accept() - + client, _address = self.sock.accept() + with self.connection_lock: + self.active_connections += 1 + logger.info('Got a connection, starting a new thread') + thread = Thread(target=self._handle_client, args=(client, index)) + thread.start() + index += 1 + + def _handle_client(self, client, index): + try: auth_token = self._get_auth_token() if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or \ self.bastion['sku']['name'] == BastionSku.Developer.name: @@ -128,22 +138,27 @@ def _listen(self): host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}" verify_mode = ssl.CERT_NONE if should_disable_connection_verify() else ssl.CERT_REQUIRED - self.ws = create_connection(host, - sockopt=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),), - sslopt={'cert_reqs': verify_mode}, - enable_multithread=True) - logger.info('Websocket, connected status: %s', self.ws.connected) - index = index + 1 + ws = create_connection(host, + sockopt=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),), + sslopt={'cert_reqs': verify_mode}, + enable_multithread=True) + logger.info('Websocket, connected status: %s', ws.connected) logger.info('Got debugger connection... index: %s', index) - debugger_thread = Thread(target=self._listen_to_client, args=(self.client, self.ws, index)) - web_socket_thread = Thread(target=self._listen_to_web_socket, args=(self.client, self.ws, index)) + debugger_thread = Thread(target=self._listen_to_client, args=(client, ws, index)) + web_socket_thread = Thread(target=self._listen_to_web_socket, args=(client, ws, index)) debugger_thread.start() web_socket_thread.start() logger.info('Both debugger and websocket threads started...') logger.info('Successfully connected to local server..') debugger_thread.join() web_socket_thread.join() - self.cleanup() + except Exception as ex: # pylint: disable=broad-except + logger.info('Exception in handling client: %s', ex) + finally: + with self.connection_lock: + self.active_connections -= 1 + if self.active_connections == 0: + self.cleanup() logger.info('Both debugger and websocket threads stopped...') logger.info('Stopped local server..') diff --git a/src/bastion/setup.py b/src/bastion/setup.py index 8bbb595a172..b719f9e0311 100644 --- a/src/bastion/setup.py +++ b/src/bastion/setup.py @@ -10,7 +10,7 @@ # HISTORY.rst entry. -VERSION = '1.0.0' +VERSION = '1.0.1' # The full list of classifiers is available at # https://pypi.python.org/pypi?%3Aaction=list_classifiers