Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pulsar.broker.loadbalance.extensions.models;

import java.util.Comparator;
import java.util.Map;
import org.apache.pulsar.policies.data.loadbalancer.BundleData;
import org.apache.pulsar.policies.data.loadbalancer.TimeAverageMessageData;

/**
* A strict comparator for BundleData that provides deterministic ordering
* without threshold-based comparisons to avoid transitivity violations.
*/
public class BundleDataComparator implements Comparator<BundleData> {

/**
* Compares two bundle entries using BundleData
* 1. Short-term data comparison (inbound bandwidth, outbound bandwidth, message rate)
* 2. Long-term data comparison (same hierarchy as short-term)
*
* @param bundleA first bundle entry
* @param bundleB second bundle entry
* @return negative if a < b, positive if a > b, zero if equal
*/
@Override
public int compare(BundleData bundleA, BundleData bundleB) {

// First compare short-term data (same hierarchy as TimeAverageMessageData.compareTo() but strict)
int result = compareShortTermData(bundleA, bundleB);
if (result != 0) {
return result;
}

// If short-term data is equal, compare long-term data
result = compareLongTermData(bundleA, bundleB);
if (result != 0) {
return result;
}

// If all metrics are equal
return 0;
}

/**
* Compare short-term data using the same hierarchy as TimeAverageMessageData.compareTo() but with strict comparisons.
*/
private int compareShortTermData(BundleData bundleA, BundleData bundleB) {
TimeAverageMessageData shortTermA = bundleA.getShortTermData();
TimeAverageMessageData shortTermB = bundleB.getShortTermData();

// 1. Inbound bandwidth (strict comparison)
int result = Double.compare(shortTermA.getMsgThroughputIn(), shortTermB.getMsgThroughputIn());
if (result != 0) {
return result;
}

// 2. Outbound bandwidth (strict comparison)
result = Double.compare(shortTermA.getMsgThroughputOut(), shortTermB.getMsgThroughputOut());
if (result != 0) {
return result;
}

// 3. Total message rate (strict comparison)
double totalMsgRateA = shortTermA.getMsgRateIn() + shortTermA.getMsgRateOut();
double totalMsgRateB = shortTermB.getMsgRateIn() + shortTermB.getMsgRateOut();
return Double.compare(totalMsgRateA, totalMsgRateB);
}

/**
* Compare long-term data using the same hierarchy as TimeAverageMessageData.compareTo() but with strict comparisons.
*/
private int compareLongTermData(BundleData bundleA, BundleData bundleB) {
TimeAverageMessageData longTermA = bundleA.getLongTermData();
TimeAverageMessageData longTermB = bundleB.getLongTermData();

// 1. Inbound bandwidth (strict comparison)
int result = Double.compare(longTermA.getMsgThroughputIn(), longTermB.getMsgThroughputIn());
if (result != 0) {
return result;
}

// 2. Outbound bandwidth (strict comparison)
result = Double.compare(longTermA.getMsgThroughputOut(), longTermB.getMsgThroughputOut());
if (result != 0) {
return result;
}

// 3. Total message rate (strict comparison)
double totalMsgRateA = longTermA.getMsgRateIn() + longTermA.getMsgRateOut();
double totalMsgRateB = longTermB.getMsgRateIn() + longTermB.getMsgRateOut();
return Double.compare(totalMsgRateA, totalMsgRateB);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pulsar.broker.loadbalance.extensions.models;

import java.util.Comparator;
import java.util.Map;
import org.apache.pulsar.policies.data.loadbalancer.NamespaceBundleStats;

/**
* A strict comparator for NamespaceBundleStats that provides deterministic ordering
* without threshold-based comparisons to avoid transitivity violations.
*/
public class StrictNamespaceBundleStatsComparator implements Comparator<NamespaceBundleStats> {

/**
* Compares two bundle entries
* Comparison hierarchy:
* 1. Inbound bandwidth (msgThroughputIn)
* 2. Outbound bandwidth (msgThroughputOut)
* 3. Total message rate (msgRateIn + msgRateOut)
* 4. Total topics and connections (topics + consumerCount + producerCount)
* 5. Cache size (cacheSize)
*
* @param bundleA first bundle entry
* @param bundleB second bundle entry
* @return negative if a < b, positive if a > b, zero if equal
*/
@Override
public int compare(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) {
int result = compareByBandwidthIn(bundleA, bundleB);
if (result != 0) {
return result;
}

result = compareByBandwidthOut(bundleA, bundleB);
if (result != 0) {
return result;
}

result = compareByMsgRate(bundleA, bundleB);
if (result != 0) {
return result;
}

result = compareByTopicConnections(bundleA, bundleB);
if (result != 0) {
return result;
}

result = compareByCacheSize(bundleA, bundleB);
if (result != 0) {
return result;
}

// If all metrics are equal
return 0;
}

private int compareByBandwidthIn(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) {
return Double.compare(bundleA.msgThroughputIn, bundleB.msgThroughputIn);
}

private int compareByBandwidthOut(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) {
return Double.compare(bundleA.msgThroughputOut, bundleB.msgThroughputOut);
}

private int compareByMsgRate(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) {
double totalMsgRateA = bundleA.msgRateIn + bundleA.msgRateOut;
double totalMsgRateB = bundleB.msgRateIn + bundleB.msgRateOut;
return Double.compare(totalMsgRateA, totalMsgRateB);
}

private int compareByTopicConnections(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) {
long totalConnectionsA = bundleA.topics + bundleA.consumerCount + bundleA.producerCount;
long totalConnectionsB = bundleB.topics + bundleB.consumerCount + bundleB.producerCount;
return Long.compare(totalConnectionsA, totalConnectionsB);
}

private int compareByCacheSize(NamespaceBundleStats bundleA, NamespaceBundleStats bundleB) {
return Long.compare(bundleA.cacheSize, bundleB.cacheSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -47,7 +48,8 @@
public class TopKBundles {

// temp array for sorting
private final List<Map.Entry<String, ? extends Comparable>> arr = new ArrayList<>();
private final List<Map.Entry<String, NamespaceBundleStats>> arr = new ArrayList<>();
public static final StrictNamespaceBundleStatsComparator strictNamespaceBundleStatsComparator = new StrictNamespaceBundleStatsComparator();

private final TopBundlesLoadData loadData = new TopBundlesLoadData();

Expand Down Expand Up @@ -100,7 +102,7 @@ public void update(Map<String, NamespaceBundleStats> bundleStats, int topk) {
return;
}
topk = Math.min(topk, arr.size());
partitionSort(arr, topk);
TopKBundles.partitionSort(arr, topk, strictNamespaceBundleStatsComparator);

for (int i = topk - 1; i >= 0; i--) {
var etr = arr.get(i);
Expand All @@ -112,17 +114,17 @@ public void update(Map<String, NamespaceBundleStats> bundleStats, int topk) {
}
}

public static void partitionSort(List<Map.Entry<String, ? extends Comparable>> arr, int k) {
public static <T> void partitionSort(List<Map.Entry<String, T>> arr, int k, Comparator<T> comparator) {
int start = 0;
int end = arr.size() - 1;
int target = k - 1;
while (start < end) {
int lo = start;
int hi = end;
int mid = lo;
var pivot = arr.get(hi).getValue();
var pivot = arr.get(hi);
while (mid <= hi) {
int cmp = pivot.compareTo(arr.get(mid).getValue());
int cmp = comparator.compare(pivot.getValue(), arr.get(mid).getValue());
if (cmp < 0) {
var tmp = arr.get(lo);
arr.set(lo++, arr.get(mid));
Expand All @@ -145,7 +147,7 @@ public static void partitionSort(List<Map.Entry<String, ? extends Comparable>> a
start = mid;
}
}
Collections.sort(arr.subList(0, end), (a, b) -> b.getValue().compareTo(a.getValue()));
Collections.sort(arr.subList(0, end),(a,b)-> comparator.compare(b.getValue(), a.getValue()));
}

private boolean hasPolicies(String bundle) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.apache.pulsar.broker.loadbalance.LoadSheddingStrategy;
import org.apache.pulsar.broker.loadbalance.ModularLoadManager;
import org.apache.pulsar.broker.loadbalance.ModularLoadManagerStrategy;
import org.apache.pulsar.broker.loadbalance.extensions.models.BundleDataComparator;
import org.apache.pulsar.broker.loadbalance.extensions.models.TopKBundles;
import org.apache.pulsar.broker.loadbalance.impl.LoadManagerShared.BrokerTopicLoadingPredicate;
import org.apache.pulsar.broker.resources.PulsarResources;
Expand Down Expand Up @@ -189,7 +190,9 @@ public class ModularLoadManagerImpl implements ModularLoadManager {
private final Set<String> knownBrokers = new HashSet<>();
private Map<String, String> bundleBrokerAffinityMap;
// array used for sorting and select topK bundles
private final List<Map.Entry<String, ? extends Comparable>> bundleArr = new ArrayList<>();
private final List<Map.Entry<String, BundleData>> bundleArr = new ArrayList<>();
public static final BundleDataComparator bundleDataComparator = new BundleDataComparator();



/**
Expand Down Expand Up @@ -1170,7 +1173,7 @@ private int selectTopKBundle() {
// no bundle to update
return 0;
}
TopKBundles.partitionSort(bundleArr, updateBundleCount);
TopKBundles.partitionSort(bundleArr, updateBundleCount, bundleDataComparator);
return updateBundleCount;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static org.mockito.Mockito.mock;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -288,8 +289,8 @@ public void testLoadBalancerSheddingBundlesWithPoliciesEnabledConfig() throws Me
public void testPartitionSort() {

Random rand = new Random();
List<Map.Entry<String, ? extends Comparable>> actual = new ArrayList<>();
List<Map.Entry<String, ? extends Comparable>> expected = new ArrayList<>();
List<Map.Entry<String, Integer>> actual = new ArrayList<>();
List<Map.Entry<String, Integer>> expected = new ArrayList<>();

for (int j = 0; j < 100; j++) {
Map<String, Integer> map = new HashMap<>();
Expand All @@ -305,8 +306,8 @@ public void testPartitionSort() {
expected.add(etr);
}
int topk = rand.nextInt(max) + 1;
TopKBundles.partitionSort(actual, topk);
Collections.sort(expected, (a, b) -> b.getValue().compareTo(a.getValue()));
TopKBundles.partitionSort(actual, topk, (a, b) -> b.compareTo(a));
Collections.sort(expected, (a, b) -> a.getValue().compareTo(b.getValue()));
String errorMsg = null;
for (int i = 0; i < topk; i++) {
Integer l = (Integer) actual.get(i).getValue();
Expand All @@ -319,4 +320,57 @@ public void testPartitionSort() {
}
}
}

@Test
public void testCollectionsSortFailsWithTransitivityViolation() {
// Create many elements with values that are more likely to trigger transitivity violation detection
// Using the same approach as testPartitionSortTransitivityIssue but with Collections.sort
Random rnd = new Random(0);
ArrayList<NamespaceBundleStats> stats = new ArrayList<>();

// Create 1000 elements with values around the threshold boundary
for (int i = 0; i < 1000; ++i) {
NamespaceBundleStats statsA = new NamespaceBundleStats();
statsA.msgThroughputIn = 4 * 75000 * rnd.nextDouble(); // Values around threshold (1e5)
statsA.msgThroughputOut = 75000000 - (4 * (75000 * rnd.nextDouble()));
statsA.msgRateIn = 4 * 75 * rnd.nextDouble();
statsA.msgRateOut = 75000 - (4 * 75 * rnd.nextDouble());
statsA.topics = i;
statsA.consumerCount = i;
statsA.producerCount = 4 * rnd.nextInt(375);
statsA.cacheSize = 75000000 - (rnd.nextInt(4 * 75000));
stats.add(statsA);
}

List<Map.Entry<String, NamespaceBundleStats>> bundleEntries = new ArrayList<>();
for (NamespaceBundleStats s : stats) {
bundleEntries.add(new HashMap.SimpleEntry<>("bundle-" + s.msgThroughputIn, s));
}

// This should throw IllegalArgumentException due to transitivity violation
try {
Collections.sort(bundleEntries, (a, b) -> a.getValue().compareTo(b.getValue()));
System.out.println("SUCCESS: Collections.sort completed without throwing exception!");
} catch (IllegalArgumentException e) {
System.out.println("ERROR: Collections.sort detected transitivity violation!");
System.out.println("Exception message: " + e.getMessage());

// Verify the exception message contains the expected text
assertTrue(e.getMessage().contains("Comparison method violates its general contract") ||
e.getMessage().contains("transitivity") ||
e.getMessage().contains("comparison"),
"Expected IllegalArgumentException about comparison contract violation, got: " + e.getMessage());
}

// Now test that TopKBundles.partitionSort works with the same data
// Create bundle entries for partitionSort
List<Map.Entry<String, NamespaceBundleStats>> bundleEntriesForPartitionSort = new ArrayList<>();
for (NamespaceBundleStats s : stats) {
bundleEntriesForPartitionSort.add(new HashMap.SimpleEntry<>("bundle-" + s.msgThroughputIn, s));
}

// This should work without throwing an exception
TopKBundles.partitionSort(bundleEntriesForPartitionSort, 10, TopKBundles.strictNamespaceBundleStatsComparator.reversed());

}
}