77from typing import Any , Callable , Union
88
99import cbor2
10+ import pandas
1011import pyarrow
1112import websockets .exceptions
1213import websockets .protocol
@@ -74,19 +75,19 @@ def __enter__(self):
7475 def __exit__ (self , exc_type , exc_val , exc_tb ):
7576 self .close ()
7677
77- def close (self ):
78+ def close (self ) -> None :
7879 self .__ws .close ()
7980
80- def commit (self ):
81+ def commit (self ) -> None :
8182 raise NotSupportedError
8283
83- def rollback (self ):
84+ def rollback (self ) -> None :
8485 raise NotSupportedError
8586
8687 def cursor (self ) -> Cursor :
8788 return Cursor (self .__execute_sql , self .__cancel_query )
8889
89- def __main_loop (self ):
90+ def __main_loop (self ) -> None :
9091 """Main background loop listening for messages from the SQL session."""
9192 logging .info ("Starting background connection handling loop..." )
9293 while self .__ws .protocol .state < websockets .protocol .State .CLOSING :
@@ -101,7 +102,7 @@ def __main_loop(self):
101102 except Exception as e :
102103 logging .exception ("Error handling message from SQL session" , exc_info = e )
103104
104- def __listen (self ):
105+ def __listen (self ) -> None :
105106 """Waits for the next message from the SQL session and processes it.
106107
107108 The code in this method is purposefully defensive to avoid unexpected situations killing the thread.
@@ -120,61 +121,70 @@ def __listen(self):
120121 )
121122 return
122123
123- if kind == EventKind .STATE_UPDATED :
124+ # Incoming state transitions are handled here.
125+ if kind == EventKind .STATE_UPDATED or kind == EventKind .EXECUTION_RESULT :
124126 try :
125127 query .state = ExecutionState [message ["state" ].upper ()]
126128 logging .info ("Query %s is now %s." , execution_id , query .state )
127129 except KeyError :
128130 logging .warning ("Invalid state update message for %s" , execution_id )
129131 return
130132
131- # Incoming state transitions are handled here.
132133 if query .state == ExecutionState .SUCCEEDED :
133- self .__request_results (execution_id )
134+ # On a state_updated event telling us the query succeeded,
135+ # ask for results.
136+ if kind == EventKind .STATE_UPDATED :
137+ self .__request_results (execution_id )
138+ return
139+
140+ # Otherwise, process the results from the execution_result event.
141+ results = message .get ("results" )
142+ if not results or not isinstance (results , dict ):
143+ logging .warning ("Got no results back from %s." , execution_id )
144+ return
145+
146+ query .state = ExecutionState .COMPLETED
147+ query .handler (self ._handle_results (execution_id , results ))
134148 elif query .state == ExecutionState .CANCELLED :
135- logging .info ("Query %s has been cancelled." , execution_id )
149+ logging .info (
150+ "Query %s has been cancelled; returning empty results." ,
151+ execution_id ,
152+ )
153+ query .handler (pandas .DataFrame ())
136154 self .__queries .pop (execution_id )
137155 elif query .state == ExecutionState .FAILED :
138156 # Don't do anything here; the ERROR event is coming with more
139157 # details.
140158 pass
141-
142- elif kind == EventKind .EXECUTION_RESULT :
143- results = message .get ("results" )
144- if not results or not isinstance (results , dict ):
145- logging .warning ("Got no results back from %s." , execution_id )
146- return
147-
148- result_bytes = results .get ("result_bytes" )
149- result_format = results .get ("format" )
150- result_compression = results .get ("compression" )
151- logging .info (
152- "Received %d bytes of %s-compressed %s results from %s." ,
153- len (result_bytes ),
154- result_compression ,
155- result_format ,
156- execution_id ,
157- )
158-
159- query .state = ExecutionState .COMPLETED
160- if result_format == ResultsFormat .JSON :
161- query .handler (json .loads (result_bytes .decode ("utf-8" )))
162- elif result_format == ResultsFormat .ARROW :
163- buffer = pyarrow .py_buffer (result_bytes )
164- stream = pyarrow .input_stream (buffer , result_compression )
165- with pyarrow .ipc .open_stream (stream ) as reader :
166- query .handler (reader .read_pandas ())
167- else :
168- query .handler (
169- OperationalError (f"Unsupported results format { result_format } " )
170- )
171159 elif kind == EventKind .ERROR :
172160 query .state = ExecutionState .FAILED
173161 error = message .get ("message" )
174162 query .handler (OperationalError (error ))
175163 else :
176164 logging .warning ("Received unknown %s event!" , kind )
177165
166+ def _handle_results (self , execution_id : str , results : dict [str , Any ]) -> Any :
167+ result_bytes = results .get ("result_bytes" )
168+ result_format = results .get ("format" )
169+ result_compression = results .get ("compression" )
170+ logging .info (
171+ "Received %d bytes of %s-compressed %s results from %s." ,
172+ len (result_bytes ),
173+ result_compression ,
174+ result_format ,
175+ execution_id ,
176+ )
177+
178+ if result_format == ResultsFormat .JSON :
179+ return json .loads (result_bytes .decode ("utf-8" ))
180+ elif result_format == ResultsFormat .ARROW :
181+ buffer = pyarrow .py_buffer (result_bytes )
182+ stream = pyarrow .input_stream (buffer , result_compression )
183+ with pyarrow .ipc .open_stream (stream ) as reader :
184+ return reader .read_pandas ()
185+ else :
186+ return OperationalError (f"Unsupported results format { result_format } " )
187+
178188 def __send (self , message : dict [str , Any ]) -> None :
179189 request = json .dumps (message )
180190 logging .debug ("Request: %s" , request )
0 commit comments