12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import requests
16
- from urllib3 .util .retry import Retry
17
- from requests .adapters import HTTPAdapter
18
-
19
15
import base64
20
16
import datetime
21
17
import errno
22
18
import json
23
19
import os
24
20
import random
25
- import six
26
21
import stat
27
- from hashlib import sha1
28
22
import warnings
23
+ from hashlib import sha1
24
+
25
+ import requests
26
+ import six
27
+ from requests .adapters import HTTPAdapter
28
+ from urllib3 .util .retry import Retry
29
29
30
30
from descarteslabs .client .exceptions import AuthError , OauthError
31
31
@@ -58,9 +58,9 @@ def makedirs_if_not_exists(path):
58
58
59
59
60
60
class Auth :
61
- def __init__ (self , domain = "https://iam .descarteslabs.com" ,
61
+ def __init__ (self , domain = "https://accounts .descarteslabs.com" ,
62
62
scope = None , leeway = 500 , token_info_path = DEFAULT_TOKEN_INFO_PATH ,
63
- client_id = None , client_secret = None , jwt_token = None ):
63
+ client_id = None , client_secret = None , jwt_token = None , refresh_token = None ):
64
64
"""
65
65
Helps retrieve JWT from a client id and refresh token for cli usage.
66
66
:param domain: endpoint for auth0
@@ -70,6 +70,7 @@ def __init__(self, domain="https://iam.descarteslabs.com",
70
70
:param client_id: JWT client id
71
71
:param client_secret: JWT client secret
72
72
:param jwt_token: the JWT token, if we already have one
73
+ :param refresh_token: the refresh token
73
74
"""
74
75
self .token_info_path = token_info_path
75
76
@@ -84,26 +85,26 @@ def __init__(self, domain="https://iam.descarteslabs.com",
84
85
self .client_id = client_id if client_id else os .environ .get ('CLIENT_ID' , token_info .get ('client_id' , None ))
85
86
self .client_secret = client_secret if client_secret else os .environ .get ('CLIENT_SECRET' , token_info .get (
86
87
'client_secret' , None ))
88
+ self .refresh_token = refresh_token if refresh_token \
89
+ else os .environ .get ('DESCARTESLABS_REFRESH_TOKEN' , token_info .get ('refresh_token' , None ))
90
+ self .scope = scope if scope else token_info .get ('scope' )
87
91
self ._token = jwt_token if jwt_token else os .environ .get ('JWT_TOKEN' , token_info .get ('jwt_token' , None ))
88
92
89
93
if token_info :
90
94
# If the token was read from a path but environment variables were set, we may need
91
95
# to reset the token.
92
96
client_id_changed = token_info .get ('client_id' , None ) != self .client_id
93
97
client_secret_changed = token_info .get ('client_secret' , None ) != self .client_secret
98
+ refresh_token_changed = token_info .get ('refresh_token' , None ) != self .refresh_token
94
99
95
- if client_id_changed or client_secret_changed :
100
+ if client_id_changed or client_secret_changed or refresh_token_changed :
96
101
self ._token = None
97
102
98
103
self ._namespace = None
99
-
104
+ self . _session = None
100
105
self .domain = domain
101
- self .scope = scope
102
106
self .leeway = leeway
103
107
104
- if self .scope is None :
105
- self .scope = ['openid' , 'name' , 'groups' ]
106
-
107
108
@classmethod
108
109
def from_environment_or_token_json (cls , ** kwargs ):
109
110
"""
@@ -120,18 +121,18 @@ def from_environment_or_token_json(cls, **kwargs):
120
121
def token (self ):
121
122
if self ._token is None :
122
123
self ._get_token ()
123
-
124
- exp = self .payload .get ('exp' )
125
-
126
- if exp is not None :
127
- now = (datetime .datetime .utcnow () - datetime .datetime (1970 , 1 , 1 )).total_seconds ()
128
- if now + self .leeway > exp :
129
- try :
130
- self ._get_token ()
131
- except AuthError as e :
132
- # Unable to refresh, raise if now > exp
133
- if now > exp :
134
- raise e
124
+ else : # might have token but could be close to expiration
125
+ exp = self .payload .get ('exp' )
126
+
127
+ if exp is not None :
128
+ now = (datetime .datetime .utcnow () - datetime .datetime (1970 , 1 , 1 )).total_seconds ()
129
+ if now + self .leeway > exp :
130
+ try :
131
+ self ._get_token ()
132
+ except AuthError as e :
133
+ # Unable to refresh, raise if now > exp
134
+ if now > exp :
135
+ raise e
135
136
136
137
return self ._token
137
138
@@ -148,38 +149,65 @@ def payload(self):
148
149
claims = token .split (b'.' )[1 ]
149
150
return json .loads (base64url_decode (claims ).decode ('utf-8' ))
150
151
152
+ @property
153
+ def session (self ):
154
+ if self ._session is None :
155
+ self ._session = requests .Session ()
156
+ retries = Retry (total = 5 ,
157
+ backoff_factor = random .uniform (1 , 10 ),
158
+ method_whitelist = frozenset (['GET' , 'POST' ]),
159
+ status_forcelist = [429 , 500 , 502 , 503 , 504 ])
160
+
161
+ self ._session .mount ('https://' , HTTPAdapter (max_retries = retries ))
162
+
163
+ return self ._session
164
+
151
165
def _get_token (self , timeout = 100 ):
152
166
if self .client_id is None :
153
- raise AuthError ("Could not find CLIENT_ID" )
154
-
155
- if self .client_secret is None :
156
- raise AuthError ("Could not find CLIENT_SECRET" )
157
-
158
- s = requests .Session ()
159
- retries = Retry (total = 5 ,
160
- backoff_factor = random .uniform (1 , 10 ),
161
- method_whitelist = frozenset (['GET' , 'POST' ]),
162
- status_forcelist = [429 , 500 , 502 , 503 , 504 ])
163
-
164
- s .mount ('https://' , HTTPAdapter (max_retries = retries ))
165
-
166
- headers = {"content-type" : "application/json" }
167
- params = {
168
- "scope" : " " .join (self .scope ),
169
- "client_id" : self .client_id ,
170
- "grant_type" : "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
171
- "target" : self .client_id ,
172
- "api_type" : "app" ,
173
- "refresh_token" : self .client_secret
174
- }
175
- r = s .post (self .domain + "/auth/delegation" , headers = headers , data = json .dumps (params ), timeout = timeout )
167
+ raise AuthError ("Could not find client_id" )
168
+
169
+ if self .client_secret is None and self .refresh_token is None :
170
+ raise AuthError ("Could not find client_secret or refresh token" )
171
+
172
+ if self .client_id in ["ZOBAi4UROl5gKZIpxxlwOEfx8KpqXf2c" ]: # TODO(justin) remove legacy handling
173
+ # TODO (justin) insert deprecation warning
174
+ if self .scope is None :
175
+ scope = ['openid' , 'name' , 'groups' ]
176
+ else :
177
+ scope = self .scope
178
+ params = {
179
+ "scope" : " " .join (scope ),
180
+ "client_id" : self .client_id ,
181
+ "grant_type" : "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
182
+ "target" : self .client_id ,
183
+ "api_type" : "app" ,
184
+ "refresh_token" : self .refresh_token if self .refresh_token is not None else self .client_secret
185
+ }
186
+ else :
187
+ params = {
188
+ "client_id" : self .client_id ,
189
+ "grant_type" : "refresh_token" ,
190
+ "refresh_token" : self .refresh_token if self .refresh_token is not None else self .client_secret
191
+ }
192
+
193
+ if self .scope is not None :
194
+ params ["scope" ] = " " .join (self .scope )
195
+
196
+ r = self .session .post (self .domain + "/token" , json = params , timeout = timeout )
176
197
177
198
if r .status_code != 200 :
178
199
raise OauthError ("%s: %s" % (r .status_code , r .text ))
179
200
180
201
data = r .json ()
181
- self ._token = data ['id_token' ]
202
+ access_token = data .get ('access_token' )
203
+ id_token = data .get ('id_token' ) # TODO(justin) remove legacy id_token usage
182
204
205
+ if access_token is not None :
206
+ self ._token = access_token
207
+ elif id_token is not None :
208
+ self ._token = id_token
209
+ else :
210
+ raise OauthError ("could not retrieve token" )
183
211
token_info = {}
184
212
185
213
if self .token_info_path :
0 commit comments