Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for concurrent connections with tunnels #7719

Merged
merged 3 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/bastion/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

Release History
===============
1.0.1
+++++
* Added support for concurrent connections.

1.0.0
++++++
* Removing preview flag and update MFA documentation.
Expand Down
39 changes: 27 additions & 12 deletions src/bastion/azext_bastion/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging as logs
from contextlib import closing
from datetime import datetime
from threading import Thread
from threading import Thread, Lock
import requests
import urllib3

Expand Down Expand Up @@ -54,6 +54,8 @@ def __init__(self, cli_ctx, local_addr, local_port, bastion, bastion_endpoint, r
self.node_id = None
self.host_name = None
self.cli_ctx = cli_ctx
self.active_connections = 0
self.connection_lock = Lock()
logger.info('Creating a socket on port: %s', self.local_port)
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
logger.info('Setting socket options')
Expand Down Expand Up @@ -118,8 +120,16 @@ def _listen(self):
self.sock.listen(100)
index = 0
while True:
self.client, _address = self.sock.accept()

client, _address = self.sock.accept()
with self.connection_lock:
self.active_connections += 1
logger.info('Got a connection, starting a new thread')
thread = Thread(target=self._handle_client, args=(client, index))
thread.start()
index += 1

def _handle_client(self, client, index):
try:
auth_token = self._get_auth_token()
if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or \
self.bastion['sku']['name'] == BastionSku.Developer.name:
Expand All @@ -128,22 +138,27 @@ def _listen(self):
host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}"

verify_mode = ssl.CERT_NONE if should_disable_connection_verify() else ssl.CERT_REQUIRED
self.ws = create_connection(host,
sockopt=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),),
sslopt={'cert_reqs': verify_mode},
enable_multithread=True)
logger.info('Websocket, connected status: %s', self.ws.connected)
index = index + 1
ws = create_connection(host,
sockopt=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),),
sslopt={'cert_reqs': verify_mode},
enable_multithread=True)
logger.info('Websocket, connected status: %s', ws.connected)
logger.info('Got debugger connection... index: %s', index)
debugger_thread = Thread(target=self._listen_to_client, args=(self.client, self.ws, index))
web_socket_thread = Thread(target=self._listen_to_web_socket, args=(self.client, self.ws, index))
debugger_thread = Thread(target=self._listen_to_client, args=(client, ws, index))
web_socket_thread = Thread(target=self._listen_to_web_socket, args=(client, ws, index))
debugger_thread.start()
web_socket_thread.start()
logger.info('Both debugger and websocket threads started...')
logger.info('Successfully connected to local server..')
debugger_thread.join()
web_socket_thread.join()
self.cleanup()
except Exception as ex: # pylint: disable=broad-except
logger.info('Exception in handling client: %s', ex)
finally:
with self.connection_lock:
self.active_connections -= 1
if self.active_connections == 0:
self.cleanup()
logger.info('Both debugger and websocket threads stopped...')
logger.info('Stopped local server..')

Expand Down
2 changes: 1 addition & 1 deletion src/bastion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


# HISTORY.rst entry.
VERSION = '1.0.0'
VERSION = '1.0.1'

# The full list of classifiers is available at
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
Expand Down
Loading