Skip to content

Commit

Permalink
Merge pull request #332 from sylwiaszunejko/rack_aware_policy
Browse files Browse the repository at this point in the history
Add `RackAwareRoundRobinPolicy` for host selection
  • Loading branch information
sylwiaszunejko authored Aug 5, 2024
2 parents 1c3cff8 + c62665f commit dbb4552
Show file tree
Hide file tree
Showing 6 changed files with 369 additions and 84 deletions.
9 changes: 7 additions & 2 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ def _profiles_without_explicit_lbps(self):

def distance(self, host):
distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values())
return HostDistance.LOCAL if HostDistance.LOCAL in distances else \
return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \
HostDistance.LOCAL if HostDistance.LOCAL in distances else \
HostDistance.REMOTE if HostDistance.REMOTE in distances else \
HostDistance.IGNORED

Expand Down Expand Up @@ -609,7 +610,7 @@ class Cluster(object):
Defaults to loopback interface.
Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit
Note: When using :class:`.DCAwareRoundRobinPolicy` with no explicit
local_dc set (as is the default), the DC is chosen from an arbitrary
host in contact_points. In this case, contact_points should contain
only nodes from a single, local DC.
Expand Down Expand Up @@ -1369,21 +1370,25 @@ def __init__(self,
self._user_types = defaultdict(dict)

self._min_requests_per_connection = {
HostDistance.LOCAL_RACK: DEFAULT_MIN_REQUESTS,
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
}

self._max_requests_per_connection = {
HostDistance.LOCAL_RACK: DEFAULT_MAX_REQUESTS,
HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
}

self._core_connections_per_host = {
HostDistance.LOCAL_RACK: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
}

self._max_connections_per_host = {
HostDistance.LOCAL_RACK: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
}
Expand Down
2 changes: 1 addition & 1 deletion cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3436,7 +3436,7 @@ def group_keys_by_replica(session, keyspace, table, keys):
all_replicas = cluster.metadata.get_replicas(keyspace, routing_key)
# First check if there are local replicas
valid_replicas = [host for host in all_replicas if
host.is_up and distance(host) == HostDistance.LOCAL]
host.is_up and distance(host) in [HostDistance.LOCAL, HostDistance.LOCAL_RACK]]
if not valid_replicas:
valid_replicas = [host for host in all_replicas if host.is_up]

Expand Down
152 changes: 146 additions & 6 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,18 @@ class HostDistance(object):
connections opened to it.
"""

LOCAL = 0
LOCAL_RACK = 0
"""
Nodes with ``LOCAL_RACK`` distance will be preferred for operations
under some load balancing policies (such as :class:`.RackAwareRoundRobinPolicy`)
and will have a greater number of connections opened against
them by default.
This distance is typically used for nodes within the same
datacenter and the same rack as the client.
"""

LOCAL = 1
"""
Nodes with ``LOCAL`` distance will be preferred for operations
under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
Expand All @@ -57,12 +68,12 @@ class HostDistance(object):
datacenter as the client.
"""

REMOTE = 1
REMOTE = 2
"""
Nodes with ``REMOTE`` distance will be treated as a last resort
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
and will have a smaller number of connections opened against
them by default.
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`
and :class:`.RackAwareRoundRobinPolicy`)and will have a smaller number of
connections opened against them by default.
This distance is typically used for nodes outside of the
datacenter that the client is running in.
Expand Down Expand Up @@ -102,6 +113,11 @@ class LoadBalancingPolicy(HostStateListener):
You may also use subclasses of :class:`.LoadBalancingPolicy` for
custom behavior.
You should always use immutable collections (e.g., tuples or
frozensets) to store information about hosts to prevent accidental
modification. When there are changes to the hosts (e.g., a host is
down or up), the old collection should be replaced with a new one.
"""

_hosts_lock = None
Expand Down Expand Up @@ -316,6 +332,130 @@ def on_add(self, host):
def on_remove(self, host):
self.on_down(host)

class RackAwareRoundRobinPolicy(LoadBalancingPolicy):
"""
Similar to :class:`.DCAwareRoundRobinPolicy`, but prefers hosts
in the local rack, before hosts in the local datacenter but a
different rack, before hosts in all other datercentres
"""

local_dc = None
local_rack = None
used_hosts_per_remote_dc = 0

def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
"""
The `local_dc` and `local_rack` parameters should be the name of the
datacenter and rack (such as is reported by ``nodetool ring``) that
should be considered local.
`used_hosts_per_remote_dc` controls how many nodes in
each remote datacenter will have connections opened
against them. In other words, `used_hosts_per_remote_dc` hosts
will be considered :attr:`~.HostDistance.REMOTE` and the
rest will be considered :attr:`~.HostDistance.IGNORED`.
By default, all remote hosts are ignored.
"""
self.local_rack = local_rack
self.local_dc = local_dc
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
self._live_hosts = {}
self._dc_live_hosts = {}
self._endpoints = []
self._position = 0
LoadBalancingPolicy.__init__(self)

def _rack(self, host):
return host.rack or self.local_rack

def _dc(self, host):
return host.datacenter or self.local_dc

def populate(self, cluster, hosts):
for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))):
self._live_hosts[(dc, rack)] = tuple(set(rack_hosts))
for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)):
self._dc_live_hosts[dc] = tuple(set(dc_hosts))

self._position = randint(0, len(hosts) - 1) if hosts else 0

def distance(self, host):
rack = self._rack(host)
dc = self._dc(host)
if rack == self.local_rack and dc == self.local_dc:
return HostDistance.LOCAL_RACK

if dc == self.local_dc:
return HostDistance.LOCAL

if not self.used_hosts_per_remote_dc:
return HostDistance.IGNORED

dc_hosts = self._dc_live_hosts.get(dc, ())
if not dc_hosts:
return HostDistance.IGNORED
if host in dc_hosts and dc_hosts.index(host) < self.used_hosts_per_remote_dc:
return HostDistance.REMOTE
else:
return HostDistance.IGNORED

def make_query_plan(self, working_keyspace=None, query=None):
pos = self._position
self._position += 1

local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
pos = (pos % len(local_rack_live)) if local_rack_live else 0
# Slice the cyclic iterator to start from pos and include the next len(local_live) elements
# This ensures we get exactly one full cycle starting from pos
for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)):
yield host

local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack]
pos = (pos % len(local_live)) if local_live else 0
for host in islice(cycle(local_live), pos, pos + len(local_live)):
yield host

# the dict can change, so get candidate DCs iterating over keys of a copy
for dc, remote_live in self._dc_live_hosts.copy().items():
if dc != self.local_dc:
for host in remote_live[:self.used_hosts_per_remote_dc]:
yield host

def on_up(self, host):
dc = self._dc(host)
rack = self._rack(host)
with self._hosts_lock:
current_rack_hosts = self._live_hosts.get((dc, rack), ())
if host not in current_rack_hosts:
self._live_hosts[(dc, rack)] = current_rack_hosts + (host, )
current_dc_hosts = self._dc_live_hosts.get(dc, ())
if host not in current_dc_hosts:
self._dc_live_hosts[dc] = current_dc_hosts + (host, )

def on_down(self, host):
dc = self._dc(host)
rack = self._rack(host)
with self._hosts_lock:
current_rack_hosts = self._live_hosts.get((dc, rack), ())
if host in current_rack_hosts:
hosts = tuple(h for h in current_rack_hosts if h != host)
if hosts:
self._live_hosts[(dc, rack)] = hosts
else:
del self._live_hosts[(dc, rack)]
current_dc_hosts = self._dc_live_hosts.get(dc, ())
if host in current_dc_hosts:
hosts = tuple(h for h in current_dc_hosts if h != host)
if hosts:
self._dc_live_hosts[dc] = hosts
else:
del self._dc_live_hosts[dc]

def on_add(self, host):
self.on_up(host)

def on_remove(self, host):
self.on_down(host)

class TokenAwarePolicy(LoadBalancingPolicy):
"""
Expand Down Expand Up @@ -390,7 +530,7 @@ def make_query_plan(self, working_keyspace=None, query=None):
shuffle(replicas)

for replica in replicas:
if replica.is_up and child.distance(replica) == HostDistance.LOCAL:
if replica.is_up and child.distance(replica) in [HostDistance.LOCAL, HostDistance.LOCAL_RACK]:
yield replica

for host in child.make_query_plan(keyspace, query):
Expand Down
3 changes: 3 additions & 0 deletions docs/api/cassandra/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Load Balancing
.. autoclass:: DCAwareRoundRobinPolicy
:members:

.. autoclass:: RackAwareRoundRobinPolicy
:members:

.. autoclass:: WhiteListRoundRobinPolicy
:members:

Expand Down
89 changes: 89 additions & 0 deletions tests/integration/standard/test_rack_aware_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import logging
import unittest

from cassandra.cluster import Cluster
from cassandra.policies import ConstantReconnectionPolicy, RackAwareRoundRobinPolicy

from tests.integration import PROTOCOL_VERSION, get_cluster, use_multidc

LOGGER = logging.getLogger(__name__)

def setup_module():
use_multidc({'DC1': {'RC1': 2, 'RC2': 2}, 'DC2': {'RC1': 3}})

class RackAwareRoundRobinPolicyTests(unittest.TestCase):
@classmethod
def setup_class(cls):
cls.cluster = Cluster(contact_points=[node.address() for node in get_cluster().nodelist()], protocol_version=PROTOCOL_VERSION,
load_balancing_policy=RackAwareRoundRobinPolicy("DC1", "RC1", used_hosts_per_remote_dc=0),
reconnection_policy=ConstantReconnectionPolicy(1))
cls.session = cls.cluster.connect()
cls.create_ks_and_cf(cls)
cls.create_data(cls.session)
cls.node1, cls.node2, cls.node3, cls.node4, cls.node5, cls.node6, cls.node7 = get_cluster().nodes.values()

@classmethod
def teardown_class(cls):
cls.cluster.shutdown()

def create_ks_and_cf(self):
self.session.execute(
"""
DROP KEYSPACE IF EXISTS test1
"""
)
self.session.execute(
"""
CREATE KEYSPACE test1
WITH replication = {
'class': 'NetworkTopologyStrategy',
'replication_factor': 3
}
""")

self.session.execute(
"""
CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck));
""")

@staticmethod
def create_data(session):
prepared = session.prepare(
"""
INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?)
""")

for i in range(50):
bound = prepared.bind((i, i%5, i%2))
session.execute(bound)

def test_rack_aware(self):
prepared = self.session.prepare(
"""
SELECT pk, ck, v FROM test1.table1 WHERE pk = ?
""")

for i in range (10):
bound = prepared.bind([i])
results = self.session.execute(bound)
self.assertEqual(results, [(i, i%5, i%2)])
coordinator = str(results.response_future.coordinator_host.endpoint)
self.assertTrue(coordinator in set(["127.0.0.1:9042", "127.0.0.2:9042"]))

self.node2.stop(wait_other_notice=True, gently=True)

for i in range (10):
bound = prepared.bind([i])
results = self.session.execute(bound)
self.assertEqual(results, [(i, i%5, i%2)])
coordinator =str(results.response_future.coordinator_host.endpoint)
self.assertEqual(coordinator, "127.0.0.1:9042")

self.node1.stop(wait_other_notice=True, gently=True)

for i in range (10):
bound = prepared.bind([i])
results = self.session.execute(bound)
self.assertEqual(results, [(i, i%5, i%2)])
coordinator = str(results.response_future.coordinator_host.endpoint)
self.assertTrue(coordinator in set(["127.0.0.3:9042", "127.0.0.4:9042"]))
Loading

0 comments on commit dbb4552

Please sign in to comment.