diff --git a/graphiti_core/utils/maintenance/community_operations.py b/graphiti_core/utils/maintenance/community_operations.py index af1e1f10f..0c9ea0a90 100644 --- a/graphiti_core/utils/maintenance/community_operations.py +++ b/graphiti_core/utils/maintenance/community_operations.py @@ -84,50 +84,108 @@ async def get_community_clusters( def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]: - # Implement the label propagation community detection algorithm. - # 1. Start with each node being assigned its own community - # 2. Each node will take on the community of the plurality of its neighbors - # 3. Ties are broken by going to the largest community - # 4. Continue until no communities change during propagation + """ + Implement the label propagation community detection algorithm. + + Algorithm: + 1. Start with each node being assigned its own community + 2. Each node will take on the community of the plurality of its neighbors + 3. Ties are broken by going to the largest community + 4. Continue until no communities change during propagation + + Oscillation prevention: + - Uses asynchronous updates (randomized node order) + - Maximum iteration limit to prevent infinite loops + - Early stopping if oscillation is detected + """ + import random + + MAX_ITERATIONS = 100 + OSCILLATION_CHECK_WINDOW = 5 community_map = {uuid: i for i, uuid in enumerate(projection.keys())} + node_uuids = list(projection.keys()) + + # Track history to detect oscillations + history: list[dict[str, int]] = [] - while True: - no_change = True - new_community_map: dict[str, int] = {} + for iteration in range(MAX_ITERATIONS): + # Asynchronous update: randomize node processing order to prevent oscillation + random.shuffle(node_uuids) - for uuid, neighbors in projection.items(): + changed_count = 0 + + for uuid in node_uuids: + neighbors = projection[uuid] curr_community = community_map[uuid] + # Count votes from neighbors community_candidates: dict[int, int] = defaultdict(int) for neighbor in neighbors: community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count + + if not community_candidates: + continue + + # Sort by count (descending), then by community ID for deterministic tie-breaking community_lst = [ (count, community) for community, count in community_candidates.items() ] + community_lst.sort(key=lambda x: (-x[0], x[1])) - community_lst.sort(reverse=True) - candidate_rank, community_candidate = community_lst[0] if community_lst else (0, -1) - if community_candidate != -1 and candidate_rank > 1: + candidate_rank, community_candidate = community_lst[0] + + # Determine new community: + # - If strong signal (edge count > 1), adopt the neighbor's community + # - Otherwise, prefer the larger community ID (original behavior) + if candidate_rank > 1: new_community = community_candidate else: new_community = max(community_candidate, curr_community) - new_community_map[uuid] = new_community - if new_community != curr_community: - no_change = False + community_map[uuid] = new_community + changed_count += 1 - if no_change: + # Check for convergence + if changed_count == 0: + logger.debug(f'Label propagation converged after {iteration + 1} iterations') break - community_map = new_community_map + # Check for oscillation by comparing with recent history + current_state = community_map.copy() + history.append(current_state) + + # Keep only recent history + if len(history) > OSCILLATION_CHECK_WINDOW: + history.pop(0) + + # Detect oscillation: if current state matches any recent state + if len(history) >= 2: + for past_state in history[:-1]: + if past_state == current_state: + logger.warning( + f'Label propagation oscillation detected at iteration {iteration + 1}, ' + 'stopping early' + ) + # Break out of the for loop + break + else: + # No oscillation detected, continue to next iteration + continue + # Oscillation detected, break out of the main loop + break + else: + logger.warning( + f'Label propagation reached maximum iterations ({MAX_ITERATIONS}) without converging' + ) - community_cluster_map = defaultdict(list) + # Group nodes by community + community_cluster_map: dict[int, list[str]] = defaultdict(list) for uuid, community in community_map.items(): community_cluster_map[community].append(uuid) - clusters = [cluster for cluster in community_cluster_map.values()] + clusters = list(community_cluster_map.values()) return clusters