Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] HTTPS Support #3380

Merged
merged 27 commits into from
Jan 20, 2025
Merged
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
56f2c04
init
cblmemo Jun 5, 2024
229231e
Merge remote-tracking branch 'origin/master' into serve-https-example…
cblmemo Sep 5, 2024
292fc73
disable update ssl
cblmemo Sep 5, 2024
5dd1c0b
Merge remote-tracking branch 'origin/master' into serve-https-example…
cblmemo Sep 16, 2024
d8336d4
address comment
cblmemo Sep 16, 2024
dcf6d36
fix
cblmemo Sep 16, 2024
8c01625
fix linter
cblmemo Sep 16, 2024
e2b6af8
remove key & cert file; add smoke test prototype
cblmemo Sep 17, 2024
cca6c02
add byo ssl key
cblmemo Sep 17, 2024
f8686ed
format
cblmemo Sep 17, 2024
f0613bd
comments
cblmemo Sep 17, 2024
80ccacb
smoke test passed
cblmemo Sep 17, 2024
4243a73
check keyfile and certfile with schema; chedck if the file exists
cblmemo Sep 17, 2024
112e06c
nit
cblmemo Sep 17, 2024
385a658
add column TLS_ENCRYPTED in service table; add schema in sky serve st…
cblmemo Sep 17, 2024
95e3ea9
fix smoke test
cblmemo Sep 17, 2024
ad6768d
fix smoke test
cblmemo Sep 18, 2024
bfd1741
Merge remote-tracking branch 'origin/master' into serve-https-example…
cblmemo Sep 19, 2024
f239e68
Merge remote-tracking branch 'origin/master' into serve-https-example…
cblmemo Oct 24, 2024
4262353
apply suggestions from code review
cblmemo Oct 24, 2024
846e072
use env vars in the test & example
cblmemo Oct 24, 2024
ba2d038
rename test file
cblmemo Oct 24, 2024
30ef361
Merge remote-tracking branch 'origin/master' into serve-https-example…
cblmemo Oct 28, 2024
37e8242
Merge remote-tracking branch 'origin/master' into serve-https-example…
cblmemo Jan 14, 2025
c57dd77
Merge remote-tracking branch 'origin/master' into serve-https-example…
cblmemo Jan 17, 2025
1c7c76b
print output for https test
cblmemo Jan 19, 2025
735ce29
upd messages
cblmemo Jan 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address comment
cblmemo committed Sep 16, 2024
commit d8336d49749b02adc0fb35d7f3838d241534fcaf
2 changes: 1 addition & 1 deletion examples/serve/https/service.yaml
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
service:
readiness_probe: /
replicas: 1
ssl:
tls:
# Generated by running:
# openssl req -x509 -newkey rsa:2048 -days 36500 -nodes
keyfile: examples/serve/https/key.pem
49 changes: 37 additions & 12 deletions sky/serve/core.py
Original file line number Diff line number Diff line change
@@ -90,6 +90,37 @@ def _validate_service_task(task: 'sky.Task') -> None:
'Please specify the same port instead.')


def _rewrite_tls_credential_paths(service_name: str,
task: 'sky.Task') -> Dict[str, Any]:
"""Rewrite the paths of TLS credentials in the task.

Args:
service_name: Name of the service.
task: sky.Task to rewrite.

Returns:
The generated template variables for TLS.
"""
service_spec = task.service
# Already checked by _validate_service_task
assert service_spec is not None
if service_spec.tls_credential is None:
return {}
remote_tls_keyfile = (
serve_utils.generate_remote_tls_keyfile_name(service_name))
remote_tls_certfile = (
serve_utils.generate_remote_tls_certfile_name(service_name))
tls_template_vars = {
'remote_tls_keyfile': remote_tls_keyfile,
'remote_tls_certfile': remote_tls_certfile,
'local_tls_keyfile': service_spec.tls_credential.keyfile,
'local_tls_certfile': service_spec.tls_credential.certfile,
}
service_spec.tls_credential = serve_utils.TLSCredential(
remote_tls_keyfile, remote_tls_certfile)
return tls_template_vars


@usage_lib.entrypoint
def up(
task: 'sky.Task',
@@ -127,6 +158,8 @@ def up(
controller_utils.maybe_translate_local_file_mounts_and_sync_up(task,
path='serve')

tls_template_vars = _rewrite_tls_credential_paths(service_name, task)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename the function? Seems not actually rewrite?

Suggested change
tls_template_vars = _rewrite_tls_credential_paths(service_name, task)
tls_template_vars = _get_tls_template_vars(service_name, task)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we are rewriting the task.tls_credential in the function. Renamed to _rewrite_tls_credential_paths_and_get_tls_env_vars now.


with tempfile.NamedTemporaryFile(
prefix=f'service-task-{service_name}-',
mode='w',
@@ -146,26 +179,19 @@ def up(
controller_resources = controller_utils.get_controller_resources(
controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
task_resources=task.resources)
remote_ssl_keyfile = (
serve_utils.generate_remote_ssl_keyfile_name(service_name))
remote_ssl_certfile = (
serve_utils.generate_remote_ssl_certfile_name(service_name))

service_spec = task.service
# Already validated in _validate_service_task
assert service_spec is not None, 'Service section not found.'
vars_to_fill = {
'remote_task_yaml_path': remote_tmp_task_yaml_path,
'local_task_yaml_path': service_file.name,
'remote_ssl_keyfile': remote_ssl_keyfile,
'remote_ssl_certfile': remote_ssl_certfile,
'local_ssl_keyfile': service_spec.ssl_keyfile,
'local_ssl_certfile': service_spec.ssl_certfile,
'service_name': service_name,
'controller_log_file': controller_log_file,
'remote_user_config_path': remote_config_yaml_path,
'modified_catalogs':
service_catalog_common.get_modified_catalog_file_mounts(),
**tls_template_vars,
**controller_utils.shared_controller_vars_to_fill(
controller=controller_utils.Controllers.SKY_SERVE_CONTROLLER,
remote_user_config_path=remote_config_yaml_path,
@@ -326,11 +352,10 @@ def update(
_validate_service_task(task)

assert task.service is not None
if (task.service.ssl_keyfile is not None or
task.service.ssl_certfile is not None):
logger.warning('Updating SSL keyfile and certfile is not supported. '
if task.service.tls_credential is not None:
logger.warning('Updating TLS keyfile and certfile is not supported. '
'Any updates to the keyfile and certfile will not take '
'effect. To update SSL keyfile and certfile, please '
'effect. To update TLS keyfile and certfile, please '
'tear down the service and spin up a new one.')

handle = backend_utils.is_controller_accessible(
26 changes: 11 additions & 15 deletions sky/serve/load_balancer.py
Original file line number Diff line number Diff line change
@@ -28,8 +28,7 @@ class SkyServeLoadBalancer:
"""

def __init__(self, controller_url: str, load_balancer_port: int,
ssl_keyfile: Optional[str],
ssl_certfile: Optional[str]) -> None:
tls_credential: Optional[serve_utils.TLSCredential]) -> None:
"""Initialize the load balancer.

Args:
@@ -43,8 +42,8 @@ def __init__(self, controller_url: str, load_balancer_port: int,
lb_policies.RoundRobinPolicy())
self._request_aggregator: serve_utils.RequestsAggregator = (
serve_utils.RequestTimestamp())
self._ssl_keyfile: Optional[str] = ssl_keyfile
self._ssl_certfile: Optional[str] = ssl_certfile
self._tls_credential: Optional[serve_utils.TLSCredential] = (
tls_credential)
# TODO(tian): httpx.Client has a resource limit of 100 max connections
# for each client. We should wait for feedback on the best max
# connections.
@@ -221,12 +220,10 @@ async def startup():
# Register controller synchronization task
asyncio.create_task(self._sync_with_controller())

ssl_kwargs = {} if self._ssl_keyfile is None else {
'ssl_keyfile': self._ssl_keyfile,
'ssl_certfile': self._ssl_certfile,
}
ssl_kwargs = ({} if self._tls_credential is None else
self._tls_credential.dump_to_uvicorn_arguments())

schema = 'https' if self._ssl_keyfile is not None else 'http'
schema = 'https' if self._tls_credential is not None else 'http'

logger.info('SkyServe Load Balancer started on '
f'{schema}://0.0.0.0:{self._load_balancer_port}')
@@ -237,14 +234,13 @@ async def startup():
**ssl_kwargs)


def run_load_balancer(controller_addr: str,
load_balancer_port: int,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None):
def run_load_balancer(
controller_addr: str,
load_balancer_port: int,
tls_credential: Optional[serve_utils.TLSCredential] = None):
load_balancer = SkyServeLoadBalancer(controller_url=controller_addr,
load_balancer_port=load_balancer_port,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile)
tls_credential=tls_credential)
load_balancer.run()


30 changes: 26 additions & 4 deletions sky/serve/serve_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""User interface with the SkyServe."""
import base64
import collections
import dataclasses
import enum
import os
import pathlib
@@ -86,6 +87,27 @@ class UpdateMode(enum.Enum):
BLUE_GREEN = 'blue_green'


@dataclasses.dataclass
class TLSCredential:
"""TLS credential for the service."""
keyfile: Optional[str]
certfile: Optional[str]

def __post_init__(self) -> None:
if self.keyfile is not None and self.certfile is None:
with ux_utils.print_exception_no_traceback():
raise ValueError('TLS certfile is required if keyfile is set.')
if self.certfile is not None and self.keyfile is None:
with ux_utils.print_exception_no_traceback():
raise ValueError('TLS keyfile is required if certfile is set.')

def dump_to_uvicorn_arguments(self) -> Dict[str, Any]:
return {
'ssl_keyfile': self.keyfile,
'ssl_certfile': self.certfile,
}


DEFAULT_UPDATE_MODE = UpdateMode.ROLLING

_SIGNAL_TO_ERROR = {
@@ -237,16 +259,16 @@ def generate_replica_log_file_name(service_name: str, replica_id: int) -> str:
return os.path.join(dir_name, f'replica_{replica_id}.log')


def generate_remote_ssl_keyfile_name(service_name: str) -> str:
def generate_remote_tls_keyfile_name(service_name: str) -> str:
dir_name = generate_remote_service_dir_name(service_name)
# Don't expand here since it is used for remote machine.
return os.path.join(dir_name, 'ssl_keyfile')
return os.path.join(dir_name, 'tls_keyfile')


def generate_remote_ssl_certfile_name(service_name: str) -> str:
def generate_remote_tls_certfile_name(service_name: str) -> str:
dir_name = generate_remote_service_dir_name(service_name)
# Don't expand here since it is used for remote machine.
return os.path.join(dir_name, 'ssl_certfile')
return os.path.join(dir_name, 'tls_certfile')


def generate_replica_cluster_name(service_name: str, replica_id: int) -> str:
18 changes: 5 additions & 13 deletions sky/serve/service.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import shutil
import time
import traceback
from typing import Dict, List, Optional
from typing import Dict, List

import filelock

@@ -128,8 +128,7 @@ def _cleanup(service_name: str) -> bool:
return failed


def _start(service_name: str, tmp_task_yaml: str, job_id: int,
ssl_keyfile: Optional[str], ssl_certfile: Optional[str]):
def _start(service_name: str, tmp_task_yaml: str, job_id: int):
"""Starts the service."""
# Generate ssh key pair to avoid race condition when multiple sky.launch
# are executed at the same time.
@@ -225,8 +224,8 @@ def _get_host():
target=ux_utils.RedirectOutputForProcess(
load_balancer.run_load_balancer,
load_balancer_log_file).run,
args=(controller_addr, load_balancer_port, ssl_keyfile,
ssl_certfile))
args=(controller_addr, load_balancer_port,
service_spec.tls_credential))
load_balancer_process.start()
serve_state.set_service_load_balancer_port(service_name,
load_balancer_port)
@@ -275,15 +274,8 @@ def _get_host():
required=True,
type=int,
help='Job id for the service job.')
parser.add_argument('--ssl-keyfile',
type=str,
help='Path to the SSL key file')
parser.add_argument('--ssl-certfile',
type=str,
help='Path to the SSL certificate file')
args = parser.parse_args()
# We start process with 'spawn', because 'fork' could result in weird
# behaviors; 'spawn' is also cross-platform.
multiprocessing.set_start_method('spawn', force=True)
_start(args.service_name, args.task_yaml, args.job_id, args.ssl_keyfile,
args.ssl_certfile)
_start(args.service_name, args.task_yaml, args.job_id)
52 changes: 25 additions & 27 deletions sky/serve/service_spec.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import yaml

from sky.serve import constants
from sky.serve import serve_utils
from sky.utils import common_utils
from sky.utils import schemas
from sky.utils import ux_utils
@@ -24,8 +25,7 @@ def __init__(
max_replicas: Optional[int] = None,
target_qps_per_replica: Optional[float] = None,
post_data: Optional[Dict[str, Any]] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
tls_credential: Optional[serve_utils.TLSCredential] = None,
readiness_headers: Optional[Dict[str, str]] = None,
dynamic_ondemand_fallback: Optional[bool] = None,
base_ondemand_fallback_replicas: Optional[int] = None,
@@ -79,22 +79,15 @@ def __init__(
'Currently, SkyServe will cleanup failed replicas'
'and auto restart it to keep the service running.')

if ssl_keyfile is not None and ssl_certfile is None:
with ux_utils.print_exception_no_traceback():
raise ValueError('SSL certfile is required if keyfile is set.')
if ssl_certfile is not None and ssl_keyfile is None:
with ux_utils.print_exception_no_traceback():
raise ValueError('SSL keyfile is required if certfile is set.')

self._readiness_path: str = readiness_path
self._initial_delay_seconds: int = initial_delay_seconds
self._readiness_timeout_seconds: int = readiness_timeout_seconds
self._min_replicas: int = min_replicas
self._max_replicas: Optional[int] = max_replicas
self._target_qps_per_replica: Optional[float] = target_qps_per_replica
self._post_data: Optional[Dict[str, Any]] = post_data
self._ssl_keyfile: Optional[str] = ssl_keyfile
self._ssl_certfile: Optional[str] = ssl_certfile
self._tls_credential: Optional[serve_utils.TLSCredential] = (
tls_credential)
self._readiness_headers: Optional[Dict[str, str]] = readiness_headers
self._dynamic_ondemand_fallback: Optional[
bool] = dynamic_ondemand_fallback
@@ -189,10 +182,12 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec':
service_config['dynamic_ondemand_fallback'] = policy_section.get(
'dynamic_ondemand_fallback', None)

ssl_section = config.get('ssl', None)
if ssl_section is not None:
service_config['ssl_keyfile'] = ssl_section.get('keyfile', None)
service_config['ssl_certfile'] = ssl_section.get('certfile', None)
tls_section = config.get('tls', None)
if tls_section is not None:
service_config['tls_credential'] = serve_utils.TLSCredential(
keyfile=tls_section.get('keyfile', None),
certfile=tls_section.get('certfile', None),
)

return SkyServiceSpec(**service_config)

@@ -249,8 +244,9 @@ def add_if_not_none(section, key, value, no_empty: bool = False):
self.upscale_delay_seconds)
add_if_not_none('replica_policy', 'downscale_delay_seconds',
self.downscale_delay_seconds)
add_if_not_none('ssl', 'keyfile', self.ssl_keyfile)
add_if_not_none('ssl', 'certfile', self.ssl_certfile)
if self.tls_credential is not None:
add_if_not_none('tls', 'keyfile', self.tls_credential.keyfile)
add_if_not_none('tls', 'certfile', self.tls_credential.certfile)
return config

def probe_str(self):
@@ -295,18 +291,19 @@ def autoscaling_policy_str(self):
f'replica{max_plural} (target QPS per replica: '
f'{self.target_qps_per_replica})')

def ssl_str(self):
if self.ssl_keyfile is None and self.ssl_certfile is None:
return 'No SSL Enabled'
return f'Keyfile: {self.ssl_keyfile}, Certfile: {self.ssl_certfile}'
def tls_str(self):
if self.tls_credential is None:
return 'No TLS Enabled'
return (f'Keyfile: {self.tls_credential.keyfile}, '
f'Certfile: {self.tls_credential.certfile}')

def __repr__(self) -> str:
return textwrap.dedent(f"""\
Readiness probe method: {self.probe_str()}
Readiness initial delay seconds: {self.initial_delay_seconds}
Readiness probe timeout seconds: {self.readiness_timeout_seconds}
Replica autoscaling policy: {self.autoscaling_policy_str()}
SSL Certificates: {self.ssl_str()}
TLS Certificates: {self.tls_str()}
Spot Policy: {self.spot_policy_str()}
""")

@@ -340,12 +337,13 @@ def post_data(self) -> Optional[Dict[str, Any]]:
return self._post_data

@property
def ssl_keyfile(self) -> Optional[str]:
return self._ssl_keyfile
def tls_credential(self) -> Optional[serve_utils.TLSCredential]:
return self._tls_credential

@property
def ssl_certfile(self) -> Optional[str]:
return self._ssl_certfile
@tls_credential.setter
def tls_credential(self,
value: Optional[serve_utils.TLSCredential]) -> None:
self._tls_credential = value

@property
def readiness_headers(self) -> Optional[Dict[str, str]]:
10 changes: 3 additions & 7 deletions sky/templates/sky-serve-controller.yaml.j2
Original file line number Diff line number Diff line change
@@ -27,9 +27,9 @@ file_mounts:
{%- for remote_catalog_path, local_catalog_path in modified_catalogs.items() %}
{{remote_catalog_path}}: {{local_catalog_path}}
{%- endfor %}
{%- if local_ssl_keyfile is not none %}
{{remote_ssl_keyfile}}: {{local_ssl_keyfile}}
{{remote_ssl_certfile}}: {{local_ssl_certfile}}
{%- if local_tls_keyfile is not none %}
{{remote_tls_keyfile}}: {{local_tls_keyfile}}
{{remote_tls_certfile}}: {{local_tls_certfile}}
{%- endif %}

run: |
@@ -41,10 +41,6 @@ run: |
--service-name {{service_name}} \
--task-yaml {{remote_task_yaml_path}} \
--job-id $SKYPILOT_INTERNAL_JOB_ID \
{%- if local_ssl_keyfile is not none %}
--ssl-keyfile {{remote_ssl_keyfile}} \
--ssl-certfile {{remote_ssl_certfile}} \
{%- endif %}
>> {{controller_log_file}} 2>&1

envs:
2 changes: 1 addition & 1 deletion sky/utils/schemas.py
Original file line number Diff line number Diff line change
@@ -375,7 +375,7 @@ def get_service_schema():
'replicas': {
'type': 'integer',
},
'ssl': {
'tls': {
'type': 'object',
'additionalProperties': False,
'properties': {