Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 285 additions & 0 deletions python/ray/serve/_private/broker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
# This module provides broker clients for querying queue lengths from message brokers.
# Adapted from Flower's broker.py (https://github.com/mher/flower/blob/master/flower/utils/broker.py)
# with the following modification:
# - Added close() method to BrokerBase and RedisBase for resource cleanup

import json
import logging
import numbers
import socket
from urllib.parse import quote, unquote, urljoin, urlparse

from tornado import httpclient, ioloop

from ray.serve._private.constants import SERVE_LOGGER_NAME

try:
import redis
except ImportError:
redis = None


logger = logging.getLogger(SERVE_LOGGER_NAME)


class BrokerBase:
def __init__(self, broker_url, *_, **__):
purl = urlparse(broker_url)
self.host = purl.hostname
self.port = purl.port
self.vhost = purl.path[1:]

username = purl.username
password = purl.password

self.username = unquote(username) if username else username
self.password = unquote(password) if password else password

async def queues(self, names):
raise NotImplementedError

def close(self):
"""Close any open connections. Override in subclasses as needed."""
pass


class RabbitMQ(BrokerBase):
def __init__(self, broker_url, http_api, io_loop=None, **__):
super().__init__(broker_url)
self.io_loop = io_loop or ioloop.IOLoop.instance()

self.host = self.host or "localhost"
self.port = self.port or 15672
self.vhost = quote(self.vhost, "") or "/" if self.vhost != "/" else self.vhost
self.username = self.username or "guest"
self.password = self.password or "guest"

if not http_api:
http_api = f"http://{self.username}:{self.password}@{self.host}:{self.port}/api/{self.vhost}"

try:
self.validate_http_api(http_api)
except ValueError:
logger.error("Invalid broker api url: %s", http_api)

self.http_api = http_api

async def queues(self, names):
url = urljoin(self.http_api, "queues/" + self.vhost)
api_url = urlparse(self.http_api)
username = unquote(api_url.username or "") or self.username
password = unquote(api_url.password or "") or self.password

http_client = httpclient.AsyncHTTPClient()
try:
response = await http_client.fetch(
url,
auth_username=username,
auth_password=password,
connect_timeout=1.0,
request_timeout=2.0,
validate_cert=False,
)
except (socket.error, httpclient.HTTPError) as e:
logger.error("RabbitMQ management API call failed: %s", e)
return []
finally:
http_client.close()

if response.code == 200:
info = json.loads(response.body.decode())
return [x for x in info if x["name"] in names]
response.rethrow()

@classmethod
def validate_http_api(cls, http_api):
url = urlparse(http_api)
if url.scheme not in ("http", "https"):
raise ValueError(f"Invalid http api schema: {url.scheme}")


class RedisBase(BrokerBase):
DEFAULT_SEP = "\x06\x16"
DEFAULT_PRIORITY_STEPS = [0, 3, 6, 9]

def __init__(self, broker_url, *_, **kwargs):
super().__init__(broker_url)
self.redis = None

if not redis:
raise ImportError("redis library is required")

broker_options = kwargs.get("broker_options", {})
self.priority_steps = broker_options.get(
"priority_steps", self.DEFAULT_PRIORITY_STEPS
)
self.sep = broker_options.get("sep", self.DEFAULT_SEP)
self.broker_prefix = broker_options.get("global_keyprefix", "")

def _q_for_pri(self, queue, pri):
if pri not in self.priority_steps:
raise ValueError("Priority not in priority steps")
# pylint: disable=consider-using-f-string
return "{0}{1}{2}".format(*((queue, self.sep, pri) if pri else (queue, "", "")))

async def queues(self, names):
queue_stats = []
for name in names:
priority_names = [
self.broker_prefix + self._q_for_pri(name, pri)
for pri in self.priority_steps
]
queue_stats.append(
{
"name": name,
"messages": sum((self.redis.llen(x) for x in priority_names)),
}
)
return queue_stats

def close(self):
"""Close the Redis connection."""
if self.redis is not None:
self.redis.close()
self.redis = None


class Redis(RedisBase):
def __init__(self, broker_url, *args, **kwargs):
super().__init__(broker_url, *args, **kwargs)
self.host = self.host or "localhost"
self.port = self.port or 6379
self.vhost = self._prepare_virtual_host(self.vhost)
self.redis = self._get_redis_client()

def _prepare_virtual_host(self, vhost):
if not isinstance(vhost, numbers.Integral):
if not vhost or vhost == "/":
vhost = 0
elif vhost.startswith("/"):
vhost = vhost[1:]
try:
vhost = int(vhost)
except ValueError as exc:
raise ValueError(
f"Database is int between 0 and limit - 1, not {vhost}"
) from exc
return vhost

def _get_redis_client_args(self):
return {
"host": self.host,
"port": self.port,
"db": self.vhost,
"username": self.username,
"password": self.password,
}

def _get_redis_client(self):
return redis.Redis(**self._get_redis_client_args())


class RedisSentinel(RedisBase):
def __init__(self, broker_url, *args, **kwargs):
super().__init__(broker_url, *args, **kwargs)
broker_options = kwargs.get("broker_options", {})
broker_use_ssl = kwargs.get("broker_use_ssl", None)
self.host = self.host or "localhost"
self.port = self.port or 26379
self.vhost = self._prepare_virtual_host(self.vhost)
self.master_name = self._prepare_master_name(broker_options)
self.redis = self._get_redis_client(broker_options, broker_use_ssl)

def _prepare_virtual_host(self, vhost):
if not isinstance(vhost, numbers.Integral):
if not vhost or vhost == "/":
vhost = 0
elif vhost.startswith("/"):
vhost = vhost[1:]
try:
vhost = int(vhost)
except ValueError as exc:
raise ValueError(
f"Database is int between 0 and limit - 1, not {vhost}"
) from exc
return vhost

def _prepare_master_name(self, broker_options):
try:
master_name = broker_options["master_name"]
except KeyError as exc:
raise ValueError("master_name is required for Sentinel broker") from exc
return master_name

def _get_redis_client(self, broker_options, broker_use_ssl):
connection_kwargs = {
"password": self.password,
"sentinel_kwargs": broker_options.get("sentinel_kwargs"),
}
if isinstance(broker_use_ssl, dict):
connection_kwargs["ssl"] = True
connection_kwargs.update(broker_use_ssl)
# get all sentinel hosts from Celery App config and use them to initialize Sentinel
sentinel = redis.sentinel.Sentinel(
[(self.host, self.port)], **connection_kwargs
)
redis_client = sentinel.master_for(self.master_name)
return redis_client


class RedisSocket(RedisBase):
def __init__(self, broker_url, *args, **kwargs):
super().__init__(broker_url, *args, **kwargs)
self.redis = redis.Redis(
unix_socket_path="/" + self.vhost, password=self.password
)


class RedisSsl(Redis):
"""
Redis SSL class offering connection to the broker over SSL.
This does not currently support SSL settings through the url, only through
the broker_use_ssl celery configuration.
"""

def __init__(self, broker_url, *args, **kwargs):
if "broker_use_ssl" not in kwargs:
raise ValueError("rediss broker requires broker_use_ssl")
self.broker_use_ssl = kwargs.get("broker_use_ssl", {})
super().__init__(broker_url, *args, **kwargs)

def _get_redis_client_args(self):
client_args = super()._get_redis_client_args()
client_args["ssl"] = True
if isinstance(self.broker_use_ssl, dict):
client_args.update(self.broker_use_ssl)
return client_args


class Broker:
"""Factory returning the appropriate broker client based on URL scheme.

Supported schemes:
``amqp`` or ``amqps`` -> :class:`RabbitMQ`
``redis`` -> :class:`Redis`
``rediss`` -> :class:`RedisSsl`
``redis+socket`` -> :class:`RedisSocket`
``sentinel`` -> :class:`RedisSentinel`
"""

def __new__(cls, broker_url, *args, **kwargs):
scheme = urlparse(broker_url).scheme
if scheme in ("amqp", "amqps"):
return RabbitMQ(broker_url, *args, **kwargs)
if scheme == "redis":
return Redis(broker_url, *args, **kwargs)
if scheme == "rediss":
return RedisSsl(broker_url, *args, **kwargs)
if scheme == "redis+socket":
return RedisSocket(broker_url, *args, **kwargs)
if scheme == "sentinel":
return RedisSentinel(broker_url, *args, **kwargs)
raise NotImplementedError

async def queues(self, names):
raise NotImplementedError
Loading