Skip to content

Commit

Permalink
Merge pull request #249 from sylwiaszunejko/introduce_tablets
Browse files Browse the repository at this point in the history
Introduce support for tablets
  • Loading branch information
avelanarius authored Jan 16, 2024
2 parents a27751b + eaa9eb1 commit dfccfff
Show file tree
Hide file tree
Showing 16 changed files with 486 additions and 25 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,11 @@ jobs:
- name: Test with pytest
run: |
export EVENT_LOOP_MANAGER=${{ matrix.event_loop_manager }}
export SCYLLA_VERSION='release:5.1'
./ci/run_integration_test.sh tests/integration/standard/ tests/integration/cqlengine/
- name: Test tablets
run: |
export EVENT_LOOP_MANAGER=${{ matrix.event_loop_manager }}
export SCYLLA_VERSION='unstable/master:2024-01-03T08:06:57Z'
./ci/run_integration_test.sh tests/integration/experiments/
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Features
* `Concurrent execution utilities <http://python-driver.docs.scylladb.com/stable/api/cassandra/concurrent.html>`_
* `Object mapper <http://python-driver.docs.scylladb.com/stable/object-mapper.html>`_
* `Shard awareness <http://python-driver.docs.scylladb.com/stable/scylla-specific.html#shard-awareness>`_
* `Tablet awareness <http://python-driver.docs.scylladb.com/stable/scylla-specific.html#tablet-awareness>`_

Installation
------------
Expand Down
34 changes: 31 additions & 3 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import weakref
from weakref import WeakValueDictionary

from cassandra import (ConsistencyLevel, AuthenticationFailed,
from cassandra import (ConsistencyLevel, AuthenticationFailed, InvalidRequest,
OperationTimedOut, UnsupportedOperation,
SchemaTargetType, DriverException, ProtocolVersion,
UnresolvableContactPoints)
Expand All @@ -51,6 +51,7 @@
EndPoint, DefaultEndPoint, DefaultEndPointFactory,
ContinuousPagingState, SniEndPointFactory, ConnectionBusy)
from cassandra.cqltypes import UserType
import cassandra.cqltypes as types
from cassandra.encoder import Encoder
from cassandra.protocol import (QueryMessage, ResultMessage,
ErrorMessage, ReadTimeoutErrorMessage,
Expand Down Expand Up @@ -79,6 +80,7 @@
named_tuple_factory, dict_factory, tuple_factory, FETCH_SIZE_UNSET,
HostTargetingStatement)
from cassandra.marshal import int64_pack
from cassandra.tablets import Tablet, Tablets
from cassandra.timestamps import MonotonicTimestampGenerator
from cassandra.compat import Mapping
from cassandra.util import _resolve_contact_points_to_string_map, Version
Expand Down Expand Up @@ -1784,6 +1786,14 @@ def connect(self, keyspace=None, wait_for_all_pools=False):
self.shutdown()
raise

# Update the information about tablet support after connection handshake.
self.load_balancing_policy._tablets_routing_v1 = self.control_connection._tablets_routing_v1
child_policy = self.load_balancing_policy.child_policy if hasattr(self.load_balancing_policy, 'child_policy') else None
while child_policy is not None:
if hasattr(child_policy, '_tablet_routing_v1'):
child_policy._tablet_routing_v1 = self.control_connection._tablets_routing_v1
child_policy = child_policy.child_policy if hasattr(child_policy, 'child_policy') else None

self.profile_manager.check_supported() # todo: rename this method

if self.idle_heartbeat_interval:
Expand Down Expand Up @@ -2398,7 +2408,6 @@ def add_prepared(self, query_id, prepared_statement):
with self._prepared_statement_lock:
self._prepared_statements[query_id] = prepared_statement


class Session(object):
"""
A collection of connection pools for each host in the cluster.
Expand Down Expand Up @@ -3550,6 +3559,7 @@ class PeersQueryType(object):
_schema_meta_page_size = 1000

_uses_peers_v2 = True
_tablets_routing_v1 = False

# for testing purposes
_time = time
Expand Down Expand Up @@ -3683,6 +3693,8 @@ def _try_connect(self, host):
# If sharding information is available, it's a ScyllaDB cluster, so do not use peers_v2 table.
if connection.features.sharding_info is not None:
self._uses_peers_v2 = False

self._tablets_routing_v1 = connection.features.tablets_routing_v1

# use weak references in both directions
# _clear_watcher will be called when this ControlConnection is about to be finalized
Expand Down Expand Up @@ -4609,7 +4621,10 @@ def _query(self, host, message=None, cb=None):
connection = None
try:
# TODO get connectTimeout from cluster settings
connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key if self.query else None)
if self.query:
connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key, keyspace=self.query.keyspace, table=self.query.table)
else:
connection, request_id = pool.borrow_connection(timeout=2.0)
self._connection = connection
result_meta = self.prepared_statement.result_metadata if self.prepared_statement else []

Expand Down Expand Up @@ -4728,6 +4743,19 @@ def _set_result(self, host, connection, pool, response):
self._warnings = getattr(response, 'warnings', None)
self._custom_payload = getattr(response, 'custom_payload', None)

if self._custom_payload and self.session.cluster.control_connection._tablets_routing_v1 and 'tablets-routing-v1' in self._custom_payload:
protocol = self.session.cluster.protocol_version
info = self._custom_payload.get('tablets-routing-v1')
ctype = types.lookup_casstype('TupleType(LongType, LongType, ListType(TupleType(UUIDType, Int32Type)))')
tablet_routing_info = ctype.from_binary(info, protocol)
first_token = tablet_routing_info[0]
last_token = tablet_routing_info[1]
tablet_replicas = tablet_routing_info[2]
tablet = Tablet.from_row(first_token, last_token, tablet_replicas)
keyspace = self.query.keyspace
table = self.query.table
self.session.cluster.metadata._tablets.add_tablet(keyspace, table, tablet)

if isinstance(response, ResultMessage):
if response.kind == RESULT_KIND_SET_KEYSPACE:
session = getattr(self, 'session', None)
Expand Down
2 changes: 2 additions & 0 deletions cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from cassandra.pool import HostDistance
from cassandra.connection import EndPoint
from cassandra.compat import Mapping
from cassandra.tablets import Tablets

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -126,6 +127,7 @@ def __init__(self):
self._hosts = {}
self._host_id_by_endpoint = {}
self._hosts_lock = RLock()
self._tablets = Tablets({})

def export_schema_as_string(self):
"""
Expand Down
16 changes: 15 additions & 1 deletion cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ class TokenAwarePolicy(LoadBalancingPolicy):

_child_policy = None
_cluster_metadata = None
_tablets_routing_v1 = False
shuffle_replicas = False
"""
Yield local replicas in a random order.
Expand All @@ -346,6 +347,7 @@ def __init__(self, child_policy, shuffle_replicas=False):

def populate(self, cluster, hosts):
self._cluster_metadata = cluster.metadata
self._tablets_routing_v1 = cluster.control_connection._tablets_routing_v1
self._child_policy.populate(cluster, hosts)

def check_supported(self):
Expand Down Expand Up @@ -376,7 +378,19 @@ def make_query_plan(self, working_keyspace=None, query=None):
for host in child.make_query_plan(keyspace, query):
yield host
else:
replicas = self._cluster_metadata.get_replicas(keyspace, routing_key)
replicas = []
if self._tablets_routing_v1:
tablet = self._cluster_metadata._tablets.get_tablet_for_key(keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(routing_key))

if tablet is not None:
replicas_mapped = set(map(lambda r: r[0], tablet.replicas))
child_plan = child.make_query_plan(keyspace, query)

replicas = [host for host in child_plan if host.host_id in replicas_mapped]

if replicas == []:
replicas = self._cluster_metadata.get_replicas(keyspace, routing_key)

if self.shuffle_replicas:
shuffle(replicas)
for replica in replicas:
Expand Down
28 changes: 23 additions & 5 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,8 @@ class HostConnection(object):
# the number below, all excess connections will be closed.
max_excess_connections_per_shard_multiplier = 3

tablets_routing_v1 = False

def __init__(self, host, host_distance, session):
self.host = host
self.host_distance = host_distance
Expand Down Expand Up @@ -436,10 +438,11 @@ def __init__(self, host, host_distance, session):
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
self.host.sharding_info = first_connection.features.sharding_info
self._open_connections_for_all_shards(first_connection.features.shard_id)
self.tablets_routing_v1 = first_connection.features.tablets_routing_v1

log.debug("Finished initializing connection for host %s", self.host)

def _get_connection_for_routing_key(self, routing_key=None):
def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table=None):
if self.is_shutdown:
raise ConnectionException(
"Pool for %s is shutdown" % (self.host,), self.host)
Expand All @@ -450,7 +453,22 @@ def _get_connection_for_routing_key(self, routing_key=None):
shard_id = None
if not self._session.cluster.shard_aware_options.disable and self.host.sharding_info and routing_key:
t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key)
shard_id = self.host.sharding_info.shard_id_from_token(t.value)

shard_id = None
if self.tablets_routing_v1 and table is not None:
if keyspace is None:
keyspace = self._keyspace

tablet = self._session.cluster.metadata._tablets.get_tablet_for_key(keyspace, table, t)

if tablet is not None:
for replica in tablet.replicas:
if replica[0] == self.host.host_id:
shard_id = replica[1]
break

if shard_id is None:
shard_id = self.host.sharding_info.shard_id_from_token(t.value)

conn = self._connections.get(shard_id)

Expand Down Expand Up @@ -496,15 +514,15 @@ def _get_connection_for_routing_key(self, routing_key=None):
return random.choice(active_connections)
return random.choice(list(self._connections.values()))

def borrow_connection(self, timeout, routing_key=None):
conn = self._get_connection_for_routing_key(routing_key)
def borrow_connection(self, timeout, routing_key=None, keyspace=None, table=None):
conn = self._get_connection_for_routing_key(routing_key, keyspace, table)
start = time.time()
remaining = timeout
last_retry = False
while True:
if conn.is_closed:
# The connection might have been closed in the meantime - if so, try again
conn = self._get_connection_for_routing_key(routing_key)
conn = self._get_connection_for_routing_key(routing_key, keyspace, table)
with conn.lock:
if (not conn.is_closed or last_retry) and conn.in_flight < conn.max_request_id:
# On last retry we ignore connection status, since it is better to return closed connection than
Expand Down
13 changes: 11 additions & 2 deletions cassandra/protocol_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,26 @@


RATE_LIMIT_ERROR_EXTENSION = "SCYLLA_RATE_LIMIT_ERROR"
TABLETS_ROUTING_V1 = "TABLETS_ROUTING_V1"

class ProtocolFeatures(object):
rate_limit_error = None
shard_id = 0
sharding_info = None
tablets_routing_v1 = False

def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None):
def __init__(self, rate_limit_error=None, shard_id=0, sharding_info=None, tablets_routing_v1=False):
self.rate_limit_error = rate_limit_error
self.shard_id = shard_id
self.sharding_info = sharding_info
self.tablets_routing_v1 = tablets_routing_v1

@staticmethod
def parse_from_supported(supported):
rate_limit_error = ProtocolFeatures.maybe_parse_rate_limit_error(supported)
shard_id, sharding_info = ProtocolFeatures.parse_sharding_info(supported)
return ProtocolFeatures(rate_limit_error, shard_id, sharding_info)
tablets_routing_v1 = ProtocolFeatures.parse_tablets_info(supported)
return ProtocolFeatures(rate_limit_error, shard_id, sharding_info, tablets_routing_v1)

@staticmethod
def maybe_parse_rate_limit_error(supported):
Expand All @@ -43,6 +47,8 @@ def get_cql_extension_field(vals, key):
def add_startup_options(self, options):
if self.rate_limit_error is not None:
options[RATE_LIMIT_ERROR_EXTENSION] = ""
if self.tablets_routing_v1:
options[TABLETS_ROUTING_V1] = ""

@staticmethod
def parse_sharding_info(options):
Expand All @@ -63,3 +69,6 @@ def parse_sharding_info(options):
shard_aware_port, shard_aware_port_ssl)


@staticmethod
def parse_tablets_info(options):
return TABLETS_ROUTING_V1 in options
12 changes: 11 additions & 1 deletion cassandra/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ class Statement(object):
.. versionadded:: 2.1.3
"""

table = None
"""
The string name of the table this query acts on. This is used when the tablet
experimental feature is enabled and in the same time :class`~.TokenAwarePolicy`
is configured in the profile load balancing policy.
"""

custom_payload = None
"""
:ref:`custom_payload` to be passed to the server.
Expand All @@ -272,7 +279,7 @@ class Statement(object):

def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
serial_consistency_level=None, fetch_size=FETCH_SIZE_UNSET, keyspace=None, custom_payload=None,
is_idempotent=False):
is_idempotent=False, table=None):
if retry_policy and not hasattr(retry_policy, 'on_read_timeout'): # just checking one method to detect positional parameter errors
raise ValueError('retry_policy should implement cassandra.policies.RetryPolicy')
if retry_policy is not None:
Expand All @@ -286,6 +293,8 @@ def __init__(self, retry_policy=None, consistency_level=None, routing_key=None,
self.fetch_size = fetch_size
if keyspace is not None:
self.keyspace = keyspace
if table is not None:
self.table = table
if custom_payload is not None:
self.custom_payload = custom_payload
self.is_idempotent = is_idempotent
Expand Down Expand Up @@ -548,6 +557,7 @@ def __init__(self, prepared_statement, retry_policy=None, consistency_level=None
meta = prepared_statement.column_metadata
if meta:
self.keyspace = meta[0].keyspace_name
self.table = meta[0].table_name

Statement.__init__(self, retry_policy, consistency_level, routing_key,
serial_consistency_level, fetch_size, keyspace, custom_payload,
Expand Down
Loading

0 comments on commit dfccfff

Please sign in to comment.