@@ -660,7 +660,7 @@ def stream(self, method, url, params, headers):
660660
661661 def test_too_many_bytes (self ):
662662 @contextmanager
663- def server ():
663+ def server (port ):
664664 def _run (server_sock ):
665665 while True :
666666 try :
@@ -671,52 +671,54 @@ def _run(server_sock):
671671 connection_t .start ()
672672
673673 with shutdown (get_new_socket ()) as server_sock :
674- server_sock .bind (('127.0.0.1' , 9001 ))
674+ server_sock .bind (('127.0.0.1' , port ))
675675 server_sock .listen (socket .IPPROTO_TCP )
676676 threading .Thread (target = _run , args = (server_sock ,)).start ()
677677 yield server_sock
678678
679- def get_http_client ():
679+ def get_http_client (port ):
680680 @contextmanager
681681 def client ():
682682 with httpx .Client () as original_client :
683683 class Client ():
684684 @contextmanager
685685 def stream (self , method , url , params , headers ):
686686 parsed_url = urllib .parse .urlparse (url )
687- url = urllib .parse .urlunparse (parsed_url ._replace (netloc = 'localhost:9001 ' ))
687+ url = urllib .parse .urlunparse (parsed_url ._replace (netloc = f 'localhost:{ port } ' ))
688688 range_query = dict (headers ).get ('range' )
689- is_query = range_query and range_query != 'bytes=0-99'
689+ yield_extra = not only_after_header or ( range_query and range_query != 'bytes=0-99' )
690690 headers_proxy_host = tuple ((key , value ) for key , value in headers if key != 'host' ) + (('host' , 'localhost:9000' ),)
691691 with original_client .stream (method , url ,
692692 params = params , headers = headers_proxy_host
693693 ) as response :
694694 chunks = response .iter_bytes ()
695695 def iter_bytes (chunk_size = None ):
696696 yield from chunks
697- if is_query :
697+ if yield_extra :
698698 yield b'e'
699699 response .iter_bytes = iter_bytes
700700 yield response
701701 yield Client ()
702702 return client ()
703703
704- with server () as server_sock :
705- with get_db ([
706- ("CREATE TABLE my_table (my_col_a text, my_col_b text);" ,()),
707- ] + [
708- ("INSERT INTO my_table VALUES " + ',' .join (["('some-text-a', 'some-text-b')" ] * 500 ),()),
709- ]) as db :
710- put_object_with_versioning ('my-bucket' , 'my.db' , db )
711-
712- with sqlite_s3_query ('http://localhost:9000/my-bucket/my.db' , get_credentials = lambda now : (
713- 'us-east-1' ,
714- 'AKIAIOSFODNN7EXAMPLE' ,
715- 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY' ,
716- None ,
717- ), get_http_client = get_http_client , get_libsqlite3 = get_libsqlite3 ) as query :
718- with self .assertRaisesRegex (SQLiteError , 'disk I/O error' ):
719- query ('SELECT my_col_a FROM my_table' ).__enter__ ()
704+ for only_after_header , port in [(False , 9001 ), (True , 9002 )]:
705+ with self .subTest ((only_after_header , port )):
706+ with server (port ) as server_sock :
707+ with get_db ([
708+ ("CREATE TABLE my_table (my_col_a text, my_col_b text);" ,()),
709+ ] + [
710+ ("INSERT INTO my_table VALUES " + ',' .join (["('some-text-a', 'some-text-b')" ] * 500 ),()),
711+ ]) as db :
712+ put_object_with_versioning ('my-bucket' , 'my.db' , db )
713+
714+ with sqlite_s3_query ('http://localhost:9000/my-bucket/my.db' , get_credentials = lambda now : (
715+ 'us-east-1' ,
716+ 'AKIAIOSFODNN7EXAMPLE' ,
717+ 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY' ,
718+ None ,
719+ ), get_http_client = functools .partial (get_http_client , port ), get_libsqlite3 = get_libsqlite3 ) as query :
720+ with self .assertRaisesRegex (SQLiteError , 'disk I/O error' ):
721+ query ('SELECT my_col_a FROM my_table' ).__enter__ ()
720722
721723 def test_disconnection (self ):
722724 @contextmanager
@@ -732,7 +734,7 @@ def _run(server_sock):
732734 connection_t .start ()
733735
734736 with shutdown (get_new_socket ()) as server_sock :
735- server_sock .bind (('127.0.0.1' , 9001 ))
737+ server_sock .bind (('127.0.0.1' , 9003 ))
736738 server_sock .listen (socket .IPPROTO_TCP )
737739 threading .Thread (target = _run , args = (server_sock ,)).start ()
738740 yield server_sock
@@ -744,7 +746,7 @@ def client():
744746 class Client ():
745747 def stream (self , method , url , headers , params ):
746748 parsed_url = urllib .parse .urlparse (url )
747- url = urllib .parse .urlunparse (parsed_url ._replace (netloc = 'localhost:9001 ' ))
749+ url = urllib .parse .urlunparse (parsed_url ._replace (netloc = 'localhost:9003 ' ))
748750 headers_proxy_host = tuple ((key , value ) for key , value in headers if key != 'host' ) + (('host' , 'localhost:9000' ),)
749751 return original_client .stream (method , url , headers = headers_proxy_host )
750752 yield Client ()
0 commit comments