diff --git a/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumerFactory.java b/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumerFactory.java index bd7d9ad8c09f..62ad62b25b63 100644 --- a/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumerFactory.java +++ b/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisConsumerFactory.java @@ -32,7 +32,7 @@ public class KinesisConsumerFactory extends StreamConsumerFactory { @Override public StreamMetadataProvider createPartitionMetadataProvider(String clientId, int partition) { - return new KinesisStreamMetadataProvider(clientId, _streamConfig); + return new KinesisStreamMetadataProvider(clientId, _streamConfig, partition); } @Override diff --git a/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisStreamMetadataProvider.java b/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisStreamMetadataProvider.java index bcf644fd2dfc..8194ebef3327 100644 --- a/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisStreamMetadataProvider.java +++ b/pinot-plugins/pinot-stream-ingestion/pinot-kinesis/src/main/java/org/apache/pinot/plugin/stream/kinesis/KinesisStreamMetadataProvider.java @@ -43,6 +43,7 @@ import org.apache.pinot.spi.stream.StreamPartitionMsgOffset; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.kinesis.model.SequenceNumberRange; import software.amazon.awssdk.services.kinesis.model.Shard; @@ -55,32 +56,60 @@ public class KinesisStreamMetadataProvider implements StreamMetadataProvider { private final StreamConsumerFactory _kinesisStreamConsumerFactory; private final String _clientId; private final int _fetchTimeoutMs; + private final int _partitionId; private static final Logger LOGGER = LoggerFactory.getLogger(KinesisStreamMetadataProvider.class); public KinesisStreamMetadataProvider(String clientId, StreamConfig streamConfig) { + this(clientId, streamConfig, Integer.MIN_VALUE); + } + + public KinesisStreamMetadataProvider(String clientId, StreamConfig streamConfig, int partitionId) { KinesisConfig kinesisConfig = new KinesisConfig(streamConfig); _kinesisConnectionHandler = new KinesisConnectionHandler(kinesisConfig); _kinesisStreamConsumerFactory = StreamConsumerFactoryProvider.create(streamConfig); _clientId = clientId; + _partitionId = partitionId; _fetchTimeoutMs = streamConfig.getFetchTimeoutMillis(); } public KinesisStreamMetadataProvider(String clientId, StreamConfig streamConfig, KinesisConnectionHandler kinesisConnectionHandler, StreamConsumerFactory streamConsumerFactory) { + this(clientId, streamConfig, Integer.MIN_VALUE, kinesisConnectionHandler, streamConsumerFactory); + } + + public KinesisStreamMetadataProvider(String clientId, StreamConfig streamConfig, int partitionId, + KinesisConnectionHandler kinesisConnectionHandler, StreamConsumerFactory streamConsumerFactory) { _kinesisConnectionHandler = kinesisConnectionHandler; _kinesisStreamConsumerFactory = streamConsumerFactory; _clientId = clientId; + _partitionId = partitionId; _fetchTimeoutMs = streamConfig.getFetchTimeoutMillis(); } @Override public int fetchPartitionCount(long timeoutMillis) { - throw new UnsupportedOperationException(); + try { + List shards = _kinesisConnectionHandler.getShards(); + return shards.size(); + } catch (Exception e) { + LOGGER.error("Failed to fetch partition count", e); + throw new RuntimeException("Failed to fetch partition count", e); + } } @Override public StreamPartitionMsgOffset fetchStreamPartitionOffset(OffsetCriteria offsetCriteria, long timeoutMillis) { - throw new UnsupportedOperationException(); + // fetch offset for _partitionId + Shard foundShard = _kinesisConnectionHandler.getShards().stream() + .filter(shard -> shard.shardId().equals(SHARD_ID_PREFIX + _partitionId)) + .findFirst().orElseThrow(() -> new RuntimeException("Failed to find shard for partitionId: " + _partitionId)); + SequenceNumberRange sequenceNumberRange = foundShard.sequenceNumberRange(); + if (offsetCriteria.isSmallest()) { + return new KinesisPartitionGroupOffset(foundShard.shardId(), sequenceNumberRange.startingSequenceNumber()); + } else if (offsetCriteria.isLargest()) { + return new KinesisPartitionGroupOffset(foundShard.shardId(), sequenceNumberRange.endingSequenceNumber()); + } + throw new IllegalArgumentException("Unsupported offset criteria: " + offsetCriteria); } /**