Skip to content

Commit

Permalink
get or create for lru cache
Browse files Browse the repository at this point in the history
Signed-off-by: dentiny <[email protected]>
  • Loading branch information
dentiny committed Feb 8, 2025
1 parent 9cbff17 commit 839a964
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 8 deletions.
2 changes: 1 addition & 1 deletion release/benchmarks/distributed/test_many_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_max_running_tasks(num_tasks):
cpus_per_task = 0.25

@ray.remote(num_cpus=cpus_per_task)
@ray.remote(num_cpus=cpus_per_task, runtime_env={"env_vars": {"FOO": "bar"}})
def task():
time.sleep(sleep_time)

Expand Down
17 changes: 10 additions & 7 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2308,19 +2308,22 @@ json CoreWorker::OverrideRuntimeEnv(const json &child,
return result_runtime_env;
}

std::shared_ptr<rpc::RuntimeEnvInfo> CoreWorker::OverrideTaskOrActorRuntimeEnvInfo(
const std::string &serialized_runtime_env_info) const {
auto factory = [this](const std::string &serialized_runtime_env_info) {
return OverrideTaskOrActorRuntimeEnvInfoImpl(serialized_runtime_env_info);
};
return runtime_env_json_serialization_cache_.GetOrCreate(serialized_runtime_env_info,
std::move(factory));
}

// TODO(hjiang): Current implementation is not the most ideal version, since it acquires a
// global lock for all operations; it's acceptable for now since no heavy-lifted operation
// is involved (considering the overall scheduling overhead is single-digit millisecond
// magnitude). But a better solution is LRU cache native providing a native support for
// sharding and `GetOrCreate` API.
std::shared_ptr<rpc::RuntimeEnvInfo> CoreWorker::OverrideTaskOrActorRuntimeEnvInfo(
std::shared_ptr<rpc::RuntimeEnvInfo> CoreWorker::OverrideTaskOrActorRuntimeEnvInfoImpl(
const std::string &serialized_runtime_env_info) const {
if (auto cached_runtime_env_info =
runtime_env_json_serialization_cache_.Get(serialized_runtime_env_info);
cached_runtime_env_info != nullptr) {
return cached_runtime_env_info;
}

// TODO(Catch-Bull,SongGuyang): task runtime env not support the field eager_install
// yet, we will overwrite the filed eager_install when it did.
std::shared_ptr<json> parent = nullptr;
Expand Down
3 changes: 3 additions & 0 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,9 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
std::shared_ptr<rpc::RuntimeEnvInfo> OverrideTaskOrActorRuntimeEnvInfo(
const std::string &serialized_runtime_env_info) const;

std::shared_ptr<rpc::RuntimeEnvInfo> OverrideTaskOrActorRuntimeEnvInfoImpl(
const std::string &serialized_runtime_env_info) const;

void BuildCommonTaskSpec(
TaskSpecBuilder &builder,
const JobID &job_id,
Expand Down
75 changes: 75 additions & 0 deletions src/ray/util/shared_lru.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#pragma once

#include <condition_variable>
#include <cstdint>
#include <list>
#include <memory>
Expand Down Expand Up @@ -194,6 +195,69 @@ class ThreadSafeSharedLruCache final {
return cache_.Get(std::forward<KeyLike>(key));
}

// Get or creation for cached key-value pairs.
//
// WARNING: Currently factory cannot have exception thrown.
// TODO(hjiang): [factory] should support template.
template <typename KeyLike>
std::shared_ptr<Val> GetOrCreate(KeyLike &&key,
std::function<std::shared_ptr<Val>(Key)> factory) {
std::shared_ptr<CreationToken> creation_token;

{
std::unique_lock lck(mu_);
auto cached_val = cache_.Get(key);
if (cached_val != nullptr) {
return cached_val;
}

auto creation_iter = ongoing_creation_.find(key);

// Another thread has requested for the same key-value pair, simply wait for its
// completion.
if (creation_iter != ongoing_creation_.end()) {
creation_token = creation_iter->second;
++creation_token->count;
creation_token->cv.wait(lck, [creation_token = creation_token.get()]() {
return creation_token->val != nullptr;
});

// Creation finished.
auto val = creation_token->val;
--creation_token->count;
if (creation_token->count == 0) {
// [creation_iter] could be invalidated here due to new insertion/deletion.
ongoing_creation_.erase(key);
}
return val;
}

// Current thread is the first one to request for the key-value pair, perform
// factory function.
creation_iter =
ongoing_creation_.emplace(key, std::make_shared<CreationToken>()).first;
creation_token = creation_iter->second;
creation_token->count = 1;
}

// Place factory out of critical section.
std::shared_ptr<Val> val = factory(key);

{
std::lock_guard lck(mu_);
cache_.Put(key, val);
creation_token->val = val;
creation_token->cv.notify_all();
int new_count = --creation_token->count;
if (new_count == 0) {
// [creation_iter] could be invalidated here due to new insertion/deletion.
ongoing_creation_.erase(key);
}
}

return val;
}

// Clear the cache.
void Clear() {
std::lock_guard lck(mu_);
Expand All @@ -204,8 +268,19 @@ class ThreadSafeSharedLruCache final {
size_t max_entries() const { return cache_.max_entries(); }

private:
struct CreationToken {
std::condition_variable cv;
// Nullptr indicate creation unfinished.
std::shared_ptr<Val> val;
// Counter for ongoing creation.
int count = 0;
};

std::mutex mu_;
SharedLruCache<Key, Val> cache_;

// Ongoing creation.
absl::flat_hash_map<Key, std::shared_ptr<CreationToken>> ongoing_creation_;
};

// Same interfaces as `SharedLruCache`, but all cached values are
Expand Down
37 changes: 37 additions & 0 deletions src/ray/util/tests/shared_lru_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include <gtest/gtest.h>

#include <future>
#include <string>
#include <thread>
#include <type_traits>

namespace ray::utils::container {
Expand Down Expand Up @@ -80,6 +82,41 @@ TEST(SharedLruCache, SameKeyTest) {
EXPECT_EQ(2, *val);
}

TEST(SharedLruCache, FactoryTest) {
using CacheType = ThreadSafeSharedLruCache<std::string, std::string>;

std::atomic<bool> invoked = {false}; // Used to check only invoke once.
auto factory = [&invoked](const std::string &key) -> std::shared_ptr<std::string> {
EXPECT_FALSE(invoked.exchange(true));
// Sleep for a while so multiple threads could kick in and get blocked.
std::this_thread::sleep_for(std::chrono::seconds(3));
return std::make_shared<std::string>(key);
};

CacheType cache{1};

constexpr size_t kFutureNum = 100;
std::vector<std::future<std::shared_ptr<std::string>>> futures;
futures.reserve(kFutureNum);

const std::string key = "key";
for (size_t idx = 0; idx < kFutureNum; ++idx) {
futures.emplace_back(std::async(std::launch::async, [&cache, &key, &factory]() {
return cache.GetOrCreate(key, factory);
}));
}
for (auto &fut : futures) {
auto val = fut.get();
ASSERT_NE(val, nullptr);
ASSERT_EQ(*val, key);
}

// After we're sure key-value pair exists in cache, make one more call.
auto cached_val = cache.GetOrCreate(key, factory);
ASSERT_NE(cached_val, nullptr);
ASSERT_EQ(*cached_val, key);
}

TEST(SharedLruConstCache, TypeAliasAssertion) {
static_assert(
std::is_same_v<SharedLruConstCache<int, int>, SharedLruCache<int, const int>>);
Expand Down

0 comments on commit 839a964

Please sign in to comment.