Skip to content

Commit 4423f65

Browse files
authored
[QUO-1390] allow for multiple api_keys (#92)
* allow for multiple api_keys * update tests to reflect new behaviour
1 parent b79e06b commit 4423f65

File tree

4 files changed

+72
-27
lines changed

4 files changed

+72
-27
lines changed

quotientai/async_client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def __init__(self, api_key: str):
3131
self.api_key = api_key
3232
self.token = None
3333
self.token_expiry = 0
34-
self._token_path = token_dir / ".quotient" / "auth_token.json"
34+
self.token_api_key = None
35+
self._token_path = token_dir / ".quotient" / f"{api_key[-6:]+'_' if api_key else ''}auth_token.json"
3536

3637
# Try to load existing token
3738
self._load_token()
@@ -59,7 +60,7 @@ def _save_token(self, token: str, expiry: int):
5960
return None
6061
# Save to disk
6162
with open(self._token_path, "w") as f:
62-
json.dump({"token": token, "expires_at": expiry}, f)
63+
json.dump({"token": token, "expires_at": expiry, "api_key": self.api_key}, f)
6364

6465
def _load_token(self):
6566
"""Load token from disk if available"""
@@ -71,14 +72,20 @@ def _load_token(self):
7172
data = json.load(f)
7273
self.token = data.get("token")
7374
self.token_expiry = data.get("expires_at", 0)
75+
self.token_api_key = data.get("api_key")
7476
except Exception:
7577
# If loading fails, token remains None
7678
pass
7779

7880
def _is_token_valid(self):
7981
"""Check if token exists and is not expired"""
82+
self._load_token()
83+
8084
if not self.token:
8185
return False
86+
87+
if self.token_api_key != self.api_key:
88+
return False
8289

8390
# With 5-minute buffer
8491
return time.time() < (self.token_expiry - 300)

quotientai/client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def __init__(self, api_key: str):
3333
self.api_key = api_key
3434
self.token = None
3535
self.token_expiry = 0
36-
self._token_path = token_dir / ".quotient" / "auth_token.json"
36+
self.token_api_key = None
37+
self._token_path = token_dir / ".quotient" / f"{api_key[-6:]+'_' if api_key else ''}auth_token.json"
3738

3839
# Try to load existing token
3940
self._load_token()
@@ -62,7 +63,7 @@ def _save_token(self, token: str, expiry: int):
6263

6364
# Save to disk
6465
with open(self._token_path, "w") as f:
65-
json.dump({"token": token, "expires_at": expiry}, f)
66+
json.dump({"token": token, "expires_at": expiry, "api_key": self.api_key}, f)
6667

6768
def _load_token(self):
6869
"""Load token from disk if available"""
@@ -74,14 +75,20 @@ def _load_token(self):
7475
data = json.load(f)
7576
self.token = data.get("token")
7677
self.token_expiry = data.get("expires_at", 0)
78+
self.token_api_key = data.get("api_key")
7779
except Exception:
7880
# If loading fails, token remains None
7981
pass
8082

8183
def _is_token_valid(self):
8284
"""Check if token exists and is not expired"""
85+
self._load_token()
86+
8387
if not self.token:
8488
return False
89+
90+
if self.token_api_key != self.api_key:
91+
return False
8592

8693
# With 5-minute buffer
8794
return time.time() < (self.token_expiry - 300)

tests/test_async_client.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,26 +62,15 @@ def test_initialization(self, tmp_path):
6262
# Use a clean temporary directory for token storage
6363
token_dir = tmp_path / ".quotient"
6464

65-
# Test successful home directory case
6665
with patch('pathlib.Path.home', return_value=tmp_path):
6766
client = _AsyncQuotientClient(api_key)
67+
6868
assert client.api_key == api_key
6969
assert client.token is None
7070
assert client.token_expiry == 0
71+
assert client.token_api_key is None
7172
assert client.headers["Authorization"] == f"Bearer {api_key}"
72-
assert client._token_path == tmp_path / ".quotient" / "auth_token.json"
73-
74-
# Test fallback to /root when home fails
75-
with patch('pathlib.Path.home', side_effect=Exception("Test error")), \
76-
patch('pathlib.Path.exists', return_value=True):
77-
client = _AsyncQuotientClient(api_key)
78-
assert client._token_path == Path("/root/.quotient/auth_token.json")
79-
80-
# Test fallback to cwd when home fails and /root doesn't exist
81-
with patch('pathlib.Path.home', side_effect=Exception("Test error")), \
82-
patch('pathlib.Path.exists', return_value=False):
83-
client = _AsyncQuotientClient(api_key)
84-
assert client._token_path == Path.cwd() / ".quotient" / "auth_token.json"
73+
assert client._token_path == tmp_path / ".quotient" / f"{api_key[-6:]}_auth_token.json"
8574

8675
def test_handle_jwt_response(self):
8776
"""Test that _handle_response properly processes JWT tokens"""
@@ -119,11 +108,12 @@ def test_save_token(self, tmp_path):
119108
assert client.token_expiry == test_expiry
120109

121110
# Verify token was saved to disk
122-
token_file = tmp_path / ".quotient" / "auth_token.json"
111+
token_file = tmp_path / ".quotient" / f"{client.api_key[-6:]}_auth_token.json"
123112
assert token_file.exists()
124113
stored_data = json.loads(token_file.read_text())
125114
assert stored_data["token"] == test_token
126115
assert stored_data["expires_at"] == test_expiry
116+
assert stored_data["api_key"] == client.api_key
127117

128118
def test_load_token(self, tmp_path):
129119
"""Test that _load_token reads token data correctly"""
@@ -135,17 +125,19 @@ def test_load_token(self, tmp_path):
135125
# Write a token file
136126
token_dir = tmp_path / ".quotient"
137127
token_dir.mkdir(parents=True)
138-
token_file = token_dir / "auth_token.json"
128+
token_file = token_dir / f"{client.api_key[-6:]}_auth_token.json"
139129
token_file.write_text(json.dumps({
140130
"token": test_token,
141-
"expires_at": test_expiry
131+
"expires_at": test_expiry,
132+
"api_key": client.api_key
142133
}))
143134

144135
# Load the token
145136
client._load_token()
146137

147138
assert client.token == test_token
148139
assert client.token_expiry == test_expiry
140+
assert client.token_api_key == client.api_key
149141

150142
def test_is_token_valid(self, tmp_path):
151143
"""Test token validity checking"""
@@ -159,16 +151,25 @@ def test_is_token_valid(self, tmp_path):
159151
# Test with expired token
160152
client.token = "expired.token"
161153
client.token_expiry = int(time.time()) - 3600 # 1 hour ago
154+
client.token_api_key = client.api_key
162155
assert not client._is_token_valid()
163156

164157
# Test with valid token
165158
client.token = "valid.token"
166159
client.token_expiry = int(time.time()) + 3600 # 1 hour from now
160+
client.token_api_key = client.api_key
167161
assert client._is_token_valid()
168162

169163
# Test with token about to expire (within 5 minute buffer)
170164
client.token = "about.to.expire"
171165
client.token_expiry = int(time.time()) + 200 # Less than 5 minutes
166+
client.token_api_key = client.api_key
167+
assert not client._is_token_valid()
168+
169+
# Test with mismatched API key
170+
client.token = "valid.token"
171+
client.token_expiry = int(time.time()) + 3600
172+
client.token_api_key = "different-api-key"
172173
assert not client._is_token_valid()
173174

174175
def test_update_auth_header(self, tmp_path):
@@ -185,13 +186,21 @@ def test_update_auth_header(self, tmp_path):
185186
test_token = "test.jwt.token"
186187
client.token = test_token
187188
client.token_expiry = int(time.time()) + 3600
189+
client.token_api_key = client.api_key
188190
client._update_auth_header()
189191
assert client.headers["Authorization"] == f"Bearer {test_token}"
190192

191193
# Should revert to API key when token expires
192194
client.token_expiry = int(time.time()) - 3600
193195
client._update_auth_header()
194196
assert client.headers["Authorization"] == f"Bearer {client.api_key}"
197+
198+
# Should revert to API key when API key doesn't match
199+
client.token = test_token
200+
client.token_expiry = int(time.time()) + 3600
201+
client.token_api_key = "different-api-key"
202+
client._update_auth_header()
203+
assert client.headers["Authorization"] == f"Bearer {client.api_key}"
195204

196205
def test_token_directory_creation_failure(self, tmp_path, caplog):
197206
"""Test that appropriate error is raised when token directory creation fails"""

tests/test_client.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def test_initialization(self, tmp_path):
7373
assert client.api_key == api_key
7474
assert client.token is None
7575
assert client.token_expiry == 0
76+
assert client.token_api_key is None
7677
assert client.headers["Authorization"] == f"Bearer {api_key}"
78+
assert client._token_path == tmp_path / ".quotient" / f"{api_key[-6:]}_auth_token.json"
7779

7880
def test_handle_jwt_response(self):
7981
"""Test that _handle_response properly processes JWT tokens"""
@@ -111,11 +113,12 @@ def test_save_token(self, tmp_path):
111113
assert client.token_expiry == test_expiry
112114

113115
# Verify token was saved to disk
114-
token_file = tmp_path / ".quotient" / "auth_token.json"
116+
token_file = tmp_path / ".quotient" / f"{client.api_key[-6:]}_auth_token.json"
115117
assert token_file.exists()
116118
stored_data = json.loads(token_file.read_text())
117119
assert stored_data["token"] == test_token
118120
assert stored_data["expires_at"] == test_expiry
121+
assert stored_data["api_key"] == client.api_key
119122

120123
def test_load_token(self, tmp_path):
121124
"""Test that _load_token reads token data correctly"""
@@ -127,17 +130,19 @@ def test_load_token(self, tmp_path):
127130
# Write a token file
128131
token_dir = tmp_path / ".quotient"
129132
token_dir.mkdir(parents=True)
130-
token_file = token_dir / "auth_token.json"
133+
token_file = token_dir / f"{client.api_key[-6:]}_auth_token.json"
131134
token_file.write_text(json.dumps({
132135
"token": test_token,
133-
"expires_at": test_expiry
136+
"expires_at": test_expiry,
137+
"api_key": client.api_key
134138
}))
135139

136140
# Load the token
137141
client._load_token()
138142

139143
assert client.token == test_token
140144
assert client.token_expiry == test_expiry
145+
assert client.token_api_key == client.api_key
141146

142147
def test_is_token_valid(self, tmp_path):
143148
"""Test token validity checking"""
@@ -151,16 +156,25 @@ def test_is_token_valid(self, tmp_path):
151156
# Test with expired token
152157
client.token = "expired.token"
153158
client.token_expiry = int(time.time()) - 3600 # 1 hour ago
159+
client.token_api_key = client.api_key
154160
assert not client._is_token_valid()
155161

156162
# Test with valid token
157163
client.token = "valid.token"
158164
client.token_expiry = int(time.time()) + 3600 # 1 hour from now
165+
client.token_api_key = client.api_key
159166
assert client._is_token_valid()
160167

161168
# Test with token about to expire (within 5 minute buffer)
162169
client.token = "about.to.expire"
163170
client.token_expiry = int(time.time()) + 200 # Less than 5 minutes
171+
client.token_api_key = client.api_key
172+
assert not client._is_token_valid()
173+
174+
# Test with mismatched API key
175+
client.token = "valid.token"
176+
client.token_expiry = int(time.time()) + 3600
177+
client.token_api_key = "different-api-key"
164178
assert not client._is_token_valid()
165179

166180
def test_update_auth_header(self, tmp_path):
@@ -177,19 +191,27 @@ def test_update_auth_header(self, tmp_path):
177191
test_token = "test.jwt.token"
178192
client.token = test_token
179193
client.token_expiry = int(time.time()) + 3600
194+
client.token_api_key = client.api_key
180195
client._update_auth_header()
181196
assert client.headers["Authorization"] == f"Bearer {test_token}"
182197

183198
# Should revert to API key when token expires
184199
client.token_expiry = int(time.time()) - 3600
185200
client._update_auth_header()
186201
assert client.headers["Authorization"] == f"Bearer {client.api_key}"
202+
203+
# Should revert to API key when API key doesn't match
204+
client.token = test_token
205+
client.token_expiry = int(time.time()) + 3600
206+
client.token_api_key = "different-api-key"
207+
client._update_auth_header()
208+
assert client.headers["Authorization"] == f"Bearer {client.api_key}"
187209

188210
def test_token_path_uses_home(self):
189211
with patch('pathlib.Path.home') as mock_home:
190212
mock_home.return_value = Path('/home/user')
191213
client = _BaseQuotientClient('test-key')
192-
assert client._token_path == Path('/home/user/.quotient/auth_token.json')
214+
assert client._token_path == Path('/home/user/.quotient/st-key_auth_token.json')
193215

194216
def test_token_path_fallback_to_root(self):
195217
with patch('pathlib.Path.home') as mock_home, \
@@ -200,7 +222,7 @@ def test_token_path_fallback_to_root(self):
200222
mock_exists.return_value = True
201223

202224
client = _BaseQuotientClient('test-key')
203-
assert client._token_path == Path('/root/.quotient/auth_token.json')
225+
assert client._token_path == Path('/root/.quotient/st-key_auth_token.json')
204226

205227
def test_token_path_fallback_to_cwd(self):
206228
with patch('pathlib.Path.home') as mock_home, \
@@ -214,7 +236,7 @@ def test_token_path_fallback_to_cwd(self):
214236
mock_cwd.return_value = Path('/current/dir')
215237

216238
client = _BaseQuotientClient('test-key')
217-
assert client._token_path == Path('/current/dir/.quotient/auth_token.json')
239+
assert client._token_path == Path('/current/dir/.quotient/st-key_auth_token.json')
218240

219241
def test_handle_jwt_token_success(self):
220242
client = _BaseQuotientClient('test-key')

0 commit comments

Comments
 (0)