diff --git a/pulsar-broker/src/test/java/org/apache/pulsar/client/impl/NegativeAcksTest.java b/pulsar-broker/src/test/java/org/apache/pulsar/client/impl/NegativeAcksTest.java index 382e7e16f018e..cb3fc115b3060 100644 --- a/pulsar-broker/src/test/java/org/apache/pulsar/client/impl/NegativeAcksTest.java +++ b/pulsar-broker/src/test/java/org/apache/pulsar/client/impl/NegativeAcksTest.java @@ -320,7 +320,7 @@ public void testNegativeAcksDeleteFromUnackedTracker() throws Exception { consumer.negativeAcknowledge(batchMessageId); consumer.negativeAcknowledge(batchMessageId2); consumer.negativeAcknowledge(batchMessageId3); - assertEquals(negativeAcksTracker.getNackedMessagesCount().orElse((long) -1).longValue(), 1L); + assertEquals(negativeAcksTracker.getNackedMessagesCount().orElse((long) -1).longValue(), 3L); assertEquals(unAckedMessageTracker.size(), 0); negativeAcksTracker.close(); } diff --git a/pulsar-client/src/main/java/org/apache/pulsar/client/impl/ConsumerImpl.java b/pulsar-client/src/main/java/org/apache/pulsar/client/impl/ConsumerImpl.java index 004adab56f529..844d69efeac23 100644 --- a/pulsar-client/src/main/java/org/apache/pulsar/client/impl/ConsumerImpl.java +++ b/pulsar-client/src/main/java/org/apache/pulsar/client/impl/ConsumerImpl.java @@ -2187,6 +2187,7 @@ private CompletableFuture> getRedeliveryMessageIdData(List messageIds) { int messagesFromQueue = 0; Message peek = incomingMessages.peek(); if (peek != null) { - MessageIdAdv messageId = MessageIdAdvUtils.discardBatch(peek.getMessageId()); + MessageId messageId = NegativeAcksTracker.discardPartitionIndex(peek.getMessageId()); if (!messageIds.contains(messageId)) { // first message is not expired, then no message is expired in queue. return 0; @@ -2751,7 +2752,7 @@ private int removeExpiredMessagesFromQueue(Set messageIds) { while (message != null) { decreaseIncomingMessageSize(message); messagesFromQueue++; - MessageIdAdv id = MessageIdAdvUtils.discardBatch(message.getMessageId()); + MessageId id = NegativeAcksTracker.discardPartitionIndex(message.getMessageId()); if (!messageIds.contains(id)) { messageIds.add(id); break; diff --git a/pulsar-client/src/main/java/org/apache/pulsar/client/impl/NegativeAcksTracker.java b/pulsar-client/src/main/java/org/apache/pulsar/client/impl/NegativeAcksTracker.java index e1724ebb85cda..52bc4b399d4f6 100644 --- a/pulsar-client/src/main/java/org/apache/pulsar/client/impl/NegativeAcksTracker.java +++ b/pulsar-client/src/main/java/org/apache/pulsar/client/impl/NegativeAcksTracker.java @@ -32,14 +32,15 @@ import org.apache.pulsar.client.api.MessageIdAdv; import org.apache.pulsar.client.api.RedeliveryBackoff; import org.apache.pulsar.client.impl.conf.ConsumerConfigurationData; -import org.apache.pulsar.common.util.collections.ConcurrentLongLongPairHashMap; +import org.apache.pulsar.common.util.collections.ConcurrentTripleLong2LongHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; class NegativeAcksTracker implements Closeable { private static final Logger log = LoggerFactory.getLogger(NegativeAcksTracker.class); - private ConcurrentLongLongPairHashMap nackedMessages = null; + // map (ledgerId, entryId, batchIndex) -> timestamp + private ConcurrentTripleLong2LongHashMap nackedMessages = null; private final ConsumerBase consumer; private final Timer timer; @@ -51,7 +52,7 @@ class NegativeAcksTracker implements Closeable { // Set a min delay to allow for grouping nacks within a single batch private static final long MIN_NACK_DELAY_NANOS = TimeUnit.MILLISECONDS.toNanos(100); - private static final long NON_PARTITIONED_TOPIC_PARTITION_INDEX = Long.MAX_VALUE; + private static final int DUMMY_PARTITION_INDEX = -2; public NegativeAcksTracker(ConsumerBase consumer, ConsumerConfigurationData conf) { this.consumer = consumer; @@ -77,11 +78,14 @@ private synchronized void triggerRedelivery(Timeout t) { // Group all the nacked messages into one single re-delivery request Set messagesToRedeliver = new HashSet<>(); long now = System.nanoTime(); - nackedMessages.forEach((ledgerId, entryId, partitionIndex, timestamp) -> { + nackedMessages.forEach((ledgerId, entryId, batchIndex, timestamp) -> { if (timestamp < now) { - MessageId msgId = new MessageIdImpl(ledgerId, entryId, - // need to covert non-partitioned topic partition index to -1 - (int) (partitionIndex == NON_PARTITIONED_TOPIC_PARTITION_INDEX ? -1 : partitionIndex)); + MessageId msgId; + if (batchIndex == -1) { + msgId = new MessageIdImpl(ledgerId, entryId, -1); + } else { + msgId = new BatchMessageIdImpl(ledgerId, entryId, -1, (int) batchIndex); + } addChunkedMessageIdsAndRemoveFromSequenceMap(msgId, messagesToRedeliver, this.consumer); messagesToRedeliver.add(msgId); } @@ -89,8 +93,9 @@ private synchronized void triggerRedelivery(Timeout t) { if (!messagesToRedeliver.isEmpty()) { for (MessageId messageId : messagesToRedeliver) { - nackedMessages.remove(((MessageIdImpl) messageId).getLedgerId(), - ((MessageIdImpl) messageId).getEntryId()); + MessageIdAdv messageIdAdv = (MessageIdAdv) messageId; + nackedMessages.remove(messageIdAdv.getLedgerId(), messageIdAdv.getEntryId(), + messageIdAdv.getBatchIndex()); } consumer.onNegativeAcksSend(messagesToRedeliver); log.info("[{}] {} messages will be re-delivered", consumer, messagesToRedeliver.size()); @@ -110,10 +115,7 @@ public synchronized void add(Message message) { private synchronized void add(MessageId messageId, int redeliveryCount) { if (nackedMessages == null) { - nackedMessages = ConcurrentLongLongPairHashMap.newBuilder() - .autoShrink(true) - .concurrencyLevel(1) - .build(); + nackedMessages = new ConcurrentTripleLong2LongHashMap(); } long backoffNs; @@ -122,14 +124,9 @@ private synchronized void add(MessageId messageId, int redeliveryCount) { } else { backoffNs = nackDelayNanos; } - MessageIdAdv messageIdAdv = MessageIdAdvUtils.discardBatch(messageId); - // ConcurrentLongLongPairHashMap requires the key and value >=0. - // partitionIndex is -1 if the message is from a non-partitioned topic, but we don't use - // partitionIndex actually, so we can set it to Long.MAX_VALUE in the case of non-partitioned topic to - // avoid exception from ConcurrentLongLongPairHashMap. + MessageIdAdv messageIdAdv = (MessageIdAdv) messageId; nackedMessages.put(messageIdAdv.getLedgerId(), messageIdAdv.getEntryId(), - messageIdAdv.getPartitionIndex() >= 0 ? messageIdAdv.getPartitionIndex() : - NON_PARTITIONED_TOPIC_PARTITION_INDEX, System.nanoTime() + backoffNs); + messageIdAdv.getBatchIndex(), System.nanoTime() + backoffNs); if (this.timeout == null) { // Schedule a task and group all the redeliveries for same period. Leave a small buffer to allow for @@ -138,9 +135,28 @@ private synchronized void add(MessageId messageId, int redeliveryCount) { } } + /** + * Discard the partition index from the message id. + * @param messageId + * @return + */ + static public MessageId discardPartitionIndex(MessageId messageId) { + if (messageId instanceof BatchMessageIdImpl) { + BatchMessageIdImpl batchMessageId = (BatchMessageIdImpl) messageId; + return new BatchMessageIdImpl(batchMessageId.getLedgerId(), batchMessageId.getEntryId(), + DUMMY_PARTITION_INDEX, batchMessageId.getBatchIndex(), batchMessageId.getBatchSize(), + batchMessageId.getAckSet()); + } else if (messageId instanceof MessageIdImpl) { + MessageIdImpl messageID = (MessageIdImpl) messageId; + return new MessageIdImpl(messageID.getLedgerId(), messageID.getEntryId(), DUMMY_PARTITION_INDEX); + } else { + return messageId; + } + } + @VisibleForTesting Optional getNackedMessagesCount() { - return Optional.ofNullable(nackedMessages).map(ConcurrentLongLongPairHashMap::size); + return Optional.ofNullable(nackedMessages).map(ConcurrentTripleLong2LongHashMap::size); } @Override diff --git a/pulsar-common/src/main/java/org/apache/pulsar/common/util/collections/ConcurrentTripleLong2LongHashMap.java b/pulsar-common/src/main/java/org/apache/pulsar/common/util/collections/ConcurrentTripleLong2LongHashMap.java new file mode 100644 index 0000000000000..f4ba09113ca7a --- /dev/null +++ b/pulsar-common/src/main/java/org/apache/pulsar/common/util/collections/ConcurrentTripleLong2LongHashMap.java @@ -0,0 +1,76 @@ +package org.apache.pulsar.common.util.collections; + +import java.util.HashMap; + +public class ConcurrentTripleLong2LongHashMap { + public class TripleLong{ + public long first; + public long second; + public long third; + @Override + public int hashCode() { + return Long.hashCode(first) ^ Long.hashCode(second) ^ Long.hashCode(third); + } + @Override + public boolean equals(Object obj) { + if(obj instanceof TripleLong){ + TripleLong other = (TripleLong) obj; + return first == other.first && second == other.second && third == other.third; + } + return false; + } + } + + private HashMap map; + public ConcurrentTripleLong2LongHashMap(){ + // TODO: use hashmap for now + map = new HashMap<>(); + } + public void put(long first, long second, long third, long value){ + TripleLong key = new TripleLong(); + key.first = first; + key.second = second; + key.third = third; + map.put(key, value); + } + public long get(long first, long second, long third){ + TripleLong key = new TripleLong(); + key.first = first; + key.second = second; + key.third = third; + return map.get(key); + } + public long remove(long first, long second, long third){ + TripleLong key = new TripleLong(); + key.first = first; + key.second = second; + key.third = third; + return map.remove(key); + } + public boolean containsKey(long first, long second, long third){ + TripleLong key = new TripleLong(); + key.first = first; + key.second = second; + key.third = third; + return map.containsKey(key); + } + public void clear(){ + map.clear(); + } + public long size(){ + return map.size(); + } + public boolean isEmpty() { + return map.isEmpty(); + } + + public interface TripleLongConsumer { + void call(long first, long second, long third, long value); + } + public void forEach(TripleLongConsumer consumer){ + for(TripleLong key : map.keySet()){ + consumer.call(key.first, key.second, key.third, map.get(key)); + } + } + +}