77from  typing  import  Any , Callable , Union 
88
99import  cbor2 
10+ import  pandas 
1011import  pyarrow 
1112import  websockets .exceptions 
1213import  websockets .protocol 
@@ -74,19 +75,19 @@ def __enter__(self):
7475    def  __exit__ (self , exc_type , exc_val , exc_tb ):
7576        self .close ()
7677
77-     def  close (self ):
78+     def  close (self )  ->   None :
7879        self .__ws .close ()
7980
80-     def  commit (self ):
81+     def  commit (self )  ->   None :
8182        raise  NotSupportedError 
8283
83-     def  rollback (self ):
84+     def  rollback (self )  ->   None :
8485        raise  NotSupportedError 
8586
8687    def  cursor (self ) ->  Cursor :
8788        return  Cursor (self .__execute_sql , self .__cancel_query )
8889
89-     def  __main_loop (self ):
90+     def  __main_loop (self )  ->   None :
9091        """Main background loop listening for messages from the SQL session.""" 
9192        logging .info ("Starting background connection handling loop..." )
9293        while  self .__ws .protocol .state  <  websockets .protocol .State .CLOSING :
@@ -101,7 +102,7 @@ def __main_loop(self):
101102            except  Exception  as  e :
102103                logging .exception ("Error handling message from SQL session" , exc_info = e )
103104
104-     def  __listen (self ):
105+     def  __listen (self )  ->   None :
105106        """Waits for the next message from the SQL session and processes it. 
106107
107108        The code in this method is purposefully defensive to avoid unexpected situations killing the thread. 
@@ -120,61 +121,70 @@ def __listen(self):
120121            )
121122            return 
122123
123-         if  kind  ==  EventKind .STATE_UPDATED :
124+         # Incoming state transitions are handled here. 
125+         if  kind  ==  EventKind .STATE_UPDATED  or  kind  ==  EventKind .EXECUTION_RESULT :
124126            try :
125127                query .state  =  ExecutionState [message ["state" ].upper ()]
126128                logging .info ("Query %s is now %s." , execution_id , query .state )
127129            except  KeyError :
128130                logging .warning ("Invalid state update message for %s" , execution_id )
129131                return 
130132
131-             # Incoming state transitions are handled here. 
132133            if  query .state  ==  ExecutionState .SUCCEEDED :
133-                 self .__request_results (execution_id )
134+                 # On a state_updated event telling us the query succeeded, 
135+                 # ask for results. 
136+                 if  kind  ==  EventKind .STATE_UPDATED :
137+                     self .__request_results (execution_id )
138+                     return 
139+ 
140+                 # Otherwise, process the results from the execution_result event. 
141+                 results  =  message .get ("results" )
142+                 if  not  results  or  not  isinstance (results , dict ):
143+                     logging .warning ("Got no results back from %s." , execution_id )
144+                     return 
145+ 
146+                 query .state  =  ExecutionState .COMPLETED 
147+                 query .handler (self ._handle_results (execution_id , results ))
134148            elif  query .state  ==  ExecutionState .CANCELLED :
135-                 logging .info ("Query %s has been cancelled." , execution_id )
149+                 logging .info (
150+                     "Query %s has been cancelled; returning empty results." ,
151+                     execution_id ,
152+                 )
153+                 query .handler (pandas .DataFrame ())
136154                self .__queries .pop (execution_id )
137155            elif  query .state  ==  ExecutionState .FAILED :
138156                # Don't do anything here; the ERROR event is coming with more 
139157                # details. 
140158                pass 
141- 
142-         elif  kind  ==  EventKind .EXECUTION_RESULT :
143-             results  =  message .get ("results" )
144-             if  not  results  or  not  isinstance (results , dict ):
145-                 logging .warning ("Got no results back from %s." , execution_id )
146-                 return 
147- 
148-             result_bytes  =  results .get ("result_bytes" )
149-             result_format  =  results .get ("format" )
150-             result_compression  =  results .get ("compression" )
151-             logging .info (
152-                 "Received %d bytes of %s-compressed %s results from %s." ,
153-                 len (result_bytes ),
154-                 result_compression ,
155-                 result_format ,
156-                 execution_id ,
157-             )
158- 
159-             query .state  =  ExecutionState .COMPLETED 
160-             if  result_format  ==  ResultsFormat .JSON :
161-                 query .handler (json .loads (result_bytes .decode ("utf-8" )))
162-             elif  result_format  ==  ResultsFormat .ARROW :
163-                 buffer  =  pyarrow .py_buffer (result_bytes )
164-                 stream  =  pyarrow .input_stream (buffer , result_compression )
165-                 with  pyarrow .ipc .open_stream (stream ) as  reader :
166-                     query .handler (reader .read_pandas ())
167-             else :
168-                 query .handler (
169-                     OperationalError (f"Unsupported results format { result_format }  )
170-                 )
171159        elif  kind  ==  EventKind .ERROR :
172160            query .state  =  ExecutionState .FAILED 
173161            error  =  message .get ("message" )
174162            query .handler (OperationalError (error ))
175163        else :
176164            logging .warning ("Received unknown %s event!" , kind )
177165
166+     def  _handle_results (self , execution_id : str , results : dict [str , Any ]) ->  Any :
167+         result_bytes  =  results .get ("result_bytes" )
168+         result_format  =  results .get ("format" )
169+         result_compression  =  results .get ("compression" )
170+         logging .info (
171+             "Received %d bytes of %s-compressed %s results from %s." ,
172+             len (result_bytes ),
173+             result_compression ,
174+             result_format ,
175+             execution_id ,
176+         )
177+ 
178+         if  result_format  ==  ResultsFormat .JSON :
179+             return  json .loads (result_bytes .decode ("utf-8" ))
180+         elif  result_format  ==  ResultsFormat .ARROW :
181+             buffer  =  pyarrow .py_buffer (result_bytes )
182+             stream  =  pyarrow .input_stream (buffer , result_compression )
183+             with  pyarrow .ipc .open_stream (stream ) as  reader :
184+                 return  reader .read_pandas ()
185+         else :
186+             return  OperationalError (f"Unsupported results format { result_format }  )
187+ 
178188    def  __send (self , message : dict [str , Any ]) ->  None :
179189        request  =  json .dumps (message )
180190        logging .debug ("Request: %s" , request )
0 commit comments