Skip to content

Commit

Permalink
add test: revoked refresh token triggers reauth
Browse files Browse the repository at this point in the history
  • Loading branch information
awoimbee committed Sep 13, 2024
1 parent 92df5aa commit d7fb65f
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/ansys/simai/core/utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class _AuthTokens(BaseModel):

@model_validator(mode="before")
@classmethod
def check_card_number_omitted(cls, data: Any) -> dict:
def expires_in_to_datetime(cls, data: Any) -> dict:
assert isinstance(data, dict)
if "expiration" not in data:
# We want and store "expiration" but API responses contain "expires_in"
Expand Down
93 changes: 83 additions & 10 deletions tests/test_utils_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@


@responses.activate
def test_request_auth_tokens_direct_grant(mocker, tmpdir):
default_params = {"client_id": "sdk", "grant_type": "password", "scope": "openid"}
def test_request_auth_tokens_direct_grant_bad_credentials_raises(mocker, tmpdir):
mocker.patch("ansys.simai.core.utils.auth.get_cache_dir", return_value=tmpdir)
responses.add(
responses.POST,
"http://myauthserver.com/protocol/openid-connect/token",
Expand All @@ -58,34 +58,107 @@ def test_request_auth_tokens_direct_grant(mocker, tmpdir):
status=401,
match=[
urlencoded_params_matcher(
dict({"username": "macron", "password": "explosion"}, **default_params)
{
"username": "macron",
"password": "explosion",
"client_id": "sdk",
"grant_type": "password",
"scope": "openid",
}
)
],
)
tokens_retriever = _AuthTokensRetriever(
credentials=None,
session=requests.Session(),
realm_url="http://myauthserver.com",
auth_cache_hash="rando",
)
with pytest.raises(SimAIError, match="Invalid user credentials"):
tokens_retriever.credentials = Credentials(username="macron", password="explosion")
tokens_retriever.get_tokens()


@responses.activate
def test_request_auth_tokens_direct_grant(mocker, tmpdir):
mocker.patch("ansys.simai.core.utils.auth.get_cache_dir", return_value=tmpdir)
responses.add(
responses.POST,
"http://myauthserver.com/protocol/openid-connect/token",
json=DEFAULT_TOKENS,
status=200,
match=[urlencoded_params_matcher(dict({"username": "timmy"}, **default_params))],
match=[
urlencoded_params_matcher(
{
"username": "timmy",
"client_id": "sdk",
"grant_type": "password",
"scope": "openid",
}
)
],
)
mocker.patch("ansys.simai.core.utils.auth.get_cache_dir", return_value=tmpdir)
tokens_retriever = _AuthTokensRetriever(
credentials=None,
session=requests.Session(),
realm_url="http://myauthserver.com",
auth_cache_hash="rando",
)
tokens_retriever.credentials = Credentials(username="timmy", password="")
tokens = tokens_retriever.get_tokens()
assert tokens.refresh_token == DEFAULT_TOKENS["refresh_token"]
assert tokens.access_token == DEFAULT_TOKENS["access_token"]

with pytest.raises(SimAIError):
tokens_retriever.credentials = Credentials(username="macron", password="explosion")
tokens_retriever.get_tokens()

tokens_retriever.credentials = Credentials(username="timmy", password="")
tokens_retriever.cache_file_path += "tests-2"
@responses.activate
def test_token_refresh_failure_triggers_reauth(mocker, tmpdir):
mocker.patch("ansys.simai.core.utils.auth.get_cache_dir", return_value=tmpdir)
resps_refresh = responses.add(
responses.POST,
"http://myauthserver.com/protocol/openid-connect/token",
status=418,
match=[
urlencoded_params_matcher(
{"client_id": "sdk", "grant_type": "refresh_token", "refresh_token": "revoked"}
)
],
)
resps_direct_grant = responses.add(
responses.POST,
"http://myauthserver.com/protocol/openid-connect/token",
json=DEFAULT_TOKENS,
status=200,
match=[
urlencoded_params_matcher(
{
"username": "timmy",
"client_id": "sdk",
"grant_type": "password",
"scope": "openid",
}
)
],
)
with open(tmpdir / "tokens-rando.json", "w") as f:
f.write(
_AuthTokens(
access_token="",
expires_in=0,
refresh_expires_in=999,
refresh_token="revoked",
).model_dump_json()
)
tokens_retriever = _AuthTokensRetriever(
credentials=Credentials(username="timmy", password=""),
session=requests.Session(),
realm_url="http://myauthserver.com",
auth_cache_hash="rando",
)
tokens = tokens_retriever.get_tokens()
assert tokens.refresh_token == DEFAULT_TOKENS["refresh_token"]
assert tokens.access_token == DEFAULT_TOKENS["access_token"]
assert resps_refresh.call_count == 1
assert resps_direct_grant.call_count == 1


@responses.activate
Expand Down

0 comments on commit d7fb65f

Please sign in to comment.