From e8d7151d615eeaaabd76ed178f373bbdd0489aaf Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Wed, 3 Jan 2024 09:01:36 +0100 Subject: [PATCH 1/4] Add parsing TABLETS_ROUTING_V1 extension to ProtocolFeatures In order for Scylla to send the tablet info, the driver must tell the database during connection handshake that it is able to interpret it. This negotation is added as a part of ProtocolFeatures class. --- cassandra/protocol_features.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/cassandra/protocol_features.py b/cassandra/protocol_features.py index fc7c5b060e..4eb7019f84 100644 --- a/cassandra/protocol_features.py +++ b/cassandra/protocol_features.py @@ -6,22 +6,26 @@ RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR" +TABLETS_ROUTING_V1 = "TABLETS_ROUTING_V1" class ProtocolFeatures(object): rate_limit_error = None shard_id = 0 sharding_info = None + tablets_routing_v1 = False - def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None): + def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False): self.rate_limit_error = rate_limit_error self.shard_id = shard_id self.sharding_info = sharding_info + self.tablets_routing_v1 = tablets_routing_v1 @staticmethod def parse_from_supported(supported): rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported) shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported) - return ProtocolFeatures(rate_limit_error, shard_id, sharding_info) + tablets_routing_v1 = ProtocolFeatures.parse_tablets_info(supported) + return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1) @staticmethod def maybe_parse_rate_limit_error(supported): @@ -43,6 +47,8 @@ def get_cql_extension_field(vals, key): def add_startup_options(self, options): if self.rate_limit_error is not None: options[RATE_LIMIT_ERROR_EXTENSION] = "" + if self.tablets_routing_v1: + options[TABLETS_ROUTING_V1] = "" @staticmethod def parse_sharding_info(options): @@ -63,3 +69,6 @@ def parse_sharding_info(options): shard_aware_port, shard_aware_port_ssl) + @staticmethod + def parse_tablets_info(options): + return TABLETS_ROUTING_V1 in options From 02e7ce969c859c305f09e13e852aba0f2f6c47e4 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Fri, 28 Jul 2023 09:26:54 +0200 Subject: [PATCH 2/4] Use tablets in token and shard awareness Add mechanism to parse system.tablets periodically. In TokenAwarePolicy check if keyspace uses tablets if so try to use them to find replicas. Make shard awareness work when using tablets. Everything is wrapped in experimental setting, because tablets are still experimental in ScyllaDB and changes in the tablets format are possible. --- cassandra/cluster.py | 34 ++++++++- cassandra/metadata.py | 2 + cassandra/policies.py | 16 ++++- cassandra/pool.py | 28 ++++++-- cassandra/query.py | 12 +++- cassandra/tablets.py | 107 +++++++++++++++++++++++++++++ tests/unit/test_policies.py | 5 ++ tests/unit/test_response_future.py | 10 +-- 8 files changed, 199 insertions(+), 15 deletions(-) create mode 100644 cassandra/tablets.py diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 6ec04521c7..e3ddc74709 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -41,7 +41,7 @@ import weakref from weakref import WeakValueDictionary -from cassandra import (ConsistencyLevel, AuthenticationFailed, +from cassandra import (ConsistencyLevel, AuthenticationFailed, InvalidRequest, OperationTimedOut, UnsupportedOperation, SchemaTargetType, DriverException, ProtocolVersion, UnresolvableContactPoints) @@ -51,6 +51,7 @@ EndPoint, DefaultEndPoint, DefaultEndPointFactory, ContinuousPagingState, SniEndPointFactory, ConnectionBusy) from cassandra.cqltypes import UserType +import cassandra.cqltypes as types from cassandra.encoder import Encoder from cassandra.protocol import (QueryMessage, ResultMessage, ErrorMessage, ReadTimeoutErrorMessage, @@ -79,6 +80,7 @@ named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET, HostTargetingStatement) from cassandra.marshal import int64_pack +from cassandra.tablets import Tablet, Tablets from cassandra.timestamps import MonotonicTimestampGenerator from cassandra.compat import Mapping from cassandra.util import _resolve_contact_points_to_string_map, Version @@ -1775,6 +1777,14 @@ def connect(self, keyspace=None, wait_for_all_pools=False): self.shutdown() raise + # Update the information about tablet support after connection handshake. + self.load_balancing_policy._tablets_routing_v1 = self.control_connection._tablets_routing_v1 + child_policy = self.load_balancing_policy.child_policy if hasattr(self.load_balancing_policy, 'child_policy') else None + while child_policy is not None: + if hasattr(child_policy, '_tablet_routing_v1'): + child_policy._tablet_routing_v1 = self.control_connection._tablets_routing_v1 + child_policy = child_policy.child_policy if hasattr(child_policy, 'child_policy') else None + self.profile_manager.check_supported() # todo: rename this method if self.idle_heartbeat_interval: @@ -2389,7 +2399,6 @@ def add_prepared(self, query_id, prepared_statement): with self._prepared_statement_lock: self._prepared_statements[query_id] = prepared_statement - class Session(object): """ A collection of connection pools for each host in the cluster. @@ -3541,6 +3550,7 @@ class PeersQueryType(object): _schema_meta_page_size = 1000 _uses_peers_v2 = True + _tablets_routing_v1 = False # for testing purposes _time = time @@ -3674,6 +3684,8 @@ def _try_connect(self, host): # If sharding information is available, it's a ScyllaDB cluster, so do not use peers_v2 table. if connection.features.sharding_info is not None: self._uses_peers_v2 = False + + self._tablets_routing_v1 = connection.features.tablets_routing_v1 # use weak references in both directions # _clear_watcher will be called when this ControlConnection is about to be finalized @@ -4600,7 +4612,10 @@ def _query(self, host, message=None, cb=None): connection = None try: # TODO get connectTimeout from cluster settings - connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key if self.query else None) + if self.query: + connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key, keyspace=self.query.keyspace, table=self.query.table) + else: + connection, request_id = pool.borrow_connection(timeout=2.0) self._connection = connection result_meta = self.prepared_statement.result_metadata if self.prepared_statement else [] @@ -4719,6 +4734,19 @@ def _set_result(self, host, connection, pool, response): self._warnings = getattr(response, 'warnings', None) self._custom_payload = getattr(response, 'custom_payload', None) + if self._custom_payload and self.session.cluster.control_connection._tablets_routing_v1 and 'tablets-routing-v1' in self._custom_payload: + protocol = self.session.cluster.protocol_version + info = self._custom_payload.get('tablets-routing-v1') + ctype = types.lookup_casstype('TupleType(LongType, LongType, ListType(TupleType(UUIDType, Int32Type)))') + tablet_routing_info = ctype.from_binary(info, protocol) + first_token = tablet_routing_info[0] + last_token = tablet_routing_info[1] + tablet_replicas = tablet_routing_info[2] + tablet = Tablet.from_row(first_token, last_token, tablet_replicas) + keyspace = self.query.keyspace + table = self.query.table + self.session.cluster.metadata._tablets.add_tablet(keyspace, table, tablet) + if isinstance(response, ResultMessage): if response.kind == RESULT_KIND_SET_KEYSPACE: session = getattr(self, 'session', None) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 5f1cfa5beb..c2993eaa3f 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -44,6 +44,7 @@ from cassandra.pool import HostDistance from cassandra.connection import EndPoint from cassandra.compat import Mapping +from cassandra.tablets import Tablets log = logging.getLogger(__name__) @@ -126,6 +127,7 @@ def __init__(self): self._hosts = {} self._host_id_by_endpoint = {} self._hosts_lock = RLock() + self._tablets = Tablets({}) def export_schema_as_string(self): """ diff --git a/cassandra/policies.py b/cassandra/policies.py index fa1e8cf385..cfacb16d81 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -335,6 +335,7 @@ class TokenAwarePolicy(LoadBalancingPolicy): _child_policy = None _cluster_metadata = None + _tablets_routing_v1 = False shuffle_replicas = False """ Yield local replicas in a random order. @@ -346,6 +347,7 @@ def __init__(self, child_policy, shuffle_replicas=False): def populate(self, cluster, hosts): self._cluster_metadata = cluster.metadata + self._tablets_routing_v1 = cluster.control_connection._tablets_routing_v1 self._child_policy.populate(cluster, hosts) def check_supported(self): @@ -376,7 +378,19 @@ def make_query_plan(self, working_keyspace=None, query=None): for host in child.make_query_plan(keyspace, query): yield host else: - replicas = self._cluster_metadata.get_replicas(keyspace, routing_key) + replicas = [] + if self._tablets_routing_v1: + tablet = self._cluster_metadata._tablets.get_tablet_for_key(keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(routing_key)) + + if tablet is not None: + replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) + child_plan = child.make_query_plan(keyspace, query) + + replicas = [host for host in child_plan if host.host_id in replicas_mapped] + + if replicas == []: + replicas = self._cluster_metadata.get_replicas(keyspace, routing_key) + if self.shuffle_replicas: shuffle(replicas) for replica in replicas: diff --git a/cassandra/pool.py b/cassandra/pool.py index 110b682c72..bb176b2ee7 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -392,6 +392,8 @@ class HostConnection(object): # the number below, all excess connections will be closed. max_excess_connections_per_shard_multiplier = 3 + tablets_routing_v1 = False + def __init__(self, host, host_distance, session): self.host = host self.host_distance = host_distance @@ -436,10 +438,11 @@ def __init__(self, host, host_distance, session): if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable: self.host.sharding_info = first_connection.features.sharding_info self._open_connections_for_all_shards(first_connection.features.shard_id) + self.tablets_routing_v1 = first_connection.features.tablets_routing_v1 log.debug("Finished initializing connection for host %s", self.host) - def _get_connection_for_routing_key(self, routing_key=None): + def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table=None): if self.is_shutdown: raise ConnectionException( "Pool for %s is shutdown" % (self.host,), self.host) @@ -450,7 +453,22 @@ def _get_connection_for_routing_key(self, routing_key=None): shard_id = None if not self._session.cluster.shard_aware_options.disable and self.host.sharding_info and routing_key: t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key) - shard_id = self.host.sharding_info.shard_id_from_token(t.value) + + shard_id = None + if self.tablets_routing_v1 and table is not None: + if keyspace is None: + keyspace = self._keyspace + + tablet = self._session.cluster.metadata._tablets.get_tablet_for_key(keyspace, table, t) + + if tablet is not None: + for replica in tablet.replicas: + if replica[0] == self.host.host_id: + shard_id = replica[1] + break + + if shard_id is None: + shard_id = self.host.sharding_info.shard_id_from_token(t.value) conn = self._connections.get(shard_id) @@ -496,15 +514,15 @@ def _get_connection_for_routing_key(self, routing_key=None): return random.choice(active_connections) return random.choice(list(self._connections.values())) - def borrow_connection(self, timeout, routing_key=None): - conn = self._get_connection_for_routing_key(routing_key) + def borrow_connection(self, timeout, routing_key=None, keyspace=None, table=None): + conn = self._get_connection_for_routing_key(routing_key, keyspace, table) start = time.time() remaining = timeout last_retry = False while True: if conn.is_closed: # The connection might have been closed in the meantime - if so, try again - conn = self._get_connection_for_routing_key(routing_key) + conn = self._get_connection_for_routing_key(routing_key, keyspace, table) with conn.lock: if (not conn.is_closed or last_retry) and conn.in_flight < conn.max_request_id: # On last retry we ignore connection status, since it is better to return closed connection than diff --git a/cassandra/query.py b/cassandra/query.py index f7a5b8fdf5..e0d6f87fd6 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -253,6 +253,13 @@ class Statement(object): .. versionadded:: 2.1.3 """ + table = None + """ + The string name of the table this query acts on. This is used when the tablet + experimental feature is enabled and in the same time :class`~.TokenAwarePolicy` + is configured in the profile load balancing policy. + """ + custom_payload = None """ :ref:`custom_payload` to be passed to the server. @@ -272,7 +279,7 @@ class Statement(object): def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None, - is_idempotent=False): + is_idempotent=False, table=None): if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy') if retry_policy is not None: @@ -286,6 +293,8 @@ def __init__(self, retry_policy=None, consistency_level=None, routing_key=None, self.fetch_size = fetch_size if keyspace is not None: self.keyspace = keyspace + if table is not None: + self.table = table if custom_payload is not None: self.custom_payload = custom_payload self.is_idempotent = is_idempotent @@ -548,6 +557,7 @@ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None meta = prepared_statement.column_metadata if meta: self.keyspace = meta[0].keyspace_name + self.table = meta[0].table_name Statement.__init__(self, retry_policy, consistency_level, routing_key, serial_consistency_level, fetch_size, keyspace, custom_payload, diff --git a/cassandra/tablets.py b/cassandra/tablets.py new file mode 100644 index 0000000000..aeba7fa8ad --- /dev/null +++ b/cassandra/tablets.py @@ -0,0 +1,107 @@ +# Experimental, this interface and use may change +from threading import Lock + +class Tablet(object): + """ + Represents a single ScyllaDB tablet. + It stores information about each replica, its host and shard, + and the token interval in the format (first_token, last_token]. + """ + first_token = 0 + last_token = 0 + replicas = None + + def __init__(self, first_token = 0, last_token = 0, replicas = None): + self.first_token = first_token + self.last_token = last_token + self.replicas = replicas + + def __str__(self): + return "" \ + % (self.first_token, self.last_token, self.replicas) + __repr__ = __str__ + + @staticmethod + def _is_valid_tablet(replicas): + return replicas is not None and len(replicas) != 0 + + @staticmethod + def from_row(first_token, last_token, replicas): + if Tablet._is_valid_tablet(replicas): + tablet = Tablet(first_token, last_token,replicas) + return tablet + return None + +# Experimental, this interface and use may change +class Tablets(object): + _lock = None + _tablets = {} + + def __init__(self, tablets): + self._tablets = tablets + self._lock = Lock() + + def get_tablet_for_key(self, keyspace, table, t): + tablet = self._tablets.get((keyspace, table), []) + if tablet == []: + return None + + id = bisect_left(tablet, t.value, key = lambda tablet: tablet.last_token) + if id < len(tablet) and t.value > tablet[id].first_token: + return tablet[id] + return None + + def add_tablet(self, keyspace, table, tablet): + with self._lock: + tablets_for_table = self._tablets.setdefault((keyspace, table), []) + + # find first overlaping range + start = bisect_left(tablets_for_table, tablet.first_token, key = lambda t: t.first_token) + if start > 0 and tablets_for_table[start - 1].last_token > tablet.first_token: + start = start - 1 + + # find last overlaping range + end = bisect_left(tablets_for_table, tablet.last_token, key = lambda t: t.last_token) + if end < len(tablets_for_table) and tablets_for_table[end].first_token >= tablet.last_token: + end = end - 1 + + if start <= end: + del tablets_for_table[start:end + 1] + + tablets_for_table.insert(start, tablet) + +# bisect.bisect_left implementation from Python 3.11, needed untill support for +# Python < 3.10 is dropped, it is needed to use `key` to extract last_token from +# Tablet list - better solution performance-wise than materialize list of last_tokens +def bisect_left(a, x, lo=0, hi=None, *, key=None): + """Return the index where to insert item x in list a, assuming a is sorted. + + The return value i is such that all e in a[:i] have e < x, and all e in + a[i:] have e >= x. So if x already appears in the list, a.insert(i, x) will + insert just before the leftmost x already there. + + Optional args lo (default 0) and hi (default len(a)) bound the + slice of a to be searched. + """ + + if lo < 0: + raise ValueError('lo must be non-negative') + if hi is None: + hi = len(a) + # Note, the comparison uses "<" to match the + # __lt__() logic in list.sort() and in heapq. + if key is None: + while lo < hi: + mid = (lo + hi) // 2 + if a[mid] < x: + lo = mid + 1 + else: + hi = mid + else: + while lo < hi: + mid = (lo + hi) // 2 + if key(a[mid]) < x: + lo = mid + 1 + else: + hi = mid + return lo diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index a6c63dcfdc..d9ff59fd7a 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -526,6 +526,7 @@ class TokenAwarePolicyTest(unittest.TestCase): def test_wrap_round_robin(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) + cluster.control_connection._tablets_routing_v1 = False hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() @@ -557,6 +558,7 @@ def get_replicas(keyspace, packed_key): def test_wrap_dc_aware(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) + cluster.control_connection._tablets_routing_v1 = False hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] for host in hosts: host.set_up() @@ -685,6 +687,7 @@ def test_statement_keyspace(self): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) + cluster.control_connection._tablets_routing_v1 = False replicas = hosts[2:] cluster.metadata.get_replicas.return_value = replicas @@ -775,6 +778,7 @@ def _assert_shuffle(self, patched_shuffle, keyspace, routing_key): cluster = Mock(spec=Cluster) cluster.metadata = Mock(spec=Metadata) + cluster.control_connection._tablets_routing_v1 = False replicas = hosts[2:] cluster.metadata.get_replicas.return_value = replicas @@ -1448,6 +1452,7 @@ def test_query_plan_deferred_to_child(self): def test_wrap_token_aware(self): cluster = Mock(spec=Cluster) + cluster.control_connection._tablets_routing_v1 = False hosts = [Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in range(1, 6)] for host in hosts: host.set_up() diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 4e212a0355..29cddec7a8 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -75,7 +75,7 @@ def test_result_message(self): rf.send_request() rf.session._pools.get.assert_called_once_with('ip1') - pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY) + pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) @@ -257,7 +257,7 @@ def test_retry_policy_says_retry(self): rf.send_request() rf.session._pools.get.assert_called_once_with('ip1') - pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY) + pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) result = Mock(spec=UnavailableErrorMessage, info={}) @@ -276,7 +276,7 @@ def test_retry_policy_says_retry(self): # it should try again with the same host since this was # an UnavailableException rf.session._pools.get.assert_called_with(host) - pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY) + pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) def test_retry_with_different_host(self): @@ -291,7 +291,7 @@ def test_retry_with_different_host(self): rf.send_request() rf.session._pools.get.assert_called_once_with('ip1') - pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY) + pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) @@ -310,7 +310,7 @@ def test_retry_with_different_host(self): # it should try with a different host rf.session._pools.get.assert_called_with('ip2') - pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY) + pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) # the consistency level should be the same From c3f194b4508b82a8d4e46ddbec5008d5b0d05c2f Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Tue, 1 Aug 2023 09:27:46 +0200 Subject: [PATCH 3/4] Add integration and unit tests --- .github/workflows/integration-tests.yml | 7 + ci/run_integration_test.sh | 5 +- tests/integration/__init__.py | 10 +- tests/integration/experiments/test_tablets.py | 156 ++++++++++++++++++ tests/unit/test_policies.py | 3 +- tests/unit/test_response_future.py | 1 + tests/unit/test_tablets.py | 88 ++++++++++ 7 files changed, 262 insertions(+), 8 deletions(-) create mode 100644 tests/integration/experiments/test_tablets.py create mode 100644 tests/unit/test_tablets.py diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index a8ee628a8d..d263b52057 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -32,4 +32,11 @@ jobs: - name: Test with pytest run: | export EVENT_LOOP_MANAGER=${{ matrix.event_loop_manager }} + export SCYLLA_VERSION='release:5.1' ./ci/run_integration_test.sh tests/integration/standard/ tests/integration/cqlengine/ + + - name: Test tablets + run: | + export EVENT_LOOP_MANAGER=${{ matrix.event_loop_manager }} + export SCYLLA_VERSION='unstable/master:2024-01-03T08:06:57Z' + ./ci/run_integration_test.sh tests/integration/experiments/ diff --git a/ci/run_integration_test.sh b/ci/run_integration_test.sh index b064b45399..2796a33e61 100755 --- a/ci/run_integration_test.sh +++ b/ci/run_integration_test.sh @@ -15,8 +15,6 @@ if (( aio_max_nr != aio_max_nr_recommended_value )); then fi fi -SCYLLA_RELEASE='release:5.1' - python3 -m venv .test-venv source .test-venv/bin/activate pip install -U pip wheel setuptools @@ -33,12 +31,11 @@ pip install https://github.com/scylladb/scylla-ccm/archive/master.zip # download version -ccm create scylla-driver-temp -n 1 --scylla --version ${SCYLLA_RELEASE} +ccm create scylla-driver-temp -n 1 --scylla --version ${SCYLLA_VERSION} ccm remove # run test -export SCYLLA_VERSION=${SCYLLA_RELEASE} export MAPPED_SCYLLA_VERSION=3.11.4 PROTOCOL_VERSION=4 pytest -rf --import-mode append $* diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index e728bc7740..52e8b5dad4 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -372,7 +372,8 @@ def _id_and_mark(f): # 1. unittest doesn't skip setUpClass when used on class and we need it sometimes # 2. unittest doesn't have conditional xfail, and I prefer to use pytest than custom decorator # 3. unittest doesn't have a reason argument, so you don't see the reason in pytest report -requires_collection_indexes = pytest.mark.skipif(SCYLLA_VERSION is not None and Version(SCYLLA_VERSION.split(':')[1]) < Version('5.2'), +# TODO remove second check when we stop using unstable version in CI for tablets +requires_collection_indexes = pytest.mark.skipif(SCYLLA_VERSION is not None and (len(SCYLLA_VERSION.split('/')) != 0 or Version(SCYLLA_VERSION.split(':')[1]) < Version('5.2')), reason='Scylla supports collection indexes from 5.2 onwards') requires_custom_indexes = pytest.mark.skipif(SCYLLA_VERSION is not None, reason='Scylla does not support SASI or any other CUSTOM INDEX class') @@ -501,7 +502,7 @@ def start_cluster_wait_for_up(cluster): def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, set_keyspace=True, ccm_options=None, - configuration_options=None, dse_options=None, use_single_interface=USE_SINGLE_INTERFACE): + configuration_options=None, dse_options=None, use_single_interface=USE_SINGLE_INTERFACE, use_tablets=False): configuration_options = configuration_options or {} dse_options = dse_options or {} workloads = workloads or [] @@ -611,7 +612,10 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, # CDC is causing an issue (can't start cluster with multiple seeds) # Selecting only features we need for tests, i.e. anything but CDC. CCM_CLUSTER = CCMScyllaCluster(path, cluster_name, **ccm_options) - CCM_CLUSTER.set_configuration_options({'experimental_features': ['lwt', 'udf'], 'start_native_transport': True}) + if use_tablets: + CCM_CLUSTER.set_configuration_options({'experimental_features': ['lwt', 'udf', 'consistent-topology-changes', 'tablets'], 'start_native_transport': True}) + else: + CCM_CLUSTER.set_configuration_options({'experimental_features': ['lwt', 'udf'], 'start_native_transport': True}) # Permit IS NOT NULL restriction on non-primary key columns of a materialized view # This allows `test_metadata_with_quoted_identifiers` to run diff --git a/tests/integration/experiments/test_tablets.py b/tests/integration/experiments/test_tablets.py new file mode 100644 index 0000000000..c9e5c3ea3c --- /dev/null +++ b/tests/integration/experiments/test_tablets.py @@ -0,0 +1,156 @@ +import time +import unittest +import pytest +import os +from cassandra.cluster import Cluster +from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy + +from tests.integration import PROTOCOL_VERSION, use_cluster +from tests.unit.test_host_connection_pool import LOGGER + +def setup_module(): + use_cluster('tablets', [3], start=True, use_tablets=True) + +class TestTabletsIntegration(unittest.TestCase): + @classmethod + def setup_class(cls): + cls.cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION, + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + reconnection_policy=ConstantReconnectionPolicy(1)) + cls.session = cls.cluster.connect() + cls.create_ks_and_cf(cls) + cls.create_data(cls.session) + + @classmethod + def teardown_class(cls): + cls.cluster.shutdown() + + def verify_same_host_in_tracing(self, results): + traces = results.get_query_trace() + events = traces.events + host_set = set() + for event in events: + LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description) + host_set.add(event.source) + + self.assertEqual(len(host_set), 1) + self.assertIn('locally', "\n".join([event.description for event in events])) + + trace_id = results.response_future.get_query_trace_ids()[0] + traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,)) + events = [event for event in traces] + host_set = set() + for event in events: + LOGGER.info("TRACE EVENT: %s %s", event.source, event.activity) + host_set.add(event.source) + + self.assertEqual(len(host_set), 1) + self.assertIn('locally', "\n".join([event.activity for event in events])) + + def verify_same_shard_in_tracing(self, results): + traces = results.get_query_trace() + events = traces.events + shard_set = set() + for event in events: + LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description) + shard_set.add(event.thread_name) + + self.assertEqual(len(shard_set), 1) + self.assertIn('locally', "\n".join([event.description for event in events])) + + trace_id = results.response_future.get_query_trace_ids()[0] + traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,)) + events = [event for event in traces] + shard_set = set() + for event in events: + LOGGER.info("TRACE EVENT: %s %s", event.thread, event.activity) + shard_set.add(event.thread) + + self.assertEqual(len(shard_set), 1) + self.assertIn('locally', "\n".join([event.activity for event in events])) + + def create_ks_and_cf(self): + self.session.execute( + """ + DROP KEYSPACE IF EXISTS test1 + """ + ) + self.session.execute( + """ + CREATE KEYSPACE test1 + WITH replication = { + 'class': 'NetworkTopologyStrategy', + 'replication_factor': 1, + 'initial_tablets': 8 + } + """) + + self.session.execute( + """ + CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck)); + """) + + @staticmethod + def create_data(session): + prepared = session.prepare( + """ + INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?) + """) + + for i in range(50): + bound = prepared.bind((i, i%5, i%2)) + session.execute(bound) + + def query_data_shard_select(self, session, verify_in_tracing=True): + prepared = session.prepare( + """ + SELECT pk, ck, v FROM test1.table1 WHERE pk = ? + """) + + bound = prepared.bind([(2)]) + results = session.execute(bound, trace=True) + self.assertEqual(results, [(2, 2, 0)]) + if verify_in_tracing: + self.verify_same_shard_in_tracing(results) + + def query_data_host_select(self, session, verify_in_tracing=True): + prepared = session.prepare( + """ + SELECT pk, ck, v FROM test1.table1 WHERE pk = ? + """) + + bound = prepared.bind([(2)]) + results = session.execute(bound, trace=True) + self.assertEqual(results, [(2, 2, 0)]) + if verify_in_tracing: + self.verify_same_host_in_tracing(results) + + def query_data_shard_insert(self, session, verify_in_tracing=True): + prepared = session.prepare( + """ + INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?) + """) + + bound = prepared.bind([(51), (1), (2)]) + results = session.execute(bound, trace=True) + if verify_in_tracing: + self.verify_same_shard_in_tracing(results) + + def query_data_host_insert(self, session, verify_in_tracing=True): + prepared = session.prepare( + """ + INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?) + """) + + bound = prepared.bind([(52), (1), (2)]) + results = session.execute(bound, trace=True) + if verify_in_tracing: + self.verify_same_host_in_tracing(results) + + def test_tablets(self): + self.query_data_host_select(self.session) + self.query_data_host_insert(self.session) + + def test_tablets_shard_awareness(self): + self.query_data_shard_select(self.session) + self.query_data_shard_insert(self.session) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index d9ff59fd7a..e60940afac 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -24,7 +24,7 @@ from threading import Thread from cassandra import ConsistencyLevel -from cassandra.cluster import Cluster +from cassandra.cluster import Cluster, ControlConnection from cassandra.metadata import Metadata from cassandra.policies import (RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, TokenAwarePolicy, SimpleConvictionPolicy, @@ -601,6 +601,7 @@ def get_replicas(keyspace, packed_key): class FakeCluster: def __init__(self): self.metadata = Mock(spec=Metadata) + self.control_connection = Mock(spec=ControlConnection) def test_get_distance(self): """ diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 29cddec7a8..d1a7ce4a9f 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -40,6 +40,7 @@ class ResponseFutureTests(unittest.TestCase): def make_basic_session(self): s = Mock(spec=Session) s.row_factory = lambda col_names, rows: [(col_names, rows)] + s.cluster.control_connection._tablets_routing_v1 = False return s def make_pool(self): diff --git a/tests/unit/test_tablets.py b/tests/unit/test_tablets.py new file mode 100644 index 0000000000..3bbba06918 --- /dev/null +++ b/tests/unit/test_tablets.py @@ -0,0 +1,88 @@ +import unittest + +from cassandra.tablets import Tablets, Tablet + +class TabletsTest(unittest.TestCase): + def compare_ranges(self, tablets, ranges): + self.assertEqual(len(tablets), len(ranges)) + + for idx, tablet in enumerate(tablets): + self.assertEqual(tablet.first_token, ranges[idx][0], "First token is not correct in tablet: {}".format(tablet)) + self.assertEqual(tablet.last_token, ranges[idx][1], "Last token is not correct in tablet: {}".format(tablet)) + + def test_add_tablet_to_empty_tablets(self): + tablets = Tablets({("test_ks", "test_tb"): []}) + + tablets.add_tablet("test_ks", "test_tb", Tablet(-6917529027641081857, -4611686018427387905, None)) + + tablets_list = tablets._tablets.get(("test_ks", "test_tb")) + + self.compare_ranges(tablets_list, [(-6917529027641081857, -4611686018427387905)]) + + def test_add_tablet_at_the_beggining(self): + tablets = Tablets({("test_ks", "test_tb"): [Tablet(-6917529027641081857, -4611686018427387905, None)]}) + + tablets.add_tablet("test_ks", "test_tb", Tablet(-8611686018427387905, -7917529027641081857, None)) + + tablets_list = tablets._tablets.get(("test_ks", "test_tb")) + + self.compare_ranges(tablets_list, [(-8611686018427387905, -7917529027641081857), + (-6917529027641081857, -4611686018427387905)]) + + def test_add_tablet_at_the_end(self): + tablets = Tablets({("test_ks", "test_tb"): [Tablet(-6917529027641081857, -4611686018427387905, None)]}) + + tablets.add_tablet("test_ks", "test_tb", Tablet(-1, 2305843009213693951, None)) + + tablets_list = tablets._tablets.get(("test_ks", "test_tb")) + + self.compare_ranges(tablets_list, [(-6917529027641081857, -4611686018427387905), + (-1, 2305843009213693951)]) + + def test_add_tablet_in_the_middle(self): + tablets = Tablets({("test_ks", "test_tb"): [Tablet(-6917529027641081857, -4611686018427387905, None), + Tablet(-1, 2305843009213693951, None)]},) + + tablets.add_tablet("test_ks", "test_tb", Tablet(-4611686018427387905, -2305843009213693953, None)) + + tablets_list = tablets._tablets.get(("test_ks", "test_tb")) + + self.compare_ranges(tablets_list, [(-6917529027641081857, -4611686018427387905), + (-4611686018427387905, -2305843009213693953), + (-1, 2305843009213693951)]) + + def test_add_tablet_intersecting(self): + tablets = Tablets({("test_ks", "test_tb"): [Tablet(-6917529027641081857, -4611686018427387905, None), + Tablet(-4611686018427387905, -2305843009213693953, None), + Tablet(-2305843009213693953, -1, None), + Tablet(-1, 2305843009213693951, None)]}) + + tablets.add_tablet("test_ks", "test_tb", Tablet(-3611686018427387905, -6, None)) + + tablets_list = tablets._tablets.get(("test_ks", "test_tb")) + + self.compare_ranges(tablets_list, [(-6917529027641081857, -4611686018427387905), + (-3611686018427387905, -6), + (-1, 2305843009213693951)]) + + def test_add_tablet_intersecting_with_first(self): + tablets = Tablets({("test_ks", "test_tb"): [Tablet(-8611686018427387905, -7917529027641081857, None), + Tablet(-6917529027641081857, -4611686018427387905, None)]}) + + tablets.add_tablet("test_ks", "test_tb", Tablet(-8011686018427387905, -7987529027641081857, None)) + + tablets_list = tablets._tablets.get(("test_ks", "test_tb")) + + self.compare_ranges(tablets_list, [(-8011686018427387905, -7987529027641081857), + (-6917529027641081857, -4611686018427387905)]) + + def test_add_tablet_intersecting_with_last(self): + tablets = Tablets({("test_ks", "test_tb"): [Tablet(-8611686018427387905, -7917529027641081857, None), + Tablet(-6917529027641081857, -4611686018427387905, None)]}) + + tablets.add_tablet("test_ks", "test_tb", Tablet(-5011686018427387905, -2987529027641081857, None)) + + tablets_list = tablets._tablets.get(("test_ks", "test_tb")) + + self.compare_ranges(tablets_list, [(-8611686018427387905, -7917529027641081857), + (-5011686018427387905, -2987529027641081857)]) From eaa9eb1f9d2ffbc8e0d007643013091e0301c902 Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Thu, 11 Jan 2024 18:36:02 +0100 Subject: [PATCH 4/4] Add documentation of tablet awareness --- README.rst | 1 + docs/scylla-specific.rst | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/README.rst b/README.rst index b1833a8fc5..2a3dc73f33 100644 --- a/README.rst +++ b/README.rst @@ -26,6 +26,7 @@ Features * `Concurrent execution utilities `_ * `Object mapper `_ * `Shard awareness `_ +* `Tablet awareness `_ Installation ------------ diff --git a/docs/scylla-specific.rst b/docs/scylla-specific.rst index f830235088..87fcf01aa3 100644 --- a/docs/scylla-specific.rst +++ b/docs/scylla-specific.rst @@ -109,3 +109,16 @@ New Error Types self.session.execute(prepared.bind((123, 456))) except RateLimitReached: raise + + +Tablet Awareness +---------------- + +**scylla-driver** is tablet aware, which mean that it is able to parse `TABLETS_ROUTING_V1` extension to ProtocolFeatures, recieve tablet information send by Scylla in `custom_payload` part of `RESULT` message, and utilize it. +Thanks to that queries to tablet based tables are still shard aware. + +Details on the scylla cql protocol extensions +https://github.com/scylladb/scylladb/blob/master/docs/dev/protocol-extensions.md#negotiate-sending-tablets-info-to-the-drivers + +Details on the sending tablet information to the drivers +https://github.com/scylladb/scylladb/blob/master/docs/dev/protocol-extensions.md#sending-tablet-info-to-the-drivers