Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tiered Caching] Moving query recomputation logic outside of write lock #14187

Merged
merged 13 commits into from
Jun 25, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Bump `com.gradle.develocity` from 3.17.4 to 3.17.5 ([#14397](https://github.com/opensearch-project/OpenSearch/pull/14397))

### Changed
- [Tiered Caching] Move query recomputation logic outside write lock ([#14187](https://github.com/opensearch-project/OpenSearch/pull/14187))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

package org.opensearch.cache.common.tier;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cache.common.policy.TookTimePolicy;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.cache.CacheType;
Expand Down Expand Up @@ -35,9 +37,13 @@
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.ToLongBiFunction;
Expand All @@ -61,6 +67,7 @@ public class TieredSpilloverCache<K, V> implements ICache<K, V> {

// Used to avoid caching stale entries in lower tiers.
private static final List<RemovalReason> SPILLOVER_REMOVAL_REASONS = List.of(RemovalReason.EVICTED, RemovalReason.CAPACITY);
private static final Logger logger = LogManager.getLogger(TieredSpilloverCache.class);

private final ICache<K, V> diskCache;
private final ICache<K, V> onHeapCache;
Expand All @@ -86,6 +93,12 @@ public class TieredSpilloverCache<K, V> implements ICache<K, V> {
private final Map<ICache<K, V>, TierInfo> caches;
private final List<Predicate<V>> policies;

/**
* This map is used to handle concurrent requests for same key in computeIfAbsent() to ensure we load the value
* only once.
*/
Map<ICacheKey<K>, CompletableFuture<Tuple<ICacheKey<K>, V>>> completableFutureMap = new ConcurrentHashMap<>();

TieredSpilloverCache(Builder<K, V> builder) {
Objects.requireNonNull(builder.onHeapCacheFactory, "onHeap cache builder can't be null");
Objects.requireNonNull(builder.diskCacheFactory, "disk cache builder can't be null");
Expand Down Expand Up @@ -190,10 +203,7 @@ public V computeIfAbsent(ICacheKey<K> key, LoadAwareCacheLoader<ICacheKey<K>, V>
// Add the value to the onHeap cache. We are calling computeIfAbsent which does another get inside.
// This is needed as there can be many requests for the same key at the same time and we only want to load
// the value once.
V value = null;
try (ReleasableLock ignore = writeLock.acquire()) {
value = onHeapCache.computeIfAbsent(key, loader);
}
V value = compute(key, loader);
// Handle stats
if (loader.isLoaded()) {
// The value was just computed and added to the cache by this thread. Register a miss for the heap cache, and the disk cache
Expand Down Expand Up @@ -222,6 +232,55 @@ public V computeIfAbsent(ICacheKey<K> key, LoadAwareCacheLoader<ICacheKey<K>, V>
return cacheValueTuple.v1();
}

private V compute(ICacheKey<K> key, LoadAwareCacheLoader<ICacheKey<K>, V> loader) throws Exception {
// Only one of the threads will succeed putting a future into map for the same key.
// Rest will fetch existing future and wait on that to complete.
CompletableFuture<Tuple<ICacheKey<K>, V>> future = completableFutureMap.putIfAbsent(key, new CompletableFuture<>());
// Handler to handle results post processing. Takes a tuple<key, value> or exception as an input and returns
// the value. Also before returning value, puts the value in cache.
BiFunction<Tuple<ICacheKey<K>, V>, Throwable, Void> handler = (pair, ex) -> {
if (pair != null) {
try (ReleasableLock ignore = writeLock.acquire()) {
onHeapCache.put(pair.v1(), pair.v2());
sgup432 marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
if (ex != null) {
logger.warn("Exception occurred while trying to compute the value", ex);
}
}
completableFutureMap.remove(key); // Remove key from map as not needed anymore.
return null;
};
if (future == null) {
sgup432 marked this conversation as resolved.
Show resolved Hide resolved
V value = null;
future = completableFutureMap.get(key);
future.handle(handler);
try {
value = loader.load(key);
} catch (Exception ex) {
future.completeExceptionally(ex);
throw new ExecutionException(ex);
sgup432 marked this conversation as resolved.
Show resolved Hide resolved
}
if (value == null) {
NullPointerException npe = new NullPointerException("Loader returned a null value");
future.completeExceptionally(npe);
throw new ExecutionException(npe);
} else {
future.complete(new Tuple<>(key, value));
}
}
V value;
try {
value = future.get().v2();
if (future.isCompletedExceptionally()) {
throw new IllegalStateException("Future completed exceptionally but no error thrown");
sgup432 marked this conversation as resolved.
Show resolved Hide resolved
}
} catch (InterruptedException ex) {
throw new IllegalStateException(ex);
}
sgup432 marked this conversation as resolved.
Show resolved Hide resolved
return value;
}

@Override
public void invalidate(ICacheKey<K> key) {
// We are trying to invalidate the key from all caches though it would be present in only of them.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@
import java.util.UUID;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Phaser;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Predicate;
Expand Down Expand Up @@ -408,6 +412,7 @@ public void testComputeIfAbsentWithEvictionsFromOnHeapCache() throws Exception {
assertEquals(onHeapCacheHit, getHitsForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
assertEquals(cacheMiss + numOfItems1, getMissesForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_DISK));
assertEquals(diskCacheHit, getHitsForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_DISK));
assertEquals(0, tieredSpilloverCache.completableFutureMap.size());
}

public void testComputeIfAbsentWithEvictionsFromTieredCache() throws Exception {
Expand Down Expand Up @@ -802,7 +807,7 @@ public String load(ICacheKey<String> key) {
};
loadAwareCacheLoaderList.add(loadAwareCacheLoader);
phaser.arriveAndAwaitAdvance();
tieredSpilloverCache.computeIfAbsent(key, loadAwareCacheLoader);
assertEquals(value, tieredSpilloverCache.computeIfAbsent(key, loadAwareCacheLoader));
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand All @@ -811,7 +816,7 @@ public String load(ICacheKey<String> key) {
threads[i].start();
}
phaser.arriveAndAwaitAdvance();
countDownLatch.await(); // Wait for rest of tasks to be cancelled.
countDownLatch.await();
int numberOfTimesKeyLoaded = 0;
assertEquals(numberOfSameKeys, loadAwareCacheLoaderList.size());
for (int i = 0; i < loadAwareCacheLoaderList.size(); i++) {
Expand All @@ -824,6 +829,231 @@ public String load(ICacheKey<String> key) {
// We should see only one heap miss, and the rest hits
assertEquals(1, getMissesForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
assertEquals(numberOfSameKeys - 1, getHitsForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
assertEquals(0, tieredSpilloverCache.completableFutureMap.size());
}

public void testComputIfAbsentConcurrentlyWithMultipleKeys() throws Exception {
int onHeapCacheSize = randomIntBetween(300, 500);
int diskCacheSize = randomIntBetween(600, 700);
int keyValueSize = 50;

MockCacheRemovalListener<String, String> removalListener = new MockCacheRemovalListener<>();
Settings settings = Settings.builder()
.put(
OpenSearchOnHeapCacheSettings.getSettingListForCacheType(CacheType.INDICES_REQUEST_CACHE)
.get(MAXIMUM_SIZE_IN_BYTES_KEY)
.getKey(),
onHeapCacheSize * keyValueSize + "b"
)
.build();

TieredSpilloverCache<String, String> tieredSpilloverCache = initializeTieredSpilloverCache(
keyValueSize,
diskCacheSize,
removalListener,
settings,
0
);

int iterations = 10;
int numberOfKeys = 20;
List<ICacheKey<String>> iCacheKeyList = new ArrayList<>();
for (int i = 0; i < numberOfKeys; i++) {
ICacheKey<String> key = getICacheKey(UUID.randomUUID().toString());
iCacheKeyList.add(key);
}
ExecutorService executorService = Executors.newFixedThreadPool(8);
CountDownLatch countDownLatch = new CountDownLatch(iterations * numberOfKeys); // To wait for all threads to finish.

List<LoadAwareCacheLoader<ICacheKey<String>, String>> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>();

for (int i = 0; i < iterations; i++) {
for (int j = 0; j < numberOfKeys; j++) {
sgup432 marked this conversation as resolved.
Show resolved Hide resolved
int finalJ = j;
executorService.submit(() -> {
try {
LoadAwareCacheLoader<ICacheKey<String>, String> loadAwareCacheLoader = new LoadAwareCacheLoader<>() {
boolean isLoaded = false;

@Override
public boolean isLoaded() {
return isLoaded;
}

@Override
public String load(ICacheKey<String> key) {
isLoaded = true;
return iCacheKeyList.get(finalJ).key;
}
};
loadAwareCacheLoaderList.add(loadAwareCacheLoader);
tieredSpilloverCache.computeIfAbsent(iCacheKeyList.get(finalJ), loadAwareCacheLoader);
} catch (Exception e) {
throw new RuntimeException(e);
}
countDownLatch.countDown();
sgup432 marked this conversation as resolved.
Show resolved Hide resolved
});
}
}
countDownLatch.await();
int numberOfTimesKeyLoaded = 0;
assertEquals(iterations * numberOfKeys, loadAwareCacheLoaderList.size());
for (int i = 0; i < loadAwareCacheLoaderList.size(); i++) {
LoadAwareCacheLoader<ICacheKey<String>, String> loader = loadAwareCacheLoaderList.get(i);
if (loader.isLoaded()) {
numberOfTimesKeyLoaded++;
}
}
assertEquals(numberOfKeys, numberOfTimesKeyLoaded); // It should be loaded only once.
// We should see only one heap miss, and the rest hits
assertEquals(numberOfKeys, getMissesForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
assertEquals((iterations * numberOfKeys) - numberOfKeys, getHitsForTier(tieredSpilloverCache, TIER_DIMENSION_VALUE_ON_HEAP));
assertEquals(0, tieredSpilloverCache.completableFutureMap.size());
executorService.shutdownNow();
}

public void testComputeIfAbsentConcurrentlyAndThrowsException() throws Exception {
int onHeapCacheSize = randomIntBetween(100, 300);
int diskCacheSize = randomIntBetween(200, 400);
int keyValueSize = 50;

MockCacheRemovalListener<String, String> removalListener = new MockCacheRemovalListener<>();
Settings settings = Settings.builder()
.put(
OpenSearchOnHeapCacheSettings.getSettingListForCacheType(CacheType.INDICES_REQUEST_CACHE)
.get(MAXIMUM_SIZE_IN_BYTES_KEY)
.getKey(),
onHeapCacheSize * keyValueSize + "b"
)
.build();

TieredSpilloverCache<String, String> tieredSpilloverCache = initializeTieredSpilloverCache(
keyValueSize,
diskCacheSize,
removalListener,
settings,
0
);

int numberOfSameKeys = randomIntBetween(10, onHeapCacheSize - 1);
ICacheKey<String> key = getICacheKey(UUID.randomUUID().toString());
String value = UUID.randomUUID().toString();
AtomicInteger exceptionCount = new AtomicInteger();

Thread[] threads = new Thread[numberOfSameKeys];
Phaser phaser = new Phaser(numberOfSameKeys + 1);
CountDownLatch countDownLatch = new CountDownLatch(numberOfSameKeys); // To wait for all threads to finish.

List<LoadAwareCacheLoader<ICacheKey<String>, String>> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>();

for (int i = 0; i < numberOfSameKeys; i++) {
threads[i] = new Thread(() -> {
try {
LoadAwareCacheLoader<ICacheKey<String>, String> loadAwareCacheLoader = new LoadAwareCacheLoader<>() {
boolean isLoaded = false;

@Override
public boolean isLoaded() {
return isLoaded;
}

@Override
public String load(ICacheKey<String> key) {
throw new RuntimeException("Testing");
}
};
loadAwareCacheLoaderList.add(loadAwareCacheLoader);
phaser.arriveAndAwaitAdvance();
tieredSpilloverCache.computeIfAbsent(key, loadAwareCacheLoader);
} catch (Exception e) {
exceptionCount.incrementAndGet();
assertEquals(ExecutionException.class, e.getClass());
assertEquals(RuntimeException.class, e.getCause().getClass());
assertEquals("Testing", e.getCause().getMessage());
} finally {
countDownLatch.countDown();
}
});
threads[i].start();
}
phaser.arriveAndAwaitAdvance();
countDownLatch.await(); // Wait for rest of tasks to be cancelled.

// Verify exception count was equal to number of requests
assertEquals(numberOfSameKeys, exceptionCount.get());
assertEquals(0, tieredSpilloverCache.completableFutureMap.size());
}

public void testComputeIfAbsentConcurrentlyWithLoaderReturningNull() throws Exception {
sgup432 marked this conversation as resolved.
Show resolved Hide resolved
int onHeapCacheSize = randomIntBetween(100, 300);
int diskCacheSize = randomIntBetween(200, 400);
int keyValueSize = 50;

MockCacheRemovalListener<String, String> removalListener = new MockCacheRemovalListener<>();
Settings settings = Settings.builder()
.put(
OpenSearchOnHeapCacheSettings.getSettingListForCacheType(CacheType.INDICES_REQUEST_CACHE)
.get(MAXIMUM_SIZE_IN_BYTES_KEY)
.getKey(),
onHeapCacheSize * keyValueSize + "b"
)
.build();

TieredSpilloverCache<String, String> tieredSpilloverCache = initializeTieredSpilloverCache(
keyValueSize,
diskCacheSize,
removalListener,
settings,
0
);

int numberOfSameKeys = randomIntBetween(10, onHeapCacheSize - 1);
ICacheKey<String> key = getICacheKey(UUID.randomUUID().toString());
String value = UUID.randomUUID().toString();
AtomicInteger exceptionCount = new AtomicInteger();

Thread[] threads = new Thread[numberOfSameKeys];
Phaser phaser = new Phaser(numberOfSameKeys + 1);
CountDownLatch countDownLatch = new CountDownLatch(numberOfSameKeys); // To wait for all threads to finish.

List<LoadAwareCacheLoader<ICacheKey<String>, String>> loadAwareCacheLoaderList = new CopyOnWriteArrayList<>();

for (int i = 0; i < numberOfSameKeys; i++) {
threads[i] = new Thread(() -> {
try {
LoadAwareCacheLoader<ICacheKey<String>, String> loadAwareCacheLoader = new LoadAwareCacheLoader<>() {
boolean isLoaded = false;

@Override
public boolean isLoaded() {
return isLoaded;
}

@Override
public String load(ICacheKey<String> key) {
return null;
}
};
loadAwareCacheLoaderList.add(loadAwareCacheLoader);
phaser.arriveAndAwaitAdvance();
tieredSpilloverCache.computeIfAbsent(key, loadAwareCacheLoader);
} catch (Exception e) {
exceptionCount.incrementAndGet();
assertEquals(ExecutionException.class, e.getClass());
assertEquals(NullPointerException.class, e.getCause().getClass());
assertEquals("Loader returned a null value", e.getCause().getMessage());
} finally {
countDownLatch.countDown();
}
});
threads[i].start();
}
phaser.arriveAndAwaitAdvance();
countDownLatch.await(); // Wait for rest of tasks to be cancelled.

// Verify exception count was equal to number of requests
assertEquals(numberOfSameKeys, exceptionCount.get());
assertEquals(0, tieredSpilloverCache.completableFutureMap.size());
}

public void testConcurrencyForEvictionFlowFromOnHeapToDiskTier() throws Exception {
Expand Down
Loading