diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/extensions/models/BundleDataComparator.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/extensions/models/BundleDataComparator.java new file mode 100644 index 0000000000000..d4d3bcba3d7ee --- /dev/null +++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/extensions/models/BundleDataComparator.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pulsar.broker.loadbalance.extensions.models; + +import java.util.Comparator; +import java.util.Map; +import org.apache.pulsar.policies.data.loadbalancer.BundleData; +import org.apache.pulsar.policies.data.loadbalancer.TimeAverageMessageData; + +/** + * A strict comparator for BundleData that provides deterministic ordering + * without threshold-based comparisons to avoid transitivity violations. + */ +public class BundleDataComparator implements Comparator { + + /** + * Compares two bundle entries using BundleData + * 1. Short-term data comparison (inbound bandwidth, outbound bandwidth, message rate) + * 2. Long-term data comparison (same hierarchy as short-term) + * + * @param bundleA first bundle entry + * @param bundleB second bundle entry + * @return negative if a < b, positive if a > b, zero if equal + */ + @Override + public int compare(BundleData bundleA, BundleData bundleB) { + + // First compare short-term data (same hierarchy as TimeAverageMessageData.compareTo() but strict) + int result = compareShortTermData(bundleA, bundleB); + if (result != 0) { + return result; + } + + // If short-term data is equal, compare long-term data + result = compareLongTermData(bundleA, bundleB); + if (result != 0) { + return result; + } + + // If all metrics are equal + return 0; + } + + /** + * Compare short-term data using the same hierarchy as TimeAverageMessageData.compareTo() but with strict comparisons. + */ + private int compareShortTermData(BundleData bundleA, BundleData bundleB) { + TimeAverageMessageData shortTermA = bundleA.getShortTermData(); + TimeAverageMessageData shortTermB = bundleB.getShortTermData(); + + // 1. Inbound bandwidth (strict comparison) + int result = Double.compare(shortTermA.getMsgThroughputIn(), shortTermB.getMsgThroughputIn()); + if (result != 0) { + return result; + } + + // 2. Outbound bandwidth (strict comparison) + result = Double.compare(shortTermA.getMsgThroughputOut(), shortTermB.getMsgThroughputOut()); + if (result != 0) { + return result; + } + + // 3. Total message rate (strict comparison) + double totalMsgRateA = shortTermA.getMsgRateIn() + shortTermA.getMsgRateOut(); + double totalMsgRateB = shortTermB.getMsgRateIn() + shortTermB.getMsgRateOut(); + return Double.compare(totalMsgRateA, totalMsgRateB); + } + + /** + * Compare long-term data using the same hierarchy as TimeAverageMessageData.compareTo() but with strict comparisons. + */ + private int compareLongTermData(BundleData bundleA, BundleData bundleB) { + TimeAverageMessageData longTermA = bundleA.getLongTermData(); + TimeAverageMessageData longTermB = bundleB.getLongTermData(); + + // 1. Inbound bandwidth (strict comparison) + int result = Double.compare(longTermA.getMsgThroughputIn(), longTermB.getMsgThroughputIn()); + if (result != 0) { + return result; + } + + // 2. Outbound bandwidth (strict comparison) + result = Double.compare(longTermA.getMsgThroughputOut(), longTermB.getMsgThroughputOut()); + if (result != 0) { + return result; + } + + // 3. Total message rate (strict comparison) + double totalMsgRateA = longTermA.getMsgRateIn() + longTermA.getMsgRateOut(); + double totalMsgRateB = longTermB.getMsgRateIn() + longTermB.getMsgRateOut(); + return Double.compare(totalMsgRateA, totalMsgRateB); + } +} diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/extensions/models/StrictNamespaceBundleStatsComparator.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/extensions/models/StrictNamespaceBundleStatsComparator.java new file mode 100644 index 0000000000000..2d6da4c3a3c65 --- /dev/null +++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/extensions/models/StrictNamespaceBundleStatsComparator.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pulsar.broker.loadbalance.extensions.models; + +import java.util.Comparator; +import java.util.Map; +import org.apache.pulsar.policies.data.loadbalancer.NamespaceBundleStats; + +/** + * A strict comparator for NamespaceBundleStats that provides deterministic ordering + * without threshold-based comparisons to avoid transitivity violations. + */ +public class StrictNamespaceBundleStatsComparator implements Comparator { + + /** + * Compares two bundle entries + * Comparison hierarchy: + * 1. Inbound bandwidth (msgThroughputIn) + * 2. Outbound bandwidth (msgThroughputOut) + * 3. Total message rate (msgRateIn + msgRateOut) + * 4. Total topics and connections (topics + consumerCount + producerCount) + * 5. Cache size (cacheSize) + * + * @param bundleA first bundle entry + * @param bundleB second bundle entry + * @return negative if a < b, positive if a > b, zero if equal + */ + @Override + public int compare(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) { + int result = compareByBandwidthIn(bundleA, bundleB); + if (result != 0) { + return result; + } + + result = compareByBandwidthOut(bundleA, bundleB); + if (result != 0) { + return result; + } + + result = compareByMsgRate(bundleA, bundleB); + if (result != 0) { + return result; + } + + result = compareByTopicConnections(bundleA, bundleB); + if (result != 0) { + return result; + } + + result = compareByCacheSize(bundleA, bundleB); + if (result != 0) { + return result; + } + + // If all metrics are equal + return 0; + } + + private int compareByBandwidthIn(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) { + return Double.compare(bundleA.msgThroughputIn, bundleB.msgThroughputIn); + } + + private int compareByBandwidthOut(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) { + return Double.compare(bundleA.msgThroughputOut, bundleB.msgThroughputOut); + } + + private int compareByMsgRate(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) { + double totalMsgRateA = bundleA.msgRateIn + bundleA.msgRateOut; + double totalMsgRateB = bundleB.msgRateIn + bundleB.msgRateOut; + return Double.compare(totalMsgRateA, totalMsgRateB); + } + + private int compareByTopicConnections(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) { + long totalConnectionsA = bundleA.topics + bundleA.consumerCount + bundleA.producerCount; + long totalConnectionsB = bundleB.topics + bundleB.consumerCount + bundleB.producerCount; + return Long.compare(totalConnectionsA, totalConnectionsB); + } + + private int compareByCacheSize(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) { + return Long.compare(bundleA.cacheSize, bundleB.cacheSize); + } +} diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/extensions/models/TopKBundles.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/extensions/models/TopKBundles.java index 481e907d04439..d3765f9f31d2b 100644 --- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/extensions/models/TopKBundles.java +++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/extensions/models/TopKBundles.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Set; @@ -47,7 +48,8 @@ public class TopKBundles { // temp array for sorting - private final List> arr = new ArrayList<>(); + private final List> arr = new ArrayList<>(); + public static final StrictNamespaceBundleStatsComparator strictNamespaceBundleStatsComparator = new StrictNamespaceBundleStatsComparator(); private final TopBundlesLoadData loadData = new TopBundlesLoadData(); @@ -100,7 +102,7 @@ public void update(Map bundleStats, int topk) { return; } topk = Math.min(topk, arr.size()); - partitionSort(arr, topk); + TopKBundles.partitionSort(arr, topk, strictNamespaceBundleStatsComparator); for (int i = topk - 1; i >= 0; i--) { var etr = arr.get(i); @@ -112,7 +114,7 @@ public void update(Map bundleStats, int topk) { } } - public static void partitionSort(List> arr, int k) { + public static void partitionSort(List> arr, int k, Comparator comparator) { int start = 0; int end = arr.size() - 1; int target = k - 1; @@ -120,9 +122,9 @@ public static void partitionSort(List> a int lo = start; int hi = end; int mid = lo; - var pivot = arr.get(hi).getValue(); + var pivot = arr.get(hi); while (mid <= hi) { - int cmp = pivot.compareTo(arr.get(mid).getValue()); + int cmp = comparator.compare(pivot.getValue(), arr.get(mid).getValue()); if (cmp < 0) { var tmp = arr.get(lo); arr.set(lo++, arr.get(mid)); @@ -145,7 +147,7 @@ public static void partitionSort(List> a start = mid; } } - Collections.sort(arr.subList(0, end), (a, b) -> b.getValue().compareTo(a.getValue())); + Collections.sort(arr.subList(0, end),(a,b)-> comparator.compare(b.getValue(), a.getValue())); } private boolean hasPolicies(String bundle) { diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/impl/ModularLoadManagerImpl.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/impl/ModularLoadManagerImpl.java index 75c60e2687942..129852a663a0d 100644 --- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/impl/ModularLoadManagerImpl.java +++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/loadbalance/impl/ModularLoadManagerImpl.java @@ -59,6 +59,7 @@ import org.apache.pulsar.broker.loadbalance.LoadSheddingStrategy; import org.apache.pulsar.broker.loadbalance.ModularLoadManager; import org.apache.pulsar.broker.loadbalance.ModularLoadManagerStrategy; +import org.apache.pulsar.broker.loadbalance.extensions.models.BundleDataComparator; import org.apache.pulsar.broker.loadbalance.extensions.models.TopKBundles; import org.apache.pulsar.broker.loadbalance.impl.LoadManagerShared.BrokerTopicLoadingPredicate; import org.apache.pulsar.broker.resources.PulsarResources; @@ -189,7 +190,9 @@ public class ModularLoadManagerImpl implements ModularLoadManager { private final Set knownBrokers = new HashSet<>(); private Map bundleBrokerAffinityMap; // array used for sorting and select topK bundles - private final List> bundleArr = new ArrayList<>(); + private final List> bundleArr = new ArrayList<>(); + public static final BundleDataComparator bundleDataComparator = new BundleDataComparator(); + /** @@ -1170,7 +1173,7 @@ private int selectTopKBundle() { // no bundle to update return 0; } - TopKBundles.partitionSort(bundleArr, updateBundleCount); + TopKBundles.partitionSort(bundleArr, updateBundleCount, bundleDataComparator); return updateBundleCount; } } diff --git a/pulsar-broker/src/test/java/org/apache/pulsar/broker/loadbalance/extensions/models/TopKBundlesTest.java b/pulsar-broker/src/test/java/org/apache/pulsar/broker/loadbalance/extensions/models/TopKBundlesTest.java index 42c1516557954..ff8c91254b834 100644 --- a/pulsar-broker/src/test/java/org/apache/pulsar/broker/loadbalance/extensions/models/TopKBundlesTest.java +++ b/pulsar-broker/src/test/java/org/apache/pulsar/broker/loadbalance/extensions/models/TopKBundlesTest.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.mock; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -288,8 +289,8 @@ public void testLoadBalancerSheddingBundlesWithPoliciesEnabledConfig() throws Me public void testPartitionSort() { Random rand = new Random(); - List> actual = new ArrayList<>(); - List> expected = new ArrayList<>(); + List> actual = new ArrayList<>(); + List> expected = new ArrayList<>(); for (int j = 0; j < 100; j++) { Map map = new HashMap<>(); @@ -305,8 +306,8 @@ public void testPartitionSort() { expected.add(etr); } int topk = rand.nextInt(max) + 1; - TopKBundles.partitionSort(actual, topk); - Collections.sort(expected, (a, b) -> b.getValue().compareTo(a.getValue())); + TopKBundles.partitionSort(actual, topk, (a, b) -> b.compareTo(a)); + Collections.sort(expected, (a, b) -> a.getValue().compareTo(b.getValue())); String errorMsg = null; for (int i = 0; i < topk; i++) { Integer l = (Integer) actual.get(i).getValue(); @@ -319,4 +320,57 @@ public void testPartitionSort() { } } } + + @Test + public void testCollectionsSortFailsWithTransitivityViolation() { + // Create many elements with values that are more likely to trigger transitivity violation detection + // Using the same approach as testPartitionSortTransitivityIssue but with Collections.sort + Random rnd = new Random(0); + ArrayList stats = new ArrayList<>(); + + // Create 1000 elements with values around the threshold boundary + for (int i = 0; i < 1000; ++i) { + NamespaceBundleStats statsA = new NamespaceBundleStats(); + statsA.msgThroughputIn = 4 * 75000 * rnd.nextDouble(); // Values around threshold (1e5) + statsA.msgThroughputOut = 75000000 - (4 * (75000 * rnd.nextDouble())); + statsA.msgRateIn = 4 * 75 * rnd.nextDouble(); + statsA.msgRateOut = 75000 - (4 * 75 * rnd.nextDouble()); + statsA.topics = i; + statsA.consumerCount = i; + statsA.producerCount = 4 * rnd.nextInt(375); + statsA.cacheSize = 75000000 - (rnd.nextInt(4 * 75000)); + stats.add(statsA); + } + + List> bundleEntries = new ArrayList<>(); + for (NamespaceBundleStats s : stats) { + bundleEntries.add(new HashMap.SimpleEntry<>("bundle-" + s.msgThroughputIn, s)); + } + + // This should throw IllegalArgumentException due to transitivity violation + try { + Collections.sort(bundleEntries, (a, b) -> a.getValue().compareTo(b.getValue())); + System.out.println("SUCCESS: Collections.sort completed without throwing exception!"); + } catch (IllegalArgumentException e) { + System.out.println("ERROR: Collections.sort detected transitivity violation!"); + System.out.println("Exception message: " + e.getMessage()); + + // Verify the exception message contains the expected text + assertTrue(e.getMessage().contains("Comparison method violates its general contract") || + e.getMessage().contains("transitivity") || + e.getMessage().contains("comparison"), + "Expected IllegalArgumentException about comparison contract violation, got: " + e.getMessage()); + } + + // Now test that TopKBundles.partitionSort works with the same data + // Create bundle entries for partitionSort + List> bundleEntriesForPartitionSort = new ArrayList<>(); + for (NamespaceBundleStats s : stats) { + bundleEntriesForPartitionSort.add(new HashMap.SimpleEntry<>("bundle-" + s.msgThroughputIn, s)); + } + + // This should work without throwing an exception + TopKBundles.partitionSort(bundleEntriesForPartitionSort, 10, TopKBundles.strictNamespaceBundleStatsComparator.reversed()); + + } }