Skip to content

Commit

Permalink
Add port-range support for local kernels
Browse files Browse the repository at this point in the history
Since local kernels (by default) use the LocalProcessProxy class, we can
add port-range support for them as well by refactoring the current logic
such that its made available to all process proxies and then taking ownership
of port assignment prior to writing out the connection file.
  • Loading branch information
kevin-bates authored and lresende committed Mar 29, 2018
1 parent 5e42b47 commit 3ddf0e3
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 65 deletions.
7 changes: 7 additions & 0 deletions enterprise_gateway/services/kernels/remotemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,11 @@ def write_connection_file(self):
# If this is a remote kernel that's using a response address, we should skip the write_connection_file
# since it will create 5 useless ports that would not adhere to port-range restrictions if configured.
if not isinstance(self.process_proxy, RemoteProcessProxy) or not self.response_address:
# However, since we *may* want to limit the selected ports, go ahead and get the ports using
# the process proxy (will be LocalPropcessProxy for default case) since the port selection will
# handle the default case when the member ports aren't set anyway.
ports = self.process_proxy.select_ports(5)
self.shell_port=ports[0]; self.iopub_port=ports[1]; self.stdin_port=ports[2]; \
self.hb_port=ports[3]; self.control_port=ports[4]

return super(RemoteKernelManager, self).write_connection_file()
130 changes: 65 additions & 65 deletions enterprise_gateway/services/processproxies/processproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def __init__(self, kernel_manager, proxy_config):
self.kernel_id = os.path.basename(self.kernel_manager.connection_file). \
replace('kernel-', '').replace('.json', '')
self.kernel_launch_timeout = default_kernel_launch_timeout
self.lower_port = 0
self.upper_port = 0
self._validate_port_range(proxy_config)

# Handle authorization sets...
# Take union of unauthorized users...
Expand Down Expand Up @@ -356,6 +359,65 @@ def load_process_info(self, process_info):
self.pgid = process_info['pgid']
self.ip = process_info['ip']

def _validate_port_range(self, proxy_config):
# Let port_range override global value - if set on kernelspec...
port_range = self.kernel_manager.parent.parent.port_range
if proxy_config.get('port_range'):
port_range = proxy_config.get('port_range')

try:
port_ranges = port_range.split("..")
self.lower_port = int(port_ranges[0])
self.upper_port = int(port_ranges[1])

port_range_size = self.upper_port - self.lower_port
if port_range_size != 0:
if port_range_size < min_port_range_size:
raise RuntimeError(
"Port range validation failed for range: '{}'. Range size must be at least {} as specified by"
" env EG_MIN_PORT_RANGE_SIZE".format(port_range, min_port_range_size))
except ValueError as ve:
raise RuntimeError("Port range validation failed for range: '{}'. Error was: {}".format(port_range, ve))
except IndexError as ie:
raise RuntimeError("Port range validation failed for range: '{}'. Error was: {}".format(port_range, ie))

self.kernel_manager.port_range = port_range

def select_ports(self, count):
"""Select and return n random ports that are available and adhere to the given port range, if applicable."""
ports = []
sockets = []
for i in range(count):
sock = self.select_socket()
ports.append(sock.getsockname()[1])
sockets.append(sock)
for sock in sockets:
sock.close()
return ports

def select_socket(self, ip=''):
"""Create and return a socket whose port is available and adheres to the given port range, if applicable."""
sock = socket(AF_INET, SOCK_STREAM)
found_port = False
retries = 0
while not found_port:
try:
sock.bind((ip, self._get_candidate_port()))
found_port = True
except Exception as e:
retries = retries + 1
if retries > max_port_range_retries:
raise RuntimeError(
"Failed to locate port within range {} after {} retries!".
format(self.kernel_manager.port_range, max_port_range_retries))
return sock

def _get_candidate_port(self):
range_size = self.upper_port - self.lower_port
if range_size == 0:
return 0
return random.randint(self.lower_port, self.upper_port)


class LocalProcessProxy(BaseProcessProxyABC):

Expand Down Expand Up @@ -391,10 +453,7 @@ def __init__(self, kernel_manager, proxy_config):
self.comm_ip = None
self.comm_port = 0
self.dest_comm_port = 0
self.lower_port = 0
self.upper_port = 0
self.tunnel_processes = {}
self._validate_port_range(proxy_config)
self._prepare_response_socket()

def launch_process(self, kernel_cmd, **kw):
Expand All @@ -416,32 +475,8 @@ def handle_timeout(self):
def confirm_remote_startup(self, kernel_cmd, **kw):
pass

def _validate_port_range(self, proxy_config):
# Let port_range override global value - if set on kernelspec...
port_range = self.kernel_manager.parent.parent.port_range
if proxy_config.get('port_range'):
port_range = proxy_config.get('port_range')

try:
port_ranges = port_range.split("..")
self.lower_port = int(port_ranges[0])
self.upper_port = int(port_ranges[1])

port_range_size = self.upper_port - self.lower_port
if port_range_size != 0:
if port_range_size < min_port_range_size:
raise RuntimeError(
"Port range validation failed for range: '{}'. Range size must be at least {} as specified by"
" env EG_MIN_PORT_RANGE_SIZE".format(port_range, min_port_range_size))
except ValueError as ve:
raise RuntimeError("Port range validation failed for range: '{}'. Error was: {}".format(port_range, ve))
except IndexError as ie:
raise RuntimeError("Port range validation failed for range: '{}'. Error was: {}".format(port_range, ie))

self.kernel_manager.port_range = port_range

def _prepare_response_socket(self):
s = self._select_socket(local_ip)
s = self.select_socket(local_ip)
port = s.getsockname()[1]
self.log.debug("Response socket bound to port: {} using {}s timeout".format(port, socket_timeout))
s.listen(1)
Expand All @@ -457,7 +492,7 @@ def _tunnel_to_kernel(self, connection_info, server, port=ssh_port, key=None):
"""
cf = connection_info

lports = self._select_ports(5)
lports = self.select_ports(5)

rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port'], cf['control_port']

Expand All @@ -479,7 +514,7 @@ def _tunnel_to_port(self, kernel_channel, remote_ip, remote_port, server, port=s
any one-off ports that require tunnelling. Note - this method assumes that passwordless ssh is
in use and has been previously validated.
"""
local_port = self._select_ports(1)[0]
local_port = self.select_ports(1)[0]
self._create_ssh_tunnel(kernel_channel, local_port, remote_port, remote_ip, server, port, key)
return local_port

Expand Down Expand Up @@ -538,41 +573,6 @@ def _get_keep_alive_interval(self, kernel_channel):
# interval for the rest of the kernel channels.
return cull_idle_timeout + 60

def _select_ports(self, count):
"""Select and return n random ports that are available and adhere to the given port range, if applicable."""
ports = []
sockets = []
for i in range(count):
sock = self._select_socket()
ports.append(sock.getsockname()[1])
sockets.append(sock)
for sock in sockets:
sock.close()
return ports

def _select_socket(self, ip=''):
"""Create and return a socket whose port is available and adheres to the given port range, if applicable."""
sock = socket(AF_INET, SOCK_STREAM)
found_port = False
retries = 0
while not found_port:
try:
sock.bind((ip, self._get_candidate_port()))
found_port = True
except Exception as e:
retries = retries + 1
if retries > max_port_range_retries:
raise RuntimeError(
"Failed to locate port within range {} after {} retries!".
format(self.kernel_manager.port_range, max_port_range_retries))
return sock

def _get_candidate_port(self):
range_size = self.upper_port - self.lower_port
if range_size == 0:
return 0
return random.randint(self.lower_port, self.upper_port)

def _decrypt(self, data):
decryptAES = lambda c, e: c.decrypt(base64.b64decode(e))
key = self.kernel_id[0:16]
Expand Down

0 comments on commit 3ddf0e3

Please sign in to comment.