Skip to content

Commit 1e050fe

Browse files
committed
Make blocking set keyspace query to fail by timeout
1 parent 7e0b02d commit 1e050fe

File tree

5 files changed

+52
-10
lines changed

5 files changed

+52
-10
lines changed

cassandra/cluster.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2388,7 +2388,7 @@ def _prepare_all_queries(self, host):
23882388
else:
23892389
for keyspace, ks_statements in groupby(statements, lambda s: s.keyspace):
23902390
if keyspace is not None:
2391-
connection.set_keyspace_blocking(keyspace)
2391+
connection.set_keyspace_blocking(keyspace, self.control_connection_timeout)
23922392

23932393
# prepare 10 statements at a time
23942394
ks_statements = list(ks_statements)

cassandra/connection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1498,14 +1498,14 @@ def _handle_auth_response(self, auth_response):
14981498
log.error(msg, self.endpoint, auth_response)
14991499
raise ProtocolError(msg % (self.endpoint, auth_response))
15001500

1501-
def set_keyspace_blocking(self, keyspace):
1501+
def set_keyspace_blocking(self, keyspace, timeout=None):
15021502
if not keyspace or keyspace == self.keyspace:
15031503
return
15041504

15051505
query = QueryMessage(query='USE "%s"' % (keyspace,),
15061506
consistency_level=ConsistencyLevel.ONE)
15071507
try:
1508-
result = self.wait_for_response(query)
1508+
result = self.wait_for_response(query, timeout=timeout)
15091509
except InvalidRequestException as ire:
15101510
# the keyspace probably doesn't exist
15111511
raise ire.to_exception()

cassandra/pool.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session):
435435
self._keyspace = session.keyspace
436436

437437
if self._keyspace:
438-
first_connection.set_keyspace_blocking(self._keyspace)
438+
first_connection.set_keyspace_blocking(self._keyspace, session.cluster.control_connection_timeout)
439439
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
440440
self.host.sharding_info = first_connection.features.sharding_info
441441
self._open_connections_for_all_shards(first_connection.features.shard_id)
@@ -615,7 +615,7 @@ def _replace(self, connection):
615615
connection = self._session.cluster.connection_factory(self.host.endpoint,
616616
on_orphaned_stream_released=self.on_orphaned_stream_released)
617617
if self._keyspace:
618-
connection.set_keyspace_blocking(self._keyspace)
618+
connection.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)
619619
self._connections[connection.features.shard_id] = connection
620620
except Exception:
621621
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
@@ -766,7 +766,7 @@ def _open_connection_to_missing_shard(self, shard_id):
766766
self.host
767767
)
768768
if self._keyspace:
769-
conn.set_keyspace_blocking(self._keyspace)
769+
conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)
770770

771771
self._connections[conn.features.shard_id] = conn
772772
if old_conn is not None:
@@ -953,7 +953,7 @@ def __init__(self, host, host_distance, session):
953953
self._keyspace = session.keyspace
954954
if self._keyspace:
955955
for conn in self._connections:
956-
conn.set_keyspace_blocking(self._keyspace)
956+
conn.set_keyspace_blocking(self._keyspace, self._session.cluster.control_connection_timeout)
957957

958958
self._trash = set()
959959
self._next_trash_allowed_at = time.time()
@@ -1053,7 +1053,7 @@ def _add_conn_if_under_max(self):
10531053
try:
10541054
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
10551055
if self._keyspace:
1056-
conn.set_keyspace_blocking(self._session.keyspace)
1056+
conn.set_keyspace_blocking(self._session.keyspace, self._session.cluster.control_connection_timeout)
10571057
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
10581058
with self._lock:
10591059
new_connections = self._connections[:] + [conn]

tests/integration/standard/test_cluster.py

+42
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
RetryPolicy, SimpleConvictionPolicy, HostDistance,
3333
AddressTranslator, TokenAwarePolicy, HostFilterPolicy)
3434
from cassandra import ConsistencyLevel
35+
from cassandra.protocol import ProtocolHandler, QueryMessage
3536

3637
from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory
3738
from cassandra.auth import PlainTextAuthProvider, SaslAuthProvider
@@ -484,6 +485,47 @@ def test_refresh_schema_table(self):
484485
self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query())
485486
cluster.shutdown()
486487

488+
def test_use_keyspace_blocking(self):
489+
ks = "test_refresh_schema_type"
490+
491+
cluster = TestCluster()
492+
493+
class ConnectionWrapper(cluster.connection_class):
494+
def __init__(self, *args, **kwargs):
495+
super(ConnectionWrapper, self).__init__(*args, **kwargs)
496+
497+
def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
498+
decoder=ProtocolHandler.decode_message, result_metadata=None):
499+
if isinstance(msg, QueryMessage) and f'USE "{ks}"' in msg.query:
500+
orig_decoder = decoder
501+
502+
def decode_patched(protocol_version, protocol_features, user_type_map, stream_id, flags, opcode,
503+
body,
504+
decompressor, result_metadata):
505+
time.sleep(cluster.control_connection_timeout + 0.1)
506+
return orig_decoder(protocol_version, protocol_features, user_type_map, stream_id, flags,
507+
opcode, body, decompressor, result_metadata)
508+
509+
decoder = decode_patched
510+
511+
return super(ConnectionWrapper, self).send_msg(msg, request_id, cb, encoder, decoder, result_metadata)
512+
513+
cluster.connection_class = ConnectionWrapper
514+
515+
cluster.connect().execute("""
516+
CREATE KEYSPACE IF NOT EXISTS %s
517+
WITH replication = { 'class': 'SimpleStrategy', 'replication_factor': '2' }
518+
""" % ks)
519+
520+
try:
521+
cluster.connect(ks)
522+
except NoHostAvailable:
523+
pass
524+
except Exception as e:
525+
self.fail(f"got unexpected exception {e}")
526+
else:
527+
self.fail("connection should fail, but was not")
528+
487529
def test_refresh_schema_type(self):
488530
if get_server_versions()[0] < (2, 1, 0):
489531
raise unittest.SkipTest('UDTs were introduced in Cassandra 2.1')

tests/unit/test_host_connection_pool.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_borrow_and_return(self):
5656
c, request_id = pool.borrow_connection(timeout=0.01)
5757
self.assertIs(c, conn)
5858
self.assertEqual(1, conn.in_flight)
59-
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace')
59+
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace', session.cluster.control_connection_timeout)
6060

6161
pool.return_connection(conn)
6262
self.assertEqual(0, conn.in_flight)
@@ -256,7 +256,7 @@ def get_conn():
256256
c, request_id = pool.borrow_connection(1.0)
257257
self.assertIs(conn, c)
258258
self.assertEqual(1, conn.in_flight)
259-
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace')
259+
conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace', session.cluster.control_connection_timeout)
260260
pool.return_connection(c)
261261

262262
t = Thread(target=get_conn)

0 commit comments

Comments
 (0)