Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Implement get or create for lru cache #50347

Merged
merged 4 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2308,19 +2308,17 @@ json CoreWorker::OverrideRuntimeEnv(const json &child,
return result_runtime_env;
}

// TODO(hjiang): Current implementation is not the most ideal version, since it acquires a
// global lock for all operations; it's acceptable for now since no heavy-lifted operation
// is involved (considering the overall scheduling overhead is single-digit millisecond
// magnitude). But a better solution is LRU cache native providing a native support for
// sharding and `GetOrCreate` API.
std::shared_ptr<rpc::RuntimeEnvInfo> CoreWorker::OverrideTaskOrActorRuntimeEnvInfo(
const std::string &serialized_runtime_env_info) const {
if (auto cached_runtime_env_info =
runtime_env_json_serialization_cache_.Get(serialized_runtime_env_info);
cached_runtime_env_info != nullptr) {
return cached_runtime_env_info;
}
auto factory = [this](const std::string &serialized_runtime_env_info) {
return OverrideTaskOrActorRuntimeEnvInfoImpl(serialized_runtime_env_info);
};
return runtime_env_json_serialization_cache_.GetOrCreate(serialized_runtime_env_info,
std::move(factory));
}

std::shared_ptr<rpc::RuntimeEnvInfo> CoreWorker::OverrideTaskOrActorRuntimeEnvInfoImpl(
const std::string &serialized_runtime_env_info) const {
// TODO(Catch-Bull,SongGuyang): task runtime env not support the field eager_install
// yet, we will overwrite the filed eager_install when it did.
std::shared_ptr<json> parent = nullptr;
Expand Down
5 changes: 5 additions & 0 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,11 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
std::shared_ptr<rpc::RuntimeEnvInfo> OverrideTaskOrActorRuntimeEnvInfo(
const std::string &serialized_runtime_env_info) const;

// Used as the factory function for [OverrideTaskOrActorRuntimeEnvInfo] to create in LRU
// cache.
std::shared_ptr<rpc::RuntimeEnvInfo> OverrideTaskOrActorRuntimeEnvInfoImpl(
const std::string &serialized_runtime_env_info) const;

dentiny marked this conversation as resolved.
Show resolved Hide resolved
void BuildCommonTaskSpec(
TaskSpecBuilder &builder,
const JobID &job_id,
Expand Down
77 changes: 75 additions & 2 deletions src/ray/util/shared_lru.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
// // Check and consume `val`.
//
// TODO(hjiang):
// 1. Add a `GetOrCreate` interface, which takes factory function to creation value.
// 2. For thread-safe cache, add a sharded container wrapper to reduce lock contention.
// For thread-safe cache, add a sharded container wrapper to reduce lock contention.

#pragma once

#include <condition_variable>
#include <cstdint>
#include <list>
#include <memory>
Expand Down Expand Up @@ -194,6 +194,68 @@ class ThreadSafeSharedLruCache final {
return cache_.Get(std::forward<KeyLike>(key));
}

// Get or creation for cached key-value pairs.
//
// WARNING: Currently factory cannot have exception thrown.
// TODO(hjiang): [factory] should support template.
template <typename KeyLike>
std::shared_ptr<Val> GetOrCreate(
KeyLike &&key, std::function<std::shared_ptr<Val>(const Key &)> factory) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

factory function should be taken as const-l value, never moving

std::shared_ptr<CreationToken> creation_token;

{
std::unique_lock lck(mu_);
auto cached_val = cache_.Get(key);
if (cached_val != nullptr) {
return cached_val;
}

auto creation_iter = ongoing_creation_.find(key);

// Another thread has requested for the same key-value pair, simply wait for its
// completion.
if (creation_iter != ongoing_creation_.end()) {
creation_token = creation_iter->second;
++creation_token->count;
creation_token->cv.wait(lck, [creation_token = creation_token.get()]() {
return creation_token->val != nullptr;
});

// Creation finished.
--creation_token->count;
if (creation_token->count == 0) {
// [creation_iter] could be invalidated here due to new insertion/deletion.
ongoing_creation_.erase(key);
}
return creation_token->val;
}

// Current thread is the first one to request for the key-value pair, perform
// factory function.
creation_iter =
ongoing_creation_.emplace(key, std::make_shared<CreationToken>()).first;
creation_token = creation_iter->second;
creation_token->count = 1;
}

// Place factory out of critical section.
std::shared_ptr<Val> val = factory(key);

{
std::lock_guard lck(mu_);
cache_.Put(key, val);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward the key here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you cannot, because key is accessed later; if pass-ed key is rvalue we get invalidated access.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh gotcha, with the erase, why not reorder so that this happens after the if new_count==0: erase(key). The whole block is mutex covered anyways

Copy link
Contributor Author

@dentiny dentiny Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can, but I usually don't do move / forward operation here, because it's error prone.
Suppose we add more logic in the code block in the future, it's easy to forget to move the forward semantics around.
I would like to defer until we have clang-tidy integrated in our CI, which reports invalid usage like use after move (I heavily rely on clang-tidy to detect inefficiency and illegal access).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as long as the move/forward is the last thing that happens should be ok, and person changing after should be aware, a lot of our code can break if ppl are not aware when adding more logic, but fine either way, mico-opt

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and person changing after should be aware, a lot of our code can break if ppl are not aware when adding more logic

Sigh, I forgot about it for a few times, and I heavily rely on static analysis.

creation_token->val = val;
creation_token->cv.notify_all();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use unique lock to lock, and then notify after unlocking at the end so the wait doesn't fail initial lock acquisition

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I strongly discourage notifying without lock held.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://en.cppreference.com/w/cpp/thread/condition_variable
https://en.cppreference.com/w/cpp/thread/condition_variable/notify_one
https://en.cppreference.com/w/cpp/thread/condition_variable/notify_all

all three cppreference examples here unlock before notifying and point 3 on the first link notes you can notify after unlock

Can look at the paragraph on notify_one doc, describes pros and cons of both. Here we're not in the situation where the wait succeeding would cause the cv to be invalidated.

The notifying thread does not need to hold the lock on the same mutex as the one held by the waiting thread(s); in fact doing so is a pessimization, since the notified thread would immediately block again, waiting for the notifying thread to release the lock. However, some implementations (in particular many implementations of pthreads) recognize this situation and avoid this "hurry up and wait" scenario by transferring the waiting thread from the condition variable's queue directly to the queue of the mutex within the notify call, without waking it up.

Notifying while under the lock may nevertheless be necessary when precise scheduling of events is required, e.g. if the waiting thread would exit the program if the condition is satisfied, causing destruction of the notifying thread's condition variable. A spurious wakeup after mutex unlock but before notify would result in notify called on a destroyed object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the link, checking the code again. I think we can unlock right before cv notification, but you need to reorder code and unlock right before the last line, TBH I don't see too much value.

int new_count = --creation_token->count;
if (new_count == 0) {
// [creation_iter] could be invalidated here due to new insertion/deletion.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// [creation_iter] could be invalidated here due to new insertion/deletion.

note not necessary? creation_iter not even in scope here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? creation_iter is just several line above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment here is used to explain why we cannot access value via the iterator created before.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? creation_iter is just several line above.

ya it's created in the block with unique_lock, not this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's created in L213, still valid here.

ongoing_creation_.erase(key);
}
}

return val;
}

// Clear the cache.
void Clear() {
std::lock_guard lck(mu_);
Expand All @@ -204,8 +266,19 @@ class ThreadSafeSharedLruCache final {
size_t max_entries() const { return cache_.max_entries(); }

private:
struct CreationToken {
std::condition_variable cv;
// Nullptr indicate creation unfinished.
std::shared_ptr<Val> val;
// Counter for ongoing creation.
int count = 0;
};

std::mutex mu_;
SharedLruCache<Key, Val> cache_;

// Ongoing creation.
absl::flat_hash_map<Key, std::shared_ptr<CreationToken>> ongoing_creation_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does CreationToken have to be boxed in a shared ptr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's your suggestion? I don't think there's other ways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cannot use raw value, since token is not copiable;
you cannot use unique pointer, because you don't know which thread is accessing the last reference count.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it might be possible to use unique_ptr, since last thread accessing is the one erasing and you can copy out val before, but ya not worth it this is safer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unique pointer is also not possible, inside token there's cond var and mutex which are not copiable.

};

// Same interfaces as `SharedLruCache`, but all cached values are
Expand Down
37 changes: 37 additions & 0 deletions src/ray/util/tests/shared_lru_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include <gtest/gtest.h>

#include <future>
#include <string>
#include <thread>
#include <type_traits>

namespace ray::utils::container {
Expand Down Expand Up @@ -80,6 +82,41 @@ TEST(SharedLruCache, SameKeyTest) {
EXPECT_EQ(2, *val);
}

TEST(SharedLruCache, FactoryTest) {
using CacheType = ThreadSafeSharedLruCache<std::string, std::string>;

std::atomic<bool> invoked = {false}; // Used to check only invoke once.
auto factory = [&invoked](const std::string &key) -> std::shared_ptr<std::string> {
EXPECT_FALSE(invoked.exchange(true));
// Sleep for a while so multiple threads could kick in and get blocked.
std::this_thread::sleep_for(std::chrono::seconds(3));
return std::make_shared<std::string>(key);
};

CacheType cache{1};

constexpr size_t kFutureNum = 100;
std::vector<std::future<std::shared_ptr<std::string>>> futures;
futures.reserve(kFutureNum);

const std::string key = "key";
for (size_t idx = 0; idx < kFutureNum; ++idx) {
futures.emplace_back(std::async(std::launch::async, [&cache, &key, &factory]() {
return cache.GetOrCreate(key, factory);
}));
}
for (auto &fut : futures) {
auto val = fut.get();
ASSERT_NE(val, nullptr);
ASSERT_EQ(*val, key);
}

// After we're sure key-value pair exists in cache, make one more call.
auto cached_val = cache.GetOrCreate(key, factory);
ASSERT_NE(cached_val, nullptr);
ASSERT_EQ(*cached_val, key);
}

TEST(SharedLruConstCache, TypeAliasAssertion) {
static_assert(
std::is_same_v<SharedLruConstCache<int, int>, SharedLruCache<int, const int>>);
Expand Down