Skip to content

Commit

Permalink
Updated defaults for redis connections to use connections tailored sp…
Browse files Browse the repository at this point in the history
…ecifically for their purpose in plainer language.
  • Loading branch information
christophertubbs authored and aaraney committed Jun 18, 2024
1 parent 8465473 commit 9f97e9e
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 83 deletions.
12 changes: 6 additions & 6 deletions python/gui/utilities/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def get_redis_connection(
A connection to a redis instance
"""
construction_arguments = {
key: value
for key, value in kwargs.items()
"host": host or application_values.REDIS_HOST,
"port": port or application_values.REDIS_PORT,
**kwargs
}

construction_arguments['host'] = host or application_values.REDIS_HOST
construction_arguments['port'] = port or application_values.REDIS_PORT
construction_arguments['password'] = password or application_values.REDIS_PASSWORD
if password or application_values.REDIS_PASSWORD:
construction_arguments["password"] = password or application_values.REDIS_PASSWORD

if db or application_values.REDIS_DB:
construction_arguments['db'] = db or application_values.REDIS_DB
Expand All @@ -60,4 +60,4 @@ def get_redis_connection(
raise Exception(
f"Could not connect to redis instance at {failing_address}. "
f"Make sure it is running and ready to receive connections."
) from e
) from e
Original file line number Diff line number Diff line change
Expand Up @@ -449,14 +449,32 @@ def items(self) -> typing.ItemsView:
"""
return self.__scope.items()

@property
def caller(self) -> str:
"""
Formatted data about whoever established the connection and scope
"""
return f"{self.user or 'Anonymous'}{'@' + self.client if self.client else ''} => {self.path}"


async def enqueue_job(launch_parameters: typing.Dict[str, typing.Any]):
with utilities.get_runner_connection() as runner_connection:
publication_response = runner_connection.publish(
EVALUATION_QUEUE_NAME,
json.dumps(launch_parameters)
)

while inspect.isawaitable(publication_response):
publication_response = await publication_response


class LaunchConsumer(AsyncWebsocketConsumer, ActionDescriber):
"""
Web Socket consumer that forwards messages to and from redis PubSub
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.redis_connection: redis.Redis = utilities.get_redis_connection()
self.redis_connection: redis.Redis = utilities.get_runner_connection()
self.publisher_and_subscriber: typing.Optional[redis_client.PubSub] = self.redis_connection.pubsub()
self.subscribed_to_channel = False
self.listener: typing.Optional[redis.client.PubSubWorkerThread] = None
Expand All @@ -476,43 +494,66 @@ def scope_data(self) -> ConcreteScope:

return self.__scope

def receive_subscribed_message(self, message):
def receive_subscribed_message(self, message: typing.Union[typing.Dict[str, typing.Any], str, bytes]):
"""
Interprets and transforms messages sent along the redis channel.
Args:
message: A message that was published from redis
"""
request_id: typing.Optional[str] = None

def is_message_wrapper(possible_wrapper) -> bool:
return isinstance(possible_wrapper, dict) \
and not possible_wrapper.get("event") \
and "data" in possible_wrapper

def get_request_id(container: typing.Union[typing.Dict[str, typing.Any], str, bytes]) -> typing.Optional[str]:
if isinstance(container, typing.Mapping) and "request_id" in container:
new_request_id = container["request_id"] or request_id

if isinstance(new_request_id, bytes):
new_request_id = new_request_id.decode()

return new_request_id
return request_id

request_id = get_request_id(message)

# The passed message may be a wrapper if it doesn't bear an event, but DOES have a 'data' member.
# If that's the case, use its data member instead
while is_message_wrapper(message):
message = message['data']
message: typing.Dict[str, typing.Any] = message['data']
request_id: typing.Optional[str] = get_request_id(message)

# If it looks like the passed message might be a string or bytes representation of a dict, attempt to
# convert it to a dict
if isinstance(message, (str, bytes)) and utilities.string_might_be_json(message):
try:
deserialized_message = json.loads(message)
deserialized_message: typing.Union[typing.Dict[str, typing.Any], str, bytes] = json.loads(message)
except:
# It couldn't be converted, so go ahead use the passed in value
deserialized_message = message
deserialized_message: typing.Union[typing.Dict[str, typing.Any], str, bytes] = message
else:
deserialized_message = message

request_id = get_request_id(deserialized_message)

while is_message_wrapper(deserialized_message):
# This is only considered a message wrapper if it is a dict; linters may think this could be a string,
# but it will always be a dict here
deserialized_message = deserialized_message['data']

request_id = get_request_id(deserialized_message)
# The caller requires this function to be synchronous, whereas `send_message` is async;
# we're stuck using async_to_sync here as a result
async_send = async_to_sync(self.send_message)
async_send(deserialized_message, event="subscribed_message_received")
async_send(
deserialized_message,
event="subscribed_message_received",
request_id=request_id,
logger=None,
response_type=None
)

async def connect(self):
"""
Expand Down Expand Up @@ -565,19 +606,34 @@ async def subscribe_to_channel(self, payload: typing.Dict[str, typing.Any] = Non
request_id=payload.get(REQUEST_ID_KEY)
)

async def receive(self, text_data=None, **kwargs):
async def receive(self, text_data: str = None, bytes_data: bytes = None, **kwargs):
"""
Processes messages received via the socket.
Called when the other end of the socket sends a message
Args:
text_data: The data sent over the socket
**kwargs:
text_data: Data sent over the socket in the form of text data
bytes_data: Data sent over the socket in the form of binary data
"""
request_id: typing.Optional[str] = kwargs.get(REQUEST_ID_KEY)

if not text_data and bytes_data:
try:
text_data = bytes_data.decode("utf-8")
except BaseException as e:
message = f"{str(self)}: Only text data may be received. Error: {e}"
await self.send_error(
event="receive",
message=message,
request_id=request_id
)
SOCKET_LOGGER.error(f"{str(self)}: {message}")
return

if not text_data:
message = f"{str(self)}: No data was received"
await self.send_error(event='receive', message=message, request_id=kwargs.get(REQUEST_ID_KEY))
await self.send_error(event='receive', message=message, request_id=request_id)
SOCKET_LOGGER.debug(f"{str(self)}: {message}")
return

Expand All @@ -586,26 +642,26 @@ async def receive(self, text_data=None, **kwargs):
except Exception as error:
message = f"Only JSON strings may be received and processed. Received data was {type(text_data)}"
SOCKET_LOGGER.error(message, error)
await self.send_error(event='receive', message=message, request_id=kwargs.get(REQUEST_ID_KEY))
await self.send_error(event='receive', message=message, request_id=request_id)
return

if payload is None:
message = f"No payload could be read from: '{text_data}'"
SOCKET_LOGGER.error(message)
await self.send_error(event='receive', message=message, request_id=kwargs.get(REQUEST_ID_KEY))
await self.send_error(event='receive', message=message, request_id=request_id)
return

try:
if not payload.get('action'):
message = f"{str(self)}: No action was received; expected action cannot be performed"
SOCKET_LOGGER.error(message)
await self.send_error(event='receive', message=message, request_id=kwargs.get(REQUEST_ID_KEY))
await self.send_error(event='receive', message=message, request_id=request_id)
return

if payload['action'] not in self.get_action_handlers():
message = f"{str(self)}: '{payload['action']}' is an invalid function"
SOCKET_LOGGER.debug(message)
await self.send_error(event='receive', message=message, request_id=kwargs.get(REQUEST_ID_KEY))
await self.send_error(event='receive', message=message, request_id=request_id)
return

action = payload['action']
Expand All @@ -618,10 +674,10 @@ async def receive(self, text_data=None, **kwargs):
if parameters and not action_parameters:
message = f"{str(self)}: '{action}' cannot be performed; no 'action_parameters' object was received"
SOCKET_LOGGER.error(message)
await self.send_error(event='receive', message=message, request_id=kwargs.get(REQUEST_ID_KEY))
await self.send_error(event='receive', message=message, request_id=request_id)
return

missing_parameters = list()
missing_parameters = []
for parameter_name, parameter_type in getattr(handler, "required_parameters").items():
if parameter_name not in action_parameters:
missing_parameters.append(f"{parameter_name}: {parameter_type}")
Expand All @@ -630,10 +686,10 @@ async def receive(self, text_data=None, **kwargs):
message = f"{str(self)}: '{action}' cannot be performed; " \
f"the following required parameters are missing: {', '.join(missing_parameters)}"
SOCKET_LOGGER.error(message=message)
await self.send_error(event='receive', message=message, request_id=kwargs.get(REQUEST_ID_KEY))
await self.send_error(event='receive', message=message, request_id=request_id)
return
except Exception as exception:
await self.send_error(message=exception, event="receive")
await self.send_error(message=exception, event="receive", request_id=request_id)
SOCKET_LOGGER.error(message=exception)
return

Expand Down Expand Up @@ -687,7 +743,7 @@ async def launch(self, payload: typing.Dict[str, typing.Any] = None):
}

# Send the job parameters through the channel that actively listens for jobs
self.redis_connection.publish(EVALUATION_QUEUE_NAME, json.dumps(launch_parameters))
await enqueue_job(launch_parameters=launch_parameters)

await self.send_message(
result=data,
Expand All @@ -707,14 +763,14 @@ async def launch(self, payload: typing.Dict[str, typing.Any] = None):
async def search(self, payload: typing.Dict[str, typing.Any] = None):
try:
if payload is None:
payload = dict()
payload = {}

payload = {
key.lower(): value
for key, value in payload.items()
}

filter_arguments = dict()
filter_arguments = {}

if "author" in payload:
filter_arguments['author__icontains'] = payload['author']
Expand All @@ -726,7 +782,7 @@ async def search(self, payload: typing.Dict[str, typing.Any] = None):
**filter_arguments
)

definitions_to_return = list()
definitions_to_return = []

for saved_definition in saved_definitions:
definitions_to_return.append({
Expand Down Expand Up @@ -797,7 +853,7 @@ async def get_all_templates(self, payload: typing.Dict[str, typing.Any] = None):
}

for specification_type, specification_type_name in self.template_manager.get_specification_types():
templates: typing.List[typing.Dict[str, typing.Union[str, int]]] = list()
templates: typing.List[typing.Dict[str, typing.Union[str, int]]] = []

matching_templates: typing.Sequence[models.SpecificationTemplate] = models.SpecificationTemplateCommunicator.filter(
template_specification_type=specification_type
Expand All @@ -821,7 +877,7 @@ async def get_all_templates(self, payload: typing.Dict[str, typing.Any] = None):
@required_parameters(configuration=REQUIRED_PARAMETER_TYPES.text)
async def validate_configuration(self, payload: typing.Dict[str, typing.Any] = None):

messages: typing.List[str] = list()
messages: typing.List[str] = []

try:
EvaluationSpecification.create(
Expand Down Expand Up @@ -862,7 +918,7 @@ async def get_template_by_id(self, payload: typing.Dict[str, typing.Any] = None)
template_id = payload.get("template_id")
possible_template = models.SpecificationTemplateCommunicator.filter(id=template_id)
response_type = None
response_data = dict()
response_data = {}

if possible_template:
template_entry: models.SpecificationTemplate = possible_template[0]
Expand Down Expand Up @@ -927,7 +983,7 @@ async def save(self, payload: typing.Dict[str, typing.Any] = None):
author = payload['author']
instructions = payload['instructions']

definition, was_created = models.EvaluationDefinitionCommunicator.update_or_create(
_, was_created = models.EvaluationDefinitionCommunicator.update_or_create(
name=name,
description=description,
author=author,
Expand Down Expand Up @@ -1082,7 +1138,7 @@ async def disconnect(self, close_code):

def __str__(self):
return f"[{self.__class__.__name__}] {self.channel_name} <=> " \
f"{':'.join([str(entry) for entry in self.scope['client']])}"
f"{self.scope_data.caller}"

def __repr__(self):
"""
Expand Down Expand Up @@ -1124,7 +1180,7 @@ class ChannelConsumer(AsyncWebsocketConsumer):
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.redis_connection: redis.Redis = utilities.get_redis_connection()
self.redis_connection: redis.Redis = utilities.get_channel_connection()
self.publisher_and_subscriber: typing.Optional[redis_client.PubSub] = None
self.listener: typing.Optional[redis.client.PubSubWorkerThread] = None
self.connection_group_id: typing.Optional[str] = None
Expand Down
Loading

0 comments on commit 9f97e9e

Please sign in to comment.