1818from kbase .auth .exceptions import InvalidTokenError , InvalidUserError
1919
2020# TODO PUBLISH make a pypi kbase org and publish there
21+ # TODO RELIABILITY could add retries for these methods, tenacity looks useful
22+ # should be safe since they're all read only
23+ # TODO NOW CODE make a kbase/auth.py module, move other code into _auth, and import everything
24+ # TODO NOW CODE move Token and User into a common class
25+ # We might want to expand exceptions to include the request ID for debugging purposes
2126
2227
2328@dataclass
@@ -75,14 +80,14 @@ def _check_response(r: httpx.Response):
7580 if err == 30010 : # Illegal username
7681 # The auth server does some goofy stuff when propagating errors, should be cleaned up
7782 # at some point
78- raise InvalidUserError (resjson ["error" ]["message" ].split (":" , 3 )[- 1 ])
83+ raise InvalidUserError (resjson ["error" ]["message" ].split (":" , 3 )[- 1 ]. strip () )
7984 # don't really see any other error codes we need to worry about - maybe disabled?
8085 # worry about it later.
8186 raise IOError ("Error from KBase auth server: " + resjson ["error" ]["message" ])
8287 return resjson
8388
8489
85- class AsyncClient :
90+ class AsyncKBaseAuthClient :
8691 """
8792 A client for the KBase Authentication service.
8893 """
@@ -111,10 +116,6 @@ async def create(
111116 except :
112117 await cli .close ()
113118 raise
114- # TODO CLIENT look through the myriad of auth clients to see what functionality we need
115- # TODO CLIENT cache valid user names using cachefor value from token
116- # TODO RELIABILITY could add retries for these methods, tenacity looks useful
117- # should be safe since they're all reads only
118119 return cli
119120
120121 def __init__ (self , base_url : str , cache_max_size : int , timer : Callable [[[]], int | float ]):
@@ -123,12 +124,14 @@ def __init__(self, base_url: str, cache_max_size: int, timer: Callable[[[]], int
123124 self ._base_url = base_url
124125 self ._token_url = base_url + "api/V2/token"
125126 self ._me_url = base_url + "api/V2/me"
127+ self ._users_url = base_url + "api/V2/users/?list="
126128 if cache_max_size < 1 :
127129 raise ValueError ("cache_max_size must be > 0" )
128130 if not timer :
129131 raise ValueError ("timer is required" )
130132 self ._token_cache = LRUCache (maxsize = cache_max_size , timer = timer )
131133 self ._user_cache = LRUCache (maxsize = cache_max_size , timer = timer )
134+ self ._username_cache = LRUCache (maxsize = cache_max_size , timer = timer )
132135 self ._cli = httpx .AsyncClient ()
133136
134137 async def __aenter__ (self ):
@@ -168,8 +171,6 @@ async def get_token(self, token: str, on_cache_miss: Callable[[], None]=None) ->
168171 res = await self ._get (self ._token_url , headers = {"Authorization" : token })
169172 tk = Token (** {k : v for k , v in res .items () if k in _VALID_TOKEN_FIELDS })
170173 # TODO TEST later may want to add tests that change the cachefor value.
171- # Cleanest way to do this is update the auth2 service to allow setting it
172- # in test mode
173174 self ._token_cache .set (token , tk , ttl = tk .cachefor / 1000 )
174175 return tk
175176
@@ -194,7 +195,56 @@ async def get_user(self, token: str, on_cache_miss: Callable[[], None]=None) ->
194195 res = await self ._get (self ._me_url , headers = {"Authorization" : token })
195196 u = User (** {k : v for k , v in res .items () if k in _VALID_USER_FIELDS })
196197 # TODO TEST later may want to add tests that change the cachefor value.
197- # Cleanest way to do this is update the auth2 service to allow setting it
198- # in test mode
199198 self ._user_cache .set (token , u , ttl = tk .cachefor / 1000 )
200199 return u
200+
201+ async def validate_usernames (
202+ self ,
203+ token : str ,
204+ * usernames : str ,
205+ on_cache_miss : Callable [[str ], None ] = None
206+ ) -> dict [str , bool ]:
207+ """
208+ Validate that one or more usernames exist in the auth service.
209+
210+ If any of the names are illegal, an error is thrown.
211+
212+ token - a valid KBase token for any user.
213+ usernames - one or more usernames to query.
214+ on_cache_miss - a function to call if a cache miss occurs. The single argument is the
215+ username that was not in the cache
216+
217+ Returns a dict of username -> boolean which is True if the username exists.
218+ """
219+ _require_string (token , "token" )
220+ if not usernames :
221+ return {}
222+ # use a dict to preserve ordering for testing purposes
223+ uns = {u .strip (): 1 for u in usernames if u .strip ()}
224+ to_return = {}
225+ to_query = set ()
226+ for u in uns .keys ():
227+ if self ._username_cache .get (u , default = False ):
228+ to_return [u ] = True
229+ else :
230+ if on_cache_miss :
231+ on_cache_miss (u )
232+ to_query .add (u )
233+ if not to_query :
234+ return to_return
235+ res = await self ._get (
236+ self ._users_url + "," .join (to_query ),
237+ headers = {"Authorization" : token }
238+ )
239+ tk = None
240+ for u in to_query :
241+ to_return [u ] = u in res
242+ if to_return [u ]:
243+ if not tk : # minor optimization, don't get the token until it's needed
244+ tk = await self .get_token (token )
245+ # Usernames are permanent but can be disabled, so we expire based on time
246+ # Don't cache non-existent names, could be created at any time and would
247+ # be terrible UX for new users
248+ # TODO TEST later may want to add tests that change the cachefor value.
249+ self ._username_cache .set (u , True , ttl = tk .cachefor / 1000 )
250+ return to_return
0 commit comments