Skip to content

Commit

Permalink
Support for concurrent connections with tunnels
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardsp committed Jun 17, 2024
1 parent d78a1cb commit 8fccede
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions src/bastion/azext_bastion/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging as logs
from contextlib import closing
from datetime import datetime
from threading import Thread
from threading import Thread, Lock
import requests
import urllib3

Expand Down Expand Up @@ -54,6 +54,8 @@ def __init__(self, cli_ctx, local_addr, local_port, bastion, bastion_endpoint, r
self.node_id = None
self.host_name = None
self.cli_ctx = cli_ctx
self.active_connections = 0
self.connection_lock = Lock()
logger.info('Creating a socket on port: %s', self.local_port)
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
logger.info('Setting socket options')
Expand Down Expand Up @@ -118,8 +120,16 @@ def _listen(self):
self.sock.listen(100)
index = 0
while True:
self.client, _address = self.sock.accept()

client, _address = self.sock.accept()
with self.connection_lock:
self.active_connections += 1
logger.info('Got a connection, starting a new thread')
thread = Thread(target=self._handle_client, args=(client, index))
thread.start()
index += 1

def _handle_client(self, client, index):
try:
auth_token = self._get_auth_token()
if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or \
self.bastion['sku']['name'] == BastionSku.Developer.name:
Expand All @@ -128,22 +138,27 @@ def _listen(self):
host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}"

verify_mode = ssl.CERT_NONE if should_disable_connection_verify() else ssl.CERT_REQUIRED
self.ws = create_connection(host,
sockopt=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),),
sslopt={'cert_reqs': verify_mode},
enable_multithread=True)
logger.info('Websocket, connected status: %s', self.ws.connected)
index = index + 1
ws = create_connection(host,
sockopt=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),),
sslopt={'cert_reqs': verify_mode},
enable_multithread=True)
logger.info('Websocket, connected status: %s', ws.connected)
logger.info('Got debugger connection... index: %s', index)
debugger_thread = Thread(target=self._listen_to_client, args=(self.client, self.ws, index))
web_socket_thread = Thread(target=self._listen_to_web_socket, args=(self.client, self.ws, index))
debugger_thread = Thread(target=self._listen_to_client, args=(client, ws, index))
web_socket_thread = Thread(target=self._listen_to_web_socket, args=(client, ws, index))
debugger_thread.start()
web_socket_thread.start()
logger.info('Both debugger and websocket threads started...')
logger.info('Successfully connected to local server..')
debugger_thread.join()
web_socket_thread.join()
self.cleanup()
except Exception as ex:
logger.info('Exception in handling client: %s', ex)
finally:
with self.connection_lock:
self.active_connections -= 1
if self.active_connections == 0:
self.cleanup()
logger.info('Both debugger and websocket threads stopped...')
logger.info('Stopped local server..')

Expand Down

0 comments on commit 8fccede

Please sign in to comment.