Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalidate tablets when table or keyspace is deleted #399

Merged
merged 2 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import struct
import random
import itertools
from typing import Optional

murmur3 = None
try:
Expand Down Expand Up @@ -168,10 +169,13 @@ def _rebuild_all(self, parser):
current_keyspaces = set()
for keyspace_meta in parser.get_all_keyspaces():
current_keyspaces.add(keyspace_meta.name)
old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None)
old_keyspace_meta: Optional[KeyspaceMetadata] = self.keyspaces.get(keyspace_meta.name, None)
self.keyspaces[keyspace_meta.name] = keyspace_meta
if old_keyspace_meta:
self._keyspace_updated(keyspace_meta.name)
for table_name in old_keyspace_meta.tables.keys():
if table_name not in keyspace_meta.tables:
self._table_removed(keyspace_meta.name, table_name)
else:
self._keyspace_added(keyspace_meta.name)

Expand Down Expand Up @@ -265,17 +269,22 @@ def _drop_aggregate(self, keyspace, aggregate):
except KeyError:
pass

def _table_removed(self, keyspace, table):
self._tablets.drop_tablets(keyspace, table)

def _keyspace_added(self, ksname):
if self.token_map:
self.token_map.rebuild_keyspace(ksname, build_if_absent=False)

def _keyspace_updated(self, ksname):
if self.token_map:
self.token_map.rebuild_keyspace(ksname, build_if_absent=False)
self._tablets.drop_tablets(ksname)

def _keyspace_removed(self, ksname):
if self.token_map:
self.token_map.remove_keyspace(ksname)
self._tablets.drop_tablets(ksname)

def rebuild_token_map(self, partitioner, token_map):
"""
Expand Down Expand Up @@ -340,11 +349,13 @@ def add_or_return_host(self, host):
return host, True

def remove_host(self, host):
self._tablets.drop_tablets_by_host_id(host.host_id)
with self._hosts_lock:
self._host_id_by_endpoint.pop(host.endpoint, False)
return bool(self._hosts.pop(host.host_id, False))

def remove_host_by_host_id(self, host_id, endpoint=None):
self._tablets.drop_tablets_by_host_id(host_id)
with self._hosts_lock:
if endpoint and self._host_id_by_endpoint[endpoint] == host_id:
self._host_id_by_endpoint.pop(endpoint, False)
Expand Down
35 changes: 35 additions & 0 deletions cassandra/tablets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from threading import Lock
from typing import Optional
from uuid import UUID


class Tablet(object):
Expand Down Expand Up @@ -32,6 +34,12 @@ def from_row(first_token, last_token, replicas):
return tablet
return None

def replica_contains_host_id(self, uuid: UUID) -> bool:
for replica in self.replicas:
if replica[0] == uuid:
return True
return False


class Tablets(object):
_lock = None
Expand All @@ -51,6 +59,33 @@ def get_tablet_for_key(self, keyspace, table, t):
return tablet[id]
return None

def drop_tablets(self, keyspace: str, table: Optional[str] = None):
with self._lock:
if table is not None:
self._tablets.pop((keyspace, table), None)
return

to_be_deleted = []
for key in self._tablets.keys():
if key[0] == keyspace:
to_be_deleted.append(key)

for key in to_be_deleted:
del self._tablets[key]

def drop_tablets_by_host_id(self, host_id: Optional[UUID]):
if host_id is None:
return
with self._lock:
for key, tablets in self._tablets.items():
to_be_deleted = []
for tablet_id, tablet in enumerate(tablets):
if tablet.replica_contains_host_id(host_id):
to_be_deleted.append(tablet_id)

for tablet_id in reversed(to_be_deleted):
tablets.pop(tablet_id)

def add_tablet(self, keyspace, table, tablet):
with self._lock:
tablets_for_table = self._tablets.setdefault((keyspace, table), [])
Expand Down
123 changes: 100 additions & 23 deletions tests/integration/experiments/test_tablets.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@
import time
import unittest

import pytest
import os

from cassandra.cluster import Cluster
from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy

from tests.integration import PROTOCOL_VERSION, use_cluster
from tests.unit.test_host_connection_pool import LOGGER

CCM_CLUSTER = None

def setup_module():
use_cluster('tablets', [3], start=True)
global CCM_CLUSTER

CCM_CLUSTER = use_cluster('tablets', [3], start=True)

class TestTabletsIntegration(unittest.TestCase):

class TestTabletsIntegration:
@classmethod
def setup_class(cls):
cls.cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION,
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
reconnection_policy=ConstantReconnectionPolicy(1))
cls.session = cls.cluster.connect()
cls.create_ks_and_cf(cls)
cls.create_ks_and_cf(cls.session)
cls.create_data(cls.session)

@classmethod
def teardown_class(cls):
cls.cluster.shutdown()

def verify_same_host_in_tracing(self, results):
def verify_hosts_in_tracing(self, results, expected):
traces = results.get_query_trace()
events = traces.events
host_set = set()
for event in events:
LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description)
host_set.add(event.source)

self.assertEqual(len(host_set), 1)
self.assertIn('locally', "\n".join([event.description for event in events]))
assert len(host_set) == expected
assert 'locally' in "\n".join([event.description for event in events])

trace_id = results.response_future.get_query_trace_ids()[0]
traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,))
Expand All @@ -44,8 +49,12 @@ def verify_same_host_in_tracing(self, results):
LOGGER.info("TRACE EVENT: %s %s", event.source, event.activity)
host_set.add(event.source)

self.assertEqual(len(host_set), 1)
self.assertIn('locally', "\n".join([event.activity for event in events]))
assert len(host_set) == expected
assert 'locally' in "\n".join([event.activity for event in events])

def get_tablet_record(self, query):
metadata = self.session.cluster.metadata
return metadata._tablets.get_tablet_for_key(query.keyspace, query.table, metadata.token_map.token_class.from_key(query.routing_key))

def verify_same_shard_in_tracing(self, results):
traces = results.get_query_trace()
Expand All @@ -55,8 +64,8 @@ def verify_same_shard_in_tracing(self, results):
LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description)
shard_set.add(event.thread_name)

self.assertEqual(len(shard_set), 1)
self.assertIn('locally', "\n".join([event.description for event in events]))
assert len(shard_set) == 1
assert 'locally' in "\n".join([event.description for event in events])

trace_id = results.response_future.get_query_trace_ids()[0]
traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,))
Expand All @@ -66,27 +75,28 @@ def verify_same_shard_in_tracing(self, results):
LOGGER.info("TRACE EVENT: %s %s", event.thread, event.activity)
shard_set.add(event.thread)

self.assertEqual(len(shard_set), 1)
self.assertIn('locally', "\n".join([event.activity for event in events]))
assert len(shard_set) == 1
assert 'locally' in "\n".join([event.activity for event in events])

def create_ks_and_cf(self):
self.session.execute(
@classmethod
def create_ks_and_cf(cls, session):
session.execute(
"""
DROP KEYSPACE IF EXISTS test1
"""
)
self.session.execute(
session.execute(
"""
CREATE KEYSPACE test1
WITH replication = {
'class': 'NetworkTopologyStrategy',
'replication_factor': 1
'replication_factor': 2
} AND tablets = {
'initial': 8
}
""")

self.session.execute(
session.execute(
"""
CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck));
""")
Expand All @@ -110,7 +120,7 @@ def query_data_shard_select(self, session, verify_in_tracing=True):

bound = prepared.bind([(2)])
results = session.execute(bound, trace=True)
self.assertEqual(results, [(2, 2, 0)])
assert results == [(2, 2, 0)]
if verify_in_tracing:
self.verify_same_shard_in_tracing(results)

Expand All @@ -122,9 +132,9 @@ def query_data_host_select(self, session, verify_in_tracing=True):

bound = prepared.bind([(2)])
results = session.execute(bound, trace=True)
self.assertEqual(results, [(2, 2, 0)])
assert results == [(2, 2, 0)]
if verify_in_tracing:
self.verify_same_host_in_tracing(results)
self.verify_hosts_in_tracing(results, 1)

def query_data_shard_insert(self, session, verify_in_tracing=True):
prepared = session.prepare(
Expand All @@ -146,7 +156,7 @@ def query_data_host_insert(self, session, verify_in_tracing=True):
bound = prepared.bind([(52), (1), (2)])
results = session.execute(bound, trace=True)
if verify_in_tracing:
self.verify_same_host_in_tracing(results)
self.verify_hosts_in_tracing(results, 2)

def test_tablets(self):
self.query_data_host_select(self.session)
Expand All @@ -155,3 +165,70 @@ def test_tablets(self):
def test_tablets_shard_awareness(self):
self.query_data_shard_select(self.session)
self.query_data_shard_insert(self.session)

def test_tablets_invalidation_drop_ks_while_reconnecting(self):
def recreate_while_reconnecting(_):
# Kill control connection
conn = self.session.cluster.control_connection._connection
self.session.cluster.control_connection._connection = None
conn.close()

# Drop and recreate ks and table to trigger tablets invalidation
self.create_ks_and_cf(self.cluster.connect())

# Start control connection
self.session.cluster.control_connection._reconnect()

self.run_tablets_invalidation_test(recreate_while_reconnecting)

def test_tablets_invalidation_drop_ks(self):
def drop_ks(_):
# Drop and recreate ks and table to trigger tablets invalidation
self.create_ks_and_cf(self.cluster.connect())
time.sleep(3)

self.run_tablets_invalidation_test(drop_ks)

@pytest.mark.last
def test_tablets_invalidation_decommission_non_cc_node(self):
def decommission_non_cc_node(rec):
# Drop and recreate ks and table to trigger tablets invalidation
for node in CCM_CLUSTER.nodes.values():
if self.cluster.control_connection._connection.endpoint.address == node.network_interfaces["storage"][0]:
# Ignore node that control connection is connected to
continue
for replica in rec.replicas:
if str(replica[0]) == str(node.node_hostid):
node.decommission()
break
else:
continue
break
else:
assert False, "failed to find node to decommission"
time.sleep(10)

self.run_tablets_invalidation_test(decommission_non_cc_node)


def run_tablets_invalidation_test(self, invalidate):
# Make sure driver holds tablet info
# By landing query to the host that is not in replica set
bound = self.session.prepare(
"""
SELECT pk, ck, v FROM test1.table1 WHERE pk = ?
""").bind([(2)])

rec = None
for host in self.cluster.metadata.all_hosts():
self.session.execute(bound, host=host)
rec = self.get_tablet_record(bound)
if rec is not None:
break

assert rec is not None, "failed to find tablet record"

invalidate(rec)

# Check if tablets information was purged
assert self.get_tablet_record(bound) is None, "tablet was not deleted, invalidation did not work"
Loading