diff --git a/hathor/p2p/netfilter/utils.py b/hathor/p2p/netfilter/utils.py index 5b8c6f4f1..a232e1b8a 100644 --- a/hathor/p2p/netfilter/utils.py +++ b/hathor/p2p/netfilter/utils.py @@ -17,15 +17,84 @@ from hathor.p2p.netfilter.rule import NetfilterRule from hathor.p2p.netfilter.targets import NetfilterReject +# Global mapping to track peer_id -> rule UUID for blacklist management +_peer_id_to_rule_uuid: dict[str, str] = {} -def add_peer_id_blacklist(peer_id_blacklist: list[str]) -> None: - """ Add a list of peer ids to a blacklist using netfilter reject + +def add_blacklist_peers(peer_ids: str | list[str]) -> list[str]: + """Add peer(s) to the blacklist. + + Args: + peer_ids: A single peer_id string or a list of peer_id strings + + Returns: + List of peer_ids that were successfully added (not already blacklisted) """ + if isinstance(peer_ids, str): + peer_ids = [peer_ids] + post_peerid = get_table('filter').get_chain('post_peerid') + added_peers: list[str] = [] - for peer_id in peer_id_blacklist: + for peer_id in peer_ids: if not peer_id: continue + + # Skip if already blacklisted + if peer_id in _peer_id_to_rule_uuid: + continue + match = NetfilterMatchPeerId(peer_id) rule = NetfilterRule(match, NetfilterReject()) post_peerid.add_rule(rule) + _peer_id_to_rule_uuid[peer_id] = rule.uuid + added_peers.append(peer_id) + + return added_peers + + +def remove_blacklist_peers(peer_ids: str | list[str]) -> list[str]: + """Remove peer(s) from the blacklist. + + Args: + peer_ids: A single peer_id string or a list of peer_id strings + + Returns: + List of peer_ids that were successfully removed + """ + if isinstance(peer_ids, str): + peer_ids = [peer_ids] + + post_peerid = get_table('filter').get_chain('post_peerid') + removed_peers: list[str] = [] + + for peer_id in peer_ids: + if not peer_id: + continue + + rule_uuid = _peer_id_to_rule_uuid.get(peer_id) + if rule_uuid is None: + continue + + if post_peerid.delete_rule(rule_uuid): + del _peer_id_to_rule_uuid[peer_id] + removed_peers.append(peer_id) + + return removed_peers + + +def list_blacklist_peers() -> list[str]: + """List all currently blacklisted peer_ids. + + Returns: + List of blacklisted peer_id strings + """ + return list(_peer_id_to_rule_uuid.keys()) + + +def add_peer_id_blacklist(peer_id_blacklist: list[str]) -> None: + """Add a list of peer ids to a blacklist using netfilter reject. + + This is a legacy function that wraps add_blacklist_peers for backward compatibility. + """ + add_blacklist_peers(peer_id_blacklist) diff --git a/hathor/sysctl/p2p/manager.py b/hathor/sysctl/p2p/manager.py index 9f9856a42..56ced7ee6 100644 --- a/hathor/sysctl/p2p/manager.py +++ b/hathor/sysctl/p2p/manager.py @@ -15,6 +15,7 @@ import os from hathor.p2p.manager import ConnectionsManager +from hathor.p2p.netfilter.utils import add_blacklist_peers, list_blacklist_peers, remove_blacklist_peers from hathor.p2p.peer_id import PeerId from hathor.p2p.sync_version import SyncVersion from hathor.p2p.utils import discover_hostname @@ -122,6 +123,21 @@ def __init__(self, connections: ConnectionsManager) -> None: None, self.reload_entrypoints_and_connections, ) + self.register( + 'blacklist.add_peers', + None, + self.set_blacklist_add_peers, + ) + self.register( + 'blacklist.remove_peers', + None, + self.set_blacklist_remove_peers, + ) + self.register( + 'blacklist.list_peers', + self.get_blacklist_list_peers, + None, + ) def set_force_sync_rotate(self) -> None: """Force a sync rotate.""" @@ -269,3 +285,43 @@ def refresh_auto_hostname(self) -> None: def reload_entrypoints_and_connections(self) -> None: """Kill all connections and reload entrypoints from the peer config file.""" self.connections.reload_entrypoints_and_connections() + + @signal_handler_safe + def set_blacklist_add_peers(self, peer_ids: str | list[str]) -> None: + """Add peer(s) to the blacklist. Accepts a single peer-id string or a list of peer-ids.""" + # Validate peer IDs + peer_id_list = [peer_ids] if isinstance(peer_ids, str) else peer_ids + try: + for peer_id in peer_id_list: + if peer_id: # Skip empty strings + PeerId(peer_id) # Validate format + except ValueError as e: + raise SysctlException(f'Invalid peer-id format: {e}') + + added_peers = add_blacklist_peers(peer_ids) + if added_peers: + self.log.info('Added peers to blacklist', peer_ids=added_peers) + else: + self.log.info('No new peers added to blacklist (already blacklisted or empty)') + + @signal_handler_safe + def set_blacklist_remove_peers(self, peer_ids: str | list[str]) -> None: + """Remove peer(s) from the blacklist. Accepts a single peer-id string or a list of peer-ids.""" + # Validate peer IDs + peer_id_list = [peer_ids] if isinstance(peer_ids, str) else peer_ids + try: + for peer_id in peer_id_list: + if peer_id: # Skip empty strings + PeerId(peer_id) # Validate format + except ValueError as e: + raise SysctlException(f'Invalid peer-id format: {e}') + + removed_peers = remove_blacklist_peers(peer_ids) + if removed_peers: + self.log.info('Removed peers from blacklist', peer_ids=removed_peers) + else: + self.log.info('No peers removed from blacklist (not found or empty)') + + def get_blacklist_list_peers(self) -> list[str]: + """List all currently blacklisted peer_ids.""" + return list_blacklist_peers() diff --git a/hathor_tests/p2p/netfilter/test_utils.py b/hathor_tests/p2p/netfilter/test_utils.py index c8901a454..7341f24f4 100644 --- a/hathor_tests/p2p/netfilter/test_utils.py +++ b/hathor_tests/p2p/netfilter/test_utils.py @@ -1,9 +1,23 @@ from hathor.p2p.netfilter import get_table -from hathor.p2p.netfilter.utils import add_peer_id_blacklist +from hathor.p2p.netfilter.utils import ( + add_blacklist_peers, + add_peer_id_blacklist, + list_blacklist_peers, + remove_blacklist_peers, +) from hathor_tests import unittest class NetfilterUtilsTest(unittest.TestCase): + def setUp(self) -> None: + """Clean up rules and tracking before each test.""" + super().setUp() + post_peerid = get_table('filter').get_chain('post_peerid') + post_peerid.rules = [] + # Clear the global tracking dictionary + from hathor.p2p.netfilter import utils + utils._peer_id_to_rule_uuid.clear() + def test_peer_id_blacklist(self) -> None: post_peerid = get_table('filter').get_chain('post_peerid') @@ -24,3 +38,128 @@ def test_peer_id_blacklist(self) -> None: self.assertEqual(data['match']['type'], 'NetfilterMatchPeerId') self.assertIn(data['match']['match_params']['peer_id'], blacklist) self.assertEqual(data['target']['type'], 'NetfilterReject') + + def test_add_blacklist_peers_with_list(self) -> None: + """Test adding multiple peers with a list.""" + post_peerid = get_table('filter').get_chain('post_peerid') + + # Initially empty + self.assertEqual(len(post_peerid.rules), 0) + self.assertEqual(list_blacklist_peers(), []) + + # Add peers + peer_ids = ['peer1', 'peer2', 'peer3'] + added = add_blacklist_peers(peer_ids) + + # All peers should be added + self.assertEqual(sorted(added), sorted(peer_ids)) + self.assertEqual(len(post_peerid.rules), 3) + self.assertEqual(sorted(list_blacklist_peers()), sorted(peer_ids)) + + def test_add_blacklist_peers_with_string(self) -> None: + """Test adding a single peer with a string.""" + post_peerid = get_table('filter').get_chain('post_peerid') + + # Add single peer + peer_id = 'single_peer' + added = add_blacklist_peers(peer_id) + + self.assertEqual(added, [peer_id]) + self.assertEqual(len(post_peerid.rules), 1) + self.assertEqual(list_blacklist_peers(), [peer_id]) + + def test_add_blacklist_peers_skip_duplicates(self) -> None: + """Test that adding duplicate peers is skipped.""" + post_peerid = get_table('filter').get_chain('post_peerid') + + # Add peers first time + peer_ids = ['peer1', 'peer2'] + added1 = add_blacklist_peers(peer_ids) + self.assertEqual(sorted(added1), sorted(peer_ids)) + self.assertEqual(len(post_peerid.rules), 2) + + # Try to add same peers again + added2 = add_blacklist_peers(peer_ids) + self.assertEqual(added2, []) # Nothing added + self.assertEqual(len(post_peerid.rules), 2) # Still 2 rules + + # Add mix of new and existing + added3 = add_blacklist_peers(['peer1', 'peer3']) + self.assertEqual(added3, ['peer3']) # Only new peer added + self.assertEqual(len(post_peerid.rules), 3) + + def test_add_blacklist_peers_skip_empty(self) -> None: + """Test that empty strings are skipped.""" + peer_ids = ['peer1', '', 'peer2', ''] + added = add_blacklist_peers(peer_ids) + + self.assertEqual(sorted(added), ['peer1', 'peer2']) + self.assertEqual(sorted(list_blacklist_peers()), ['peer1', 'peer2']) + + def test_remove_blacklist_peers_with_list(self) -> None: + """Test removing multiple peers with a list.""" + # Add peers first + peer_ids = ['peer1', 'peer2', 'peer3'] + add_blacklist_peers(peer_ids) + self.assertEqual(sorted(list_blacklist_peers()), sorted(peer_ids)) + + # Remove some peers + to_remove = ['peer1', 'peer3'] + removed = remove_blacklist_peers(to_remove) + + self.assertEqual(sorted(removed), sorted(to_remove)) + self.assertEqual(list_blacklist_peers(), ['peer2']) + + def test_remove_blacklist_peers_with_string(self) -> None: + """Test removing a single peer with a string.""" + # Add peers first + add_blacklist_peers(['peer1', 'peer2']) + + # Remove one peer + removed = remove_blacklist_peers('peer1') + + self.assertEqual(removed, ['peer1']) + self.assertEqual(list_blacklist_peers(), ['peer2']) + + def test_remove_blacklist_peers_nonexistent(self) -> None: + """Test removing peers that don't exist.""" + # Add one peer + add_blacklist_peers('peer1') + + # Try to remove nonexistent peers + removed = remove_blacklist_peers(['peer2', 'peer3']) + + self.assertEqual(removed, []) + self.assertEqual(list_blacklist_peers(), ['peer1']) + + # Remove mix of existing and nonexistent + removed2 = remove_blacklist_peers(['peer1', 'peer2']) + self.assertEqual(removed2, ['peer1']) + self.assertEqual(list_blacklist_peers(), []) + + def test_remove_blacklist_peers_skip_empty(self) -> None: + """Test that empty strings are skipped during removal.""" + add_blacklist_peers(['peer1', 'peer2']) + + removed = remove_blacklist_peers(['peer1', '', 'peer2']) + + self.assertEqual(sorted(removed), ['peer1', 'peer2']) + self.assertEqual(list_blacklist_peers(), []) + + def test_list_blacklist_peers(self) -> None: + """Test listing blacklisted peers.""" + # Initially empty + self.assertEqual(list_blacklist_peers(), []) + + # Add some peers + peer_ids = ['peer1', 'peer2', 'peer3'] + add_blacklist_peers(peer_ids) + self.assertEqual(sorted(list_blacklist_peers()), sorted(peer_ids)) + + # Remove one + remove_blacklist_peers('peer2') + self.assertEqual(sorted(list_blacklist_peers()), ['peer1', 'peer3']) + + # Remove all + remove_blacklist_peers(['peer1', 'peer3']) + self.assertEqual(list_blacklist_peers(), []) diff --git a/hathor_tests/sysctl/test_p2p.py b/hathor_tests/sysctl/test_p2p.py index 460ea884c..1fa9a7805 100644 --- a/hathor_tests/sysctl/test_p2p.py +++ b/hathor_tests/sysctl/test_p2p.py @@ -180,3 +180,156 @@ def test_kill_connection_unknown_peer_id(self): with self.assertRaises(SysctlException): sysctl.unsafe_set('kill_connection', 'unknown-peer-id') + + def test_blacklist_add_peers_with_list(self): + """Test adding peers to blacklist with a list.""" + manager = self.create_peer() + connections = manager.connections + sysctl = ConnectionsManagerSysctl(connections) + + # Initially empty + self.assertEqual(sysctl.get('blacklist.list_peers'), []) + + # Add peers (using valid 64-char hex peer IDs) + peer_ids = [ + '0000000000000000000000000000000000000000000000000000000000000001', + '0000000000000000000000000000000000000000000000000000000000000002', + '0000000000000000000000000000000000000000000000000000000000000003', + ] + sysctl.unsafe_set('blacklist.add_peers', peer_ids) + + # Check they were added + blacklisted = sysctl.get('blacklist.list_peers') + self.assertEqual(sorted(blacklisted), sorted(peer_ids)) + + def test_blacklist_add_peers_with_string(self): + """Test adding a single peer to blacklist with a string.""" + manager = self.create_peer() + connections = manager.connections + sysctl = ConnectionsManagerSysctl(connections) + + # Add single peer (using valid 64-char hex peer ID) + peer_id = '0000000000000000000000000000000000000000000000000000000000000001' + sysctl.unsafe_set('blacklist.add_peers', peer_id) + + # Check it was added + blacklisted = sysctl.get('blacklist.list_peers') + self.assertEqual(blacklisted, [peer_id]) + + def test_blacklist_remove_peers_with_list(self): + """Test removing peers from blacklist with a list.""" + manager = self.create_peer() + connections = manager.connections + sysctl = ConnectionsManagerSysctl(connections) + + # Add peers first (using valid 64-char hex peer IDs) + peer_ids = [ + '0000000000000000000000000000000000000000000000000000000000000001', + '0000000000000000000000000000000000000000000000000000000000000002', + '0000000000000000000000000000000000000000000000000000000000000003', + ] + sysctl.unsafe_set('blacklist.add_peers', peer_ids) + self.assertEqual(sorted(sysctl.get('blacklist.list_peers')), sorted(peer_ids)) + + # Remove some + to_remove = [peer_ids[0], peer_ids[2]] + sysctl.unsafe_set('blacklist.remove_peers', to_remove) + + # Check they were removed + blacklisted = sysctl.get('blacklist.list_peers') + self.assertEqual(blacklisted, [peer_ids[1]]) + + def test_blacklist_remove_peers_with_string(self): + """Test removing a single peer from blacklist with a string.""" + manager = self.create_peer() + connections = manager.connections + sysctl = ConnectionsManagerSysctl(connections) + + # Add peers first (using valid 64-char hex peer IDs) + peer1 = '0000000000000000000000000000000000000000000000000000000000000001' + peer2 = '0000000000000000000000000000000000000000000000000000000000000002' + sysctl.unsafe_set('blacklist.add_peers', [peer1, peer2]) + + # Remove one + sysctl.unsafe_set('blacklist.remove_peers', peer1) + + # Check it was removed + blacklisted = sysctl.get('blacklist.list_peers') + self.assertEqual(blacklisted, [peer2]) + + def test_blacklist_list_peers(self): + """Test listing blacklisted peers.""" + manager = self.create_peer() + connections = manager.connections + sysctl = ConnectionsManagerSysctl(connections) + + # Initially empty + self.assertEqual(sysctl.get('blacklist.list_peers'), []) + + # Add some peers (using valid 64-char hex peer IDs) + peer1 = '0000000000000000000000000000000000000000000000000000000000000001' + peer2 = '0000000000000000000000000000000000000000000000000000000000000002' + peer3 = '0000000000000000000000000000000000000000000000000000000000000003' + sysctl.unsafe_set('blacklist.add_peers', [peer1, peer2]) + self.assertEqual(sorted(sysctl.get('blacklist.list_peers')), sorted([peer1, peer2])) + + # Add more + sysctl.unsafe_set('blacklist.add_peers', peer3) + self.assertEqual(sorted(sysctl.get('blacklist.list_peers')), sorted([peer1, peer2, peer3])) + + # Remove all + sysctl.unsafe_set('blacklist.remove_peers', [peer1, peer2, peer3]) + self.assertEqual(sysctl.get('blacklist.list_peers'), []) + + def test_blacklist_add_peers_invalid_peer_id(self): + """Test that adding invalid peer IDs raises an exception.""" + manager = self.create_peer() + connections = manager.connections + sysctl = ConnectionsManagerSysctl(connections) + + # Too short + with self.assertRaises(SysctlException) as cm: + sysctl.unsafe_set('blacklist.add_peers', 'invalid') + self.assertIn('Invalid peer-id format', str(cm.exception)) + + # Too long (more than 64 hex chars) + with self.assertRaises(SysctlException) as cm: + sysctl.unsafe_set('blacklist.add_peers', 'a' * 65) + self.assertIn('Invalid peer-id format', str(cm.exception)) + + # Invalid hex characters + with self.assertRaises(SysctlException) as cm: + sysctl.unsafe_set('blacklist.add_peers', 'g' * 64) + self.assertIn('Invalid peer-id format', str(cm.exception)) + + # Valid length but odd number of hex chars + with self.assertRaises(SysctlException) as cm: + sysctl.unsafe_set('blacklist.add_peers', 'a' * 63) + self.assertIn('Invalid peer-id format', str(cm.exception)) + + # List with some invalid + with self.assertRaises(SysctlException) as cm: + sysctl.unsafe_set('blacklist.add_peers', ['0' * 64, 'invalid']) + self.assertIn('Invalid peer-id format', str(cm.exception)) + + # Ensure nothing was added + self.assertEqual(sysctl.get('blacklist.list_peers'), []) + + def test_blacklist_remove_peers_invalid_peer_id(self): + """Test that removing invalid peer IDs raises an exception.""" + manager = self.create_peer() + connections = manager.connections + sysctl = ConnectionsManagerSysctl(connections) + + # Add a valid peer first + valid_peer_id = '0' * 64 + sysctl.unsafe_set('blacklist.add_peers', valid_peer_id) + self.assertEqual(sysctl.get('blacklist.list_peers'), [valid_peer_id]) + + # Try to remove with invalid peer ID + with self.assertRaises(SysctlException) as cm: + sysctl.unsafe_set('blacklist.remove_peers', 'invalid') + self.assertIn('Invalid peer-id format', str(cm.exception)) + + # Ensure the valid peer is still there + self.assertEqual(sysctl.get('blacklist.list_peers'), [valid_peer_id])