Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring & testing utils/authentication.py #234

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 51 additions & 81 deletions src/sparsezoo/utils/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,77 +45,8 @@
)


class SparseZooCredentials:
"""
Class wrapping around the sparse zoo credentials file.
"""

def __init__(self):
if os.path.exists(CREDENTIALS_YAML):
_LOGGER.debug(f"Loading sparse zoo credentials from {CREDENTIALS_YAML}")
with open(CREDENTIALS_YAML) as credentials_file:
credentials_yaml = yaml.safe_load(credentials_file)
if credentials_yaml and CREDENTIALS_YAML_TOKEN_KEY in credentials_yaml:
self._token = credentials_yaml[CREDENTIALS_YAML_TOKEN_KEY]["token"]
self._created = credentials_yaml[CREDENTIALS_YAML_TOKEN_KEY][
"created"
]
else:
self._token = None
self._created = None
else:
_LOGGER.debug(
f"No sparse zoo credentials files found at {CREDENTIALS_YAML}"
)
self._token = None
self._created = None

def save_token(self, token: str, created: float):
"""
Save the jwt for accessing sparse zoo APIs. Will create the credentials file
if it does not exist already.

:param token: the jwt for accessing sparse zoo APIs
:param created: the approximate time the token was created
"""
_LOGGER.debug(f"Saving sparse zoo credentials at {CREDENTIALS_YAML}")
if not os.path.exists(CREDENTIALS_YAML):
create_parent_dirs(CREDENTIALS_YAML)
with open(CREDENTIALS_YAML, "w+") as credentials_file:
credentials_yaml = yaml.safe_load(credentials_file)
if credentials_yaml is None:
credentials_yaml = {}
credentials_yaml[CREDENTIALS_YAML_TOKEN_KEY] = {
"token": token,
"created": created,
}
self._token = token
self._created = created

yaml.safe_dump(credentials_yaml, credentials_file)

@property
def token(self):
"""
:return: obtain the token if under 1 day old, else return None
"""
_LOGGER.debug(f"Obtaining sparse zoo credentials from {CREDENTIALS_YAML}")
if self._token and self._created is not None:
creation_date = datetime.fromtimestamp(self._created, tz=timezone.utc)
creation_difference = datetime.now(tz=timezone.utc) - creation_date
if creation_difference.days < 30:
return self._token
else:
_LOGGER.debug(f"Expired sparse zoo credentials at {CREDENTIALS_YAML}")
return None
else:
_LOGGER.debug(f"No sparse zoo credentials found at {CREDENTIALS_YAML}")
return None


def get_auth_header(
authentication_type: str = PUBLIC_AUTH_TYPE,
force_token_refresh: bool = False,
force_token_refresh: bool = False, path: str = CREDENTIALS_YAML
) -> Dict:
"""
Obtain an authentication header token from either credentials file or from APIs
Expand All @@ -124,24 +55,63 @@ def get_auth_header(

Currently only 'public' authentication type is supported.

:param authentication_type: authentication type for generating token
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed authentication_type as a parameter since there was only 1 valid type and nowhere uses the parameter. Seem okay?

:param force_token_refresh: forces a new token to be generated
:return: An authentication header with key 'nm-token-header' containing the header
token
"""
credentials = SparseZooCredentials()
token = credentials.token
if token and not force_token_refresh:
return {NM_TOKEN_HEADER: token}
elif authentication_type.lower() == PUBLIC_AUTH_TYPE:
token = _maybe_load_token(path)
if token is None or force_token_refresh:
_LOGGER.info("Obtaining new sparse zoo credentials token")
created = time.time()
response = requests.post(
url=AUTH_API, data=json.dumps({"authentication_type": PUBLIC_AUTH_TYPE})
)
response.raise_for_status()
token = response.json()["token"]
credentials.save_token(token, created)
return {NM_TOKEN_HEADER: token}
else:
raise Exception(f"Authentication type {PUBLIC_AUTH_TYPE} not supported.")
created = time.time()
_save_token(token, created, path)
return {NM_TOKEN_HEADER: token}


def _maybe_load_token(path: str):
if not os.path.exists(path):
_LOGGER.debug(f"No sparse zoo credentials files found at {path}")
return None

_LOGGER.debug(f"Loading sparse zoo credentials from {path}")

with open(path) as fp:
creds = yaml.safe_load(fp)

if creds is None or CREDENTIALS_YAML_TOKEN_KEY not in creds:
_LOGGER.debug(f"No sparse zoo credentials found at {path}")
return None

info = creds[CREDENTIALS_YAML_TOKEN_KEY]
if "token" not in info or "created" not in info:
_LOGGER.debug(f"No sparse zoo credentials found at {path}")
return None

date_created = datetime.fromtimestamp(info["created"], tz=timezone.utc)
creation_difference = datetime.now(tz=timezone.utc) - date_created

if creation_difference.days > 30:
corey-nm marked this conversation as resolved.
Show resolved Hide resolved
_LOGGER.debug(f"Expired sparse zoo credentials at {path}")
return None

return info["token"]


def _save_token(token: str, created: float, path: str):
"""
Save the jwt for accessing sparse zoo APIs. Will create the credentials file
if it does not exist already.

:param token: the jwt for accessing sparse zoo APIs
:param created: the approximate time the token was created
"""
_LOGGER.debug(f"Saving sparse zoo credentials at {CREDENTIALS_YAML}")
if not os.path.exists(path):
create_parent_dirs(path)
with open(path, "w+") as fp:
auth = {CREDENTIALS_YAML_TOKEN_KEY: dict(token=token, created=created)}
yaml.safe_dump(auth, fp)
85 changes: 85 additions & 0 deletions tests/sparsezoo/utils/test_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sparsezoo.utils.authentication import (
get_auth_header,
_maybe_load_token,
_save_token,
CREDENTIALS_YAML_TOKEN_KEY,
NM_TOKEN_HEADER,
)
import pytest
from datetime import datetime, timedelta
import yaml
from unittest.mock import patch, MagicMock


def test_load_token_no_path(tmp_path):
path = str(tmp_path / "token.yaml")
assert _maybe_load_token(path) is None


def test_load_token_yaml_fail(tmp_path):
path = str(tmp_path / "token.yaml")
with open(path, "w") as fp:
fp.write("asdf")
assert _maybe_load_token(path) is None


_OLD_DATE = (datetime.now() - timedelta(days=40)).timestamp()


@pytest.mark.parametrize(
"content",
[
{},
{CREDENTIALS_YAML_TOKEN_KEY: {}},
{CREDENTIALS_YAML_TOKEN_KEY: {"token": "asdf"}},
{CREDENTIALS_YAML_TOKEN_KEY: {"created": "asdf"}},
{CREDENTIALS_YAML_TOKEN_KEY: {"created": _OLD_DATE}},
],
)
def test_load_token_failure_cases(tmp_path, content):
path = str(tmp_path / "token.yaml")
with open(path, "w") as fp:
yaml.dump(content, fp)
assert _maybe_load_token(path) is None


def test_load_token_valid(tmp_path):
auth = {
CREDENTIALS_YAML_TOKEN_KEY: {
"created": datetime.now().timestamp(),
"token": "asdf",
}
}
path = str(tmp_path / "token.yaml")
with open(path, "w") as fp:
yaml.dump(auth, fp)
assert _maybe_load_token(path) == "asdf"


def test_load_saved_token(tmp_path):
path = str(tmp_path / "some" / "dirs" / "token.yaml")
_save_token("asdf", datetime.now().timestamp(), path)
assert _maybe_load_token(path) == "asdf"


@patch("requests.post", return_value=MagicMock(json=lambda: {"token": "qwer"}))
def test_get_auth_token(post_mock, tmp_path):
path = tmp_path / "creds.yaml"
assert not path.exists()
assert get_auth_header(path=str(path)) == {NM_TOKEN_HEADER: "qwer"}
assert path.exists()
post_mock.assert_called()