Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ephemeral services. #1302

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions cms/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import errno
import ipaddress
import json
import logging
import os
import socket
import sys
from collections import namedtuple
from contextlib import closing

from .log import set_detailed_logs

Expand All @@ -44,6 +47,7 @@
service (thus identifying it).

"""

def __repr__(self):
return "%s,%d" % (self.name, self.shard)

Expand All @@ -53,6 +57,75 @@
pass


class EphemeralServiceConfig:
"""Configuration of an ephemeral service. An ephemeral service is a
normal service whose shard is chosen depending on its address and
port. The port is assigned inside a range and the address must be
inside the subnet.
"""
EPHEMERAL_SHARD_OFFSET = 10000

def __init__(self, subnet, min_port, max_port):
self.subnet = ipaddress.ip_network(subnet)
self.min_port = min_port
self.max_port = max_port
if min_port > max_port:
raise ConfigError("Invalid port range: [%s, %s]"

Check warning on line 73 in cms/conf.py

View check run for this annotation

Codecov / codecov/patch

cms/conf.py#L73

Added line #L73 was not covered by tests
% (min_port, max_port))

def get_shard(self, address, port):
"""Get the ephemeral shard for a service given its address and port.

address (IPv4Address|IPv6Address): address of the service.
port (int): port of the service.

return (int): shard of the service
"""
if address not in self.subnet:
raise ValueError("The address is not inside the subnet")
host_id = int(address) & int(self.subnet.hostmask)
num_ports = self.max_port - self.min_port + 1
shard = host_id * num_ports + (port - self.min_port)
return shard + self.EPHEMERAL_SHARD_OFFSET

Check warning on line 89 in cms/conf.py

View check run for this annotation

Codecov / codecov/patch

cms/conf.py#L84-L89

Added lines #L84 - L89 were not covered by tests

def get_address(self, shard):
"""Get the address and port of a service given its shard.

shard (int): shard of the service

return (Address): address and port of the service
"""
shard -= self.EPHEMERAL_SHARD_OFFSET
num_ports = self.max_port - self.min_port + 1
port_offset = shard % num_ports
host_id = (shard - port_offset) // num_ports

port = self.min_port + port_offset
addr = self.subnet.network_address + host_id
if addr not in self.subnet:
raise ValueError("The shard is not valid")
return Address(str(addr), port)

def find_free_port(self, address):
"""Find the first open port.

address (IPv4Address|IPv6Address): local address to bind to
"""
if address.version == 4:
family = socket.AF_INET

Check warning on line 115 in cms/conf.py

View check run for this annotation

Codecov / codecov/patch

cms/conf.py#L114-L115

Added lines #L114 - L115 were not covered by tests
else:
family = socket.AF_INET6
for port in range(self.min_port, self.max_port+1):
with closing(socket.socket(family, socket.SOCK_STREAM)) as sock:
try:
sock.bind((str(address), port))
return port
except socket.error:
continue
raise ValueError("No free port found in range [%s, %s] "

Check warning on line 125 in cms/conf.py

View check run for this annotation

Codecov / codecov/patch

cms/conf.py#L117-L125

Added lines #L117 - L125 were not covered by tests
"for address %s" % (minport, maxport, address))


class AsyncConfig:
"""This class will contain the configuration for the
services. This needs to be populated at the initilization stage.
Expand All @@ -69,6 +142,7 @@
"""
core_services = {}
other_services = {}
ephemeral_services = {} # type: dict[str, EphemeralServiceConfig]


async_config = AsyncConfig()
Expand All @@ -81,6 +155,7 @@
directory for information on the meaning of the fields.

"""

def __init__(self):
"""Default values for configuration, plus decide if this
instance is running from the system path or from the source
Expand Down Expand Up @@ -274,6 +349,19 @@
self.async_config.other_services[coord] = Address(*shard)
del data["other_services"]

if 'ephemeral_services' in data:
for service_name in data['ephemeral_services']:
if service_name.startswith("_"):
continue

Check warning on line 355 in cms/conf.py

View check run for this annotation

Codecov / codecov/patch

cms/conf.py#L355

Added line #L355 was not covered by tests
service = data["ephemeral_services"][service_name]
self.async_config.ephemeral_services[service_name] = \
EphemeralServiceConfig(
service["subnet"],
service["min_port"],
service["max_port"],
)
del data["ephemeral_services"]

# Put everything else in self.
for key, value in data.items():
setattr(self, key, value)
Expand Down
2 changes: 2 additions & 0 deletions cms/io/web_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(self, listen_port, handlers, parameters, shard=0,
if num_proxies_used > 0:
self.wsgi_app = ProxyFix(self.wsgi_app, num_proxies_used)

logger.info("%s listening on '%s' at port %d",
type(self).__name__, listen_address, listen_port)
self.web_server = WSGIServer((listen_address, listen_port), self)

def __call__(self, environ, start_response):
Expand Down
9 changes: 7 additions & 2 deletions cms/server/contest/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from cms.io import WebService
from cms.locale import get_translations
from cms.server.contest.jinja2_toolbox import CWS_ENVIRONMENT
from cms.util import is_shard_ephemeral
from cmscommon.binary import hex_to_bin
from .handlers import HANDLERS
from .handlers.base import ContestListHandler
Expand Down Expand Up @@ -73,8 +74,12 @@
}

try:
listen_address = config.contest_listen_address[shard]
listen_port = config.contest_listen_port[shard]
if is_shard_ephemeral(shard):
index = 0

Check warning on line 78 in cms/server/contest/server.py

View check run for this annotation

Codecov / codecov/patch

cms/server/contest/server.py#L78

Added line #L78 was not covered by tests
else:
index = shard
listen_address = config.contest_listen_address[index]
listen_port = config.contest_listen_port[index]
except IndexError:
raise ConfigError("Wrong shard number for %s, or missing "
"address/port configuration. Please check "
Expand Down
15 changes: 14 additions & 1 deletion cms/service/EvaluationService.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def enqueue(self, item, priority, timestamp):
item_entry = item.to_dict()
del item_entry["testcase_codename"]
item_entry["multiplicity"] = 1
entry = {"item": item_entry, "priority": priority, "timestamp": make_timestamp(timestamp)}
entry = {"item": item_entry, "priority": priority,
"timestamp": make_timestamp(timestamp)}
self.queue_status_cumulative[key] = entry
return success

Expand Down Expand Up @@ -197,6 +198,11 @@ def _remove_from_cumulative_status(self, queue_entry):
if self.queue_status_cumulative[key]["item"]["multiplicity"] == 0:
del self.queue_status_cumulative[key]

def add_worker(self, worker_coord):
"""Add a new worker to the pool.
"""
self.pool.add_worker(worker_coord, ephemeral=True)


def with_post_finish_lock(func):
"""Decorator for locking on self.post_finish_lock.
Expand Down Expand Up @@ -379,6 +385,13 @@ def workers_status(self):
"""
return self.get_executor().pool.get_status()

@rpc_method
def add_worker(self, coord):
"""Register a new worker to the list of workers.
"""
service, shard = coord
self.get_executor().add_worker(ServiceCoord(service, shard))

def check_workers_timeout(self):
"""We ask WorkerPool for the unresponsive workers, and we put
again their operations in the queue.
Expand Down
8 changes: 8 additions & 0 deletions cms/service/Worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import gevent.lock

from cms import ServiceCoord
from cms.db import SessionGen, Contest, enumerate_files
from cms.db.filecacher import FileCacher, TombstoneError
from cms.grading import JobException
Expand Down Expand Up @@ -64,6 +65,13 @@ def __init__(self, shard, fake_worker_time=None):

self._fake_worker_time = fake_worker_time

self.evaluation_service = self.connect_to(
ServiceCoord("EvaluationService", 0),
on_connect=self.on_es_connection)

def on_es_connection(self, address):
self.evaluation_service.add_worker(coord=self._my_coord)

@rpc_method
def precache_files(self, contest_id):
"""RPC to ask the worker to precache of files in the contest.
Expand Down
25 changes: 23 additions & 2 deletions cms/service/workerpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,20 @@
"""Wait until a worker might be available."""
self._workers_available_event.wait()

def add_worker(self, worker_coord):
def add_worker(self, worker_coord, ephemeral=False):
"""Add a new worker to the worker pool.

worker_coord (ServiceCoord): the coordinates of the worker.
ephemeral (bool): remove the worker from the pool after the
disconnection.

"""
shard = worker_coord.shard
# Instruct GeventLibrary to connect ES to the Worker.
self._worker[shard] = self._service.connect_to(
worker_coord,
on_connect=self.on_worker_connected)
on_connect=self.on_worker_connected,
on_disconnect=lambda: self.on_worker_disconnected(worker_coord, ephemeral))

# And we fill all data.
self._operations[shard] = WorkerPool.WORKER_INACTIVE
Expand Down Expand Up @@ -183,6 +186,24 @@
# so we wake up the consumers.
self._workers_available_event.set()

def on_worker_disconnected(self, worker_coord, ephemeral):
"""If the worker is ephemeral, disable and the remove the worker
form the pool.
"""
if not ephemeral:
return
shard = worker_coord.shard
if self._operations[shard] != WorkerPool.WORKER_DISABLED:
# disable the worker and re-enqueue the lost operations
lost_operations = self.disable_worker(shard)
for operation in lost_operations:
logger.info("Operation %s put again in the queue because "

Check warning on line 200 in cms/service/workerpool.py

View check run for this annotation

Codecov / codecov/patch

cms/service/workerpool.py#L200

Added line #L200 was not covered by tests
"the worker disconnected.", operation)
priority, timestamp = operation.side_data
self._service.enqueue(operation, priority, timestamp)

Check warning on line 203 in cms/service/workerpool.py

View check run for this annotation

Codecov / codecov/patch

cms/service/workerpool.py#L202-L203

Added lines #L202 - L203 were not covered by tests
del self._worker[shard]
logger.info("Worker %s removed", worker_coord)

def acquire_worker(self, operations):
"""Tries to assign an operation to an available worker. If no workers
are available then this returns None, otherwise this returns
Expand Down
38 changes: 32 additions & 6 deletions cms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import argparse
import itertools
import ipaddress
import logging
import netifaces
import os
Expand All @@ -35,6 +36,7 @@
import gevent.socket

from cms import ServiceCoord, ConfigError, async_config, config
from cms.conf import EphemeralServiceConfig


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -136,8 +138,19 @@
raise (ValueError): if no safe shard can be returned.

"""
addrs = _find_local_addresses()
# Try to assign an ephemeral shard first. This needs to be done before
# autodetecting the shared using the ip since here we cannot detect if
# the service is already running on that port.
if provided_shard is None and service in config.async_config.ephemeral_services:
ephemeral_config = config.async_config.ephemeral_services[service]
for addr in addrs:
addr = ipaddress.ip_address(addr[1])
if addr in ephemeral_config.subnet:
port = ephemeral_config.find_free_port(addr)
shard = ephemeral_config.get_shard(addr, port)
return shard

Check warning on line 152 in cms/util.py

View check run for this annotation

Codecov / codecov/patch

cms/util.py#L146-L152

Added lines #L146 - L152 were not covered by tests
if provided_shard is None:
addrs = _find_local_addresses()
computed_shard = _get_shard_from_addresses(service, addrs)
if computed_shard is None:
logger.critical("Couldn't autodetect shard number and "
Expand All @@ -157,17 +170,30 @@
return provided_shard


def is_shard_ephemeral(shard):
"""Checks if the shard is ephemeral.

shard (int): the shard to check.

return (bool): True if the shard is ephemeral.
"""
return shard >= EphemeralServiceConfig.EPHEMERAL_SHARD_OFFSET


def get_service_address(key):
"""Give the Address of a ServiceCoord.

key (ServiceCoord): the service needed.
returns (Address): listening address of key.

"""
service, shard = key
if key in async_config.core_services:
return async_config.core_services[key]
elif key in async_config.other_services:
return async_config.other_services[key]
elif service in async_config.ephemeral_services:
return async_config.ephemeral_services[service].get_address(shard)
else:
raise KeyError("Service not found.")

Expand All @@ -179,11 +205,11 @@
returns (int): the number of shards defined in the configuration.

"""
for i in itertools.count():
try:
get_service_address(ServiceCoord(service, i))
except KeyError:
return i
count = 0
for services in (async_config.core_services, async_config.other_services):
count += len([0 for s in services if s.name == service])

return count


def default_argument_parser(description, cls, ask_contest=None):
Expand Down
Loading