Skip to content

Commit

Permalink
Changed up how defaults for redis passwords are evaluated, fixed an i…
Browse files Browse the repository at this point in the history
…ssue where the function that gets a connection to the redis runner was using the baseline redis credentials, added documentation and clarifications to the websocket listener, and made it so that `get_request_id` doesn't rely on external state.
  • Loading branch information
christophertubbs authored and aaraney committed Jun 18, 2024
1 parent 9f97e9e commit 729d30e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,8 @@ async def enqueue_job(launch_parameters: typing.Dict[str, typing.Any]):
json.dumps(launch_parameters)
)

# `publish` returns an item that can be awaitable or just about anything else. In case the returned value is
# an awaitable, wait for it until all contained async operations are complete
while inspect.isawaitable(publication_response):
publication_response = await publication_response

Expand Down Expand Up @@ -504,27 +506,50 @@ def receive_subscribed_message(self, message: typing.Union[typing.Dict[str, typi
request_id: typing.Optional[str] = None

def is_message_wrapper(possible_wrapper) -> bool:
return isinstance(possible_wrapper, dict) \
"""
Determines if a value might contain event data
Args:
possible_wrapper: A value that might contain event data
Returns:
True if an event definition and its payload may be read from the given object
"""
return isinstance(possible_wrapper, typing.Mapping) \
and not possible_wrapper.get("event") \
and "data" in possible_wrapper

def get_request_id(container: typing.Union[typing.Dict[str, typing.Any], str, bytes]) -> typing.Optional[str]:
def get_request_id(
container: typing.Union[typing.Dict[str, typing.Any], str, bytes],
previous_value: str = None
) -> typing.Optional[str]:
"""
Inspect a given container to get a possible definition for a `request_id` - default to the previous value
if not found
Args:
container: The value that may contain a request_id
previous_value: The previous value of a request_id
Returns:
The most current appropriate value for a request_id
"""
if isinstance(container, typing.Mapping) and "request_id" in container:
new_request_id = container["request_id"] or request_id
new_request_id = container["request_id"] or previous_value

if isinstance(new_request_id, bytes):
new_request_id = new_request_id.decode()

return new_request_id
return request_id
return previous_value

request_id = get_request_id(message)

# The passed message may be a wrapper if it doesn't bear an event, but DOES have a 'data' member.
# If that's the case, use its data member instead
while is_message_wrapper(message):
message: typing.Dict[str, typing.Any] = message['data']
request_id: typing.Optional[str] = get_request_id(message)
request_id = get_request_id(message, request_id)

# If it looks like the passed message might be a string or bytes representation of a dict, attempt to
# convert it to a dict
Expand All @@ -537,13 +562,14 @@ def get_request_id(container: typing.Union[typing.Dict[str, typing.Any], str, by
else:
deserialized_message = message

request_id = get_request_id(deserialized_message)
request_id = get_request_id(deserialized_message, request_id)

while is_message_wrapper(deserialized_message):
# This is only considered a message wrapper if it is a dict; linters may think this could be a string,
# but it will always be a dict here
deserialized_message = deserialized_message['data']
request_id = get_request_id(deserialized_message)
request_id = get_request_id(deserialized_message, request_id)

# The caller requires this function to be synchronous, whereas `send_message` is async;
# we're stuck using async_to_sync here as a result
async_send = async_to_sync(self.send_message)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,17 @@ def get_redis_password(password_path_variable: str = None, password_variable_nam
The optional environment variables that control this are `REDIS_PASSWORD_FILE` for the secret or `REDIS_PASS`
for the password on its own.
Args:
password_path_variable: The path to the secrets file for the password.
password_variable_name: The name of the environment variable for the password.
Returns:
The optional password to the core redis service
"""
if password_path_variable is None:
password_path_variable = "REDIS_PASSWORD_FILE"

password_filename = os.environ.get(password_path_variable, "/run/secrets/myredis_pass")
password_filename = os.environ.get(
password_path_variable or "REDIS_PASSWORD_FILE",
"/run/secrets/myredis_pass"
)

# If a password file has been identified, try to get a password from that
if os.path.exists(password_filename):
Expand All @@ -66,11 +70,8 @@ def get_redis_password(password_path_variable: str = None, password_variable_nam
# Data couldn't be read? Move on to attempting to read it from the environment variable
pass

if not password_variable_name:
password_variable_name = "REDIS_PASS"

# Fall back to env if no secrets file, further falling back to default if no env value
return os.environ.get(password_variable_name)
return os.environ.get(password_variable_name or "REDIS_PASS")


def get_full_localtimezone():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ def get_runner_connection(
A connection to a redis instance
"""
return redis.Redis(
host=host or application_values.REDIS_HOST,
port=port or application_values.REDIS_PORT,
username=username or application_values.REDIS_USERNAME,
password=password or application_values.REDIS_PASSWORD,
db=db or application_values.REDIS_DB,
host=host or application_values.RUNNER_HOST,
port=port or application_values.RUNNER_PORT,
username=username or application_values.RUNNER_USERNAME,
password=password or application_values.RUNNER_PASSWORD,
db=db or application_values.RUNNER_DB,
)


Expand Down

0 comments on commit 729d30e

Please sign in to comment.