Skip to content

Commit

Permalink
Add fetchPartitionCount and fetchStreamPartitionOffset implementation…
Browse files Browse the repository at this point in the history
… api for pinot-kinesis
  • Loading branch information
xiangfu0 committed Dec 24, 2024
1 parent 383bbce commit 78dfda4
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class KinesisConsumerFactory extends StreamConsumerFactory {

@Override
public StreamMetadataProvider createPartitionMetadataProvider(String clientId, int partition) {
return new KinesisStreamMetadataProvider(clientId, _streamConfig);
return new KinesisStreamMetadataProvider(clientId, _streamConfig, String.valueOf(partition));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,44 +43,73 @@
import org.apache.pinot.spi.stream.StreamPartitionMsgOffset;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.kinesis.model.SequenceNumberRange;
import software.amazon.awssdk.services.kinesis.model.Shard;


/**
* A {@link StreamMetadataProvider} implementation for the Kinesis stream
*/
public class KinesisStreamMetadataProvider implements StreamMetadataProvider {
private static final String SHARD_ID_PREFIX = "shardId-";
public static final String SHARD_ID_PREFIX = "shardId-";
private final KinesisConnectionHandler _kinesisConnectionHandler;
private final StreamConsumerFactory _kinesisStreamConsumerFactory;
private final String _clientId;
private final int _fetchTimeoutMs;
private final String _partitionId;
private static final Logger LOGGER = LoggerFactory.getLogger(KinesisStreamMetadataProvider.class);

public KinesisStreamMetadataProvider(String clientId, StreamConfig streamConfig) {
this(clientId, streamConfig, String.valueOf(Integer.MIN_VALUE));
}

public KinesisStreamMetadataProvider(String clientId, StreamConfig streamConfig, String partitionId) {
KinesisConfig kinesisConfig = new KinesisConfig(streamConfig);
_kinesisConnectionHandler = new KinesisConnectionHandler(kinesisConfig);
_kinesisStreamConsumerFactory = StreamConsumerFactoryProvider.create(streamConfig);
_clientId = clientId;
_partitionId = partitionId;
_fetchTimeoutMs = streamConfig.getFetchTimeoutMillis();
}

public KinesisStreamMetadataProvider(String clientId, StreamConfig streamConfig,
KinesisConnectionHandler kinesisConnectionHandler, StreamConsumerFactory streamConsumerFactory) {
this(clientId, streamConfig, String.valueOf(Integer.MIN_VALUE), kinesisConnectionHandler, streamConsumerFactory);
}

public KinesisStreamMetadataProvider(String clientId, StreamConfig streamConfig, String partitionId,
KinesisConnectionHandler kinesisConnectionHandler, StreamConsumerFactory streamConsumerFactory) {
_kinesisConnectionHandler = kinesisConnectionHandler;
_kinesisStreamConsumerFactory = streamConsumerFactory;
_clientId = clientId;
_partitionId = partitionId;
_fetchTimeoutMs = streamConfig.getFetchTimeoutMillis();
}

@Override
public int fetchPartitionCount(long timeoutMillis) {
throw new UnsupportedOperationException();
try {
List<Shard> shards = _kinesisConnectionHandler.getShards();
return shards.size();
} catch (Exception e) {
LOGGER.error("Failed to fetch partition count", e);
throw new RuntimeException("Failed to fetch partition count", e);
}
}

@Override
public StreamPartitionMsgOffset fetchStreamPartitionOffset(OffsetCriteria offsetCriteria, long timeoutMillis) {
throw new UnsupportedOperationException();
// fetch offset for _partitionId
Shard foundShard = _kinesisConnectionHandler.getShards().stream()
.filter(shard -> shard.shardId().equals(SHARD_ID_PREFIX + _partitionId))
.findFirst().orElseThrow(() -> new RuntimeException("Failed to find shard for partitionId: " + _partitionId));
SequenceNumberRange sequenceNumberRange = foundShard.sequenceNumberRange();
if (offsetCriteria.isSmallest()) {
return new KinesisPartitionGroupOffset(foundShard.shardId(), sequenceNumberRange.startingSequenceNumber());
} else if (offsetCriteria.isLargest()) {
return new KinesisPartitionGroupOffset(foundShard.shardId(), sequenceNumberRange.endingSequenceNumber());
}
throw new IllegalArgumentException("Unsupported offset criteria: " + offsetCriteria);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.awssdk.services.kinesis.model.ShardIteratorType;

import static org.apache.pinot.plugin.stream.kinesis.KinesisStreamMetadataProvider.SHARD_ID_PREFIX;
import static org.apache.pinot.spi.stream.OffsetCriteria.LARGEST_OFFSET_CRITERIA;
import static org.apache.pinot.spi.stream.OffsetCriteria.SMALLEST_OFFSET_CRITERIA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -98,6 +101,51 @@ public void getPartitionsGroupInfoListTest()
Assert.assertEquals(result.get(1).getPartitionGroupId(), 1);
}

@Test
public void fetchStreamPartitionOffsetTest() {
Shard shard0 = Shard.builder().shardId(SHARD_ID_PREFIX + SHARD_ID_0)
.sequenceNumberRange(
SequenceNumberRange.builder().startingSequenceNumber("1").endingSequenceNumber("100").build()).build();
Shard shard1 = Shard.builder().shardId(SHARD_ID_PREFIX + SHARD_ID_1)
.sequenceNumberRange(
SequenceNumberRange.builder().startingSequenceNumber("2").endingSequenceNumber("200").build()).build();
when(_kinesisConnectionHandler.getShards()).thenReturn(ImmutableList.of(shard0, shard1));

KinesisStreamMetadataProvider kinesisStreamMetadataProviderShard0 =
new KinesisStreamMetadataProvider(CLIENT_ID, getStreamConfig(), SHARD_ID_0, _kinesisConnectionHandler,
_streamConsumerFactory);
Assert.assertEquals(kinesisStreamMetadataProviderShard0.fetchPartitionCount(TIMEOUT), 2);

KinesisPartitionGroupOffset kinesisPartitionGroupOffset =
(KinesisPartitionGroupOffset) kinesisStreamMetadataProviderShard0.fetchStreamPartitionOffset(
SMALLEST_OFFSET_CRITERIA, TIMEOUT);
Assert.assertEquals(kinesisPartitionGroupOffset.getShardId(), SHARD_ID_PREFIX + SHARD_ID_0);
Assert.assertEquals(kinesisPartitionGroupOffset.getSequenceNumber(), "1");

kinesisPartitionGroupOffset =
(KinesisPartitionGroupOffset) kinesisStreamMetadataProviderShard0.fetchStreamPartitionOffset(
LARGEST_OFFSET_CRITERIA, TIMEOUT);
Assert.assertEquals(kinesisPartitionGroupOffset.getShardId(), SHARD_ID_PREFIX + SHARD_ID_0);
Assert.assertEquals(kinesisPartitionGroupOffset.getSequenceNumber(), "100");

KinesisStreamMetadataProvider kinesisStreamMetadataProviderShard1 =
new KinesisStreamMetadataProvider(CLIENT_ID, getStreamConfig(), SHARD_ID_1, _kinesisConnectionHandler,
_streamConsumerFactory);
Assert.assertEquals(kinesisStreamMetadataProviderShard1.fetchPartitionCount(TIMEOUT), 2);

kinesisPartitionGroupOffset =
(KinesisPartitionGroupOffset) kinesisStreamMetadataProviderShard1.fetchStreamPartitionOffset(
SMALLEST_OFFSET_CRITERIA, TIMEOUT);
Assert.assertEquals(kinesisPartitionGroupOffset.getShardId(), SHARD_ID_PREFIX + SHARD_ID_1);
Assert.assertEquals(kinesisPartitionGroupOffset.getSequenceNumber(), "2");

kinesisPartitionGroupOffset =
(KinesisPartitionGroupOffset) kinesisStreamMetadataProviderShard1.fetchStreamPartitionOffset(
LARGEST_OFFSET_CRITERIA, TIMEOUT);
Assert.assertEquals(kinesisPartitionGroupOffset.getShardId(), SHARD_ID_PREFIX + SHARD_ID_1);
Assert.assertEquals(kinesisPartitionGroupOffset.getSequenceNumber(), "200");
}

@Test
public void getPartitionsGroupInfoEndOfShardTest()
throws Exception {
Expand Down

0 comments on commit 78dfda4

Please sign in to comment.