11# copyright (c) 2020, Matthias Dellweg 
22# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt) 
33
4+ import  asyncio 
45import  base64 
56import  datetime 
67import  json 
78import  os 
9+ import  ssl 
810import  typing  as  t 
911from  collections  import  defaultdict 
1012from  contextlib  import  suppress 
1113from  io  import  BufferedReader 
1214from  urllib .parse  import  urljoin 
1315
16+ import  aiohttp 
1417import  requests 
1518import  urllib3 
1619
@@ -174,6 +177,9 @@ def __init__(
174177        self ._safe_calls_only : bool  =  safe_calls_only 
175178        self ._headers  =  headers  or  {}
176179        self ._verify  =  verify 
180+         # Shall we make that a parameter? 
181+         self ._ssl_context : t .Optional [t .Union [ssl .SSLContext , bool ]] =  None 
182+ 
177183        self ._auth_provider  =  auth_provider 
178184        self ._cert  =  cert 
179185        self ._key  =  key 
@@ -225,6 +231,22 @@ def base_url(self) -> str:
225231    def  cid (self ) ->  t .Optional [str ]:
226232        return  self ._headers .get ("Correlation-Id" )
227233
234+     @property  
235+     def  ssl_context (self ) ->  t .Union [ssl .SSLContext , bool ]:
236+         if  self ._ssl_context  is  None :
237+             if  self ._verify  is  False :
238+                 self ._ssl_context  =  False 
239+             else :
240+                 if  isinstance (self ._verify , str ):
241+                     self ._ssl_context  =  ssl .create_default_context (cafile = self ._verify )
242+                 else :
243+                     self ._ssl_context  =  ssl .create_default_context ()
244+                 if  self ._cert  is  not   None :
245+                     self ._ssl_context .load_cert_chain (self ._cert , self ._key )
246+             # Type inference is failing here. 
247+             self ._ssl_context  =  t .cast (t .Union [ssl .SSLContext  |  bool ], self ._ssl_context )
248+         return  self ._ssl_context 
249+ 
228250    def  load_api (self , refresh_cache : bool  =  False ) ->  None :
229251        # TODO: Find a way to invalidate caches on upstream change 
230252        xdg_cache_home : str  =  os .environ .get ("XDG_CACHE_HOME" ) or  "~/.cache" 
@@ -242,7 +264,7 @@ def load_api(self, refresh_cache: bool = False) -> None:
242264            self ._parse_api (data )
243265        except  Exception :
244266            # Try again with a freshly downloaded version 
245-             data  =  self ._download_api ()
267+             data  =  asyncio . run ( self ._download_api () )
246268            self ._parse_api (data )
247269            # Write to cache as it seems to be valid 
248270            os .makedirs (os .path .dirname (apidoc_cache ), exist_ok = True )
@@ -262,15 +284,18 @@ def _parse_api(self, data: bytes) -> None:
262284            if  method  in  {"get" , "put" , "post" , "delete" , "options" , "head" , "patch" , "trace" }
263285        }
264286
265-     def  _download_api (self ) ->  bytes :
287+     async   def  _download_api (self ) ->  bytes :
266288        try :
267-             response : requests .Response  =  self ._session .get (urljoin (self ._base_url , self ._doc_path ))
268-         except  requests .RequestException  as  e :
289+             connector  =  aiohttp .TCPConnector (ssl = self .ssl_context )
290+             async  with  aiohttp .ClientSession (connector = connector , headers = self ._headers ) as  session :
291+                 async  with  session .get (urljoin (self ._base_url , self ._doc_path )) as  response :
292+                     response .raise_for_status ()
293+                     data  =  await  response .read ()
294+         except  aiohttp .ClientError  as  e :
269295            raise  OpenAPIError (str (e ))
270-         response .raise_for_status ()
271296        if  "Correlation-ID"  in  response .headers :
272297            self ._set_correlation_id (response .headers ["Correlation-ID" ])
273-         return  response . content 
298+         return  data 
274299
275300    def  _set_correlation_id (self , correlation_id : str ) ->  None :
276301        if  "Correlation-ID"  in  self ._headers :
0 commit comments