diff --git a/thread/test/test.cpp b/thread/test/test.cpp index 4fd2a171..2b34b57d 100644 --- a/thread/test/test.cpp +++ b/thread/test/test.cpp @@ -1890,6 +1890,101 @@ TEST(intrusive_list, split) { } +TEST(interrupt, mutex) { + photon::mutex mtx(0); + // lock first + mtx.lock(); + auto th = photon::CURRENT; + int reason = rand(); + while (reason == 0) reason = rand(); + photon::thread_create11([th, reason]() { + // any errno except 0 is able to stop waiting + photon::thread_interrupt(th, reason); + }); + // this time will goto sleep + auto ret = mtx.lock(); + ERRNO err; + EXPECT_EQ(-1, ret); + EXPECT_EQ(reason, err.no); + mtx.unlock(); +} + +TEST(interrupt, condition_variable) { + photon::condition_variable cond; + auto th = photon::CURRENT; + int reason = rand(); + while (reason == 0) reason = rand(); + photon::thread_create11([th, reason]() { + // any errno except 0 is able to stop waiting + photon::thread_interrupt(th, reason); + }); + auto ret = cond.wait_no_lock(); + ERRNO err; + EXPECT_EQ(-1, ret); + EXPECT_EQ(reason, err.no); +} + +TEST(interrupt, semaphore) { + photon::semaphore sem(0); + auto th = photon::CURRENT; + int reason = rand(); + while (reason == 0) reason = rand(); + photon::thread_create11([th, reason]() { + // any errno except 0 is able to stop waiting + photon::thread_interrupt(th, reason); + }); + auto ret = sem.wait_interruptible(1); // nobody + ERRNO err; + EXPECT_EQ(-1, ret); + EXPECT_EQ(reason, err.no); +} + + +TEST(condition_variable, pred) { + photon::condition_variable cond; + int flag = 0; + photon::thread_create11([&cond, &flag]() { + // any errno except 0 is able to stop waiting + flag = 1; + cond.notify_one(); + // first notify should not wake up condition variable + photon::thread_usleep(1000 * 1000); + flag = 2; + cond.notify_one(); + + }); + auto ret = cond.wait_no_lock([&flag](){ return flag == 2;}); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, flag); + ret = cond.wait_no_lock([&flag](){ return flag == 3; }, 1000); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); + flag = 0; + photon::mutex mtx; + SCOPED_LOCK(mtx); + photon::thread_create11([&cond, &flag, &mtx]() { + // any errno except 0 is able to stop waiting + { + SCOPED_LOCK(mtx); + flag = 1; + cond.notify_one(); + } + // first notify should not wake up condition variable + photon::thread_usleep(1000 * 1000); + { + SCOPED_LOCK(mtx); + flag = 2; + cond.notify_one(); + } + }); + ret = cond.wait(mtx, [&flag](){ return flag == 2;}); + EXPECT_EQ(0, ret); + EXPECT_EQ(2, flag); + ret = cond.wait(mtx, [&flag](){ return flag == 3; }, 1000); + EXPECT_EQ(-1, ret); + EXPECT_EQ(ETIMEDOUT, errno); +} + int main(int argc, char** arg) { ::testing::InitGoogleTest(&argc, arg); diff --git a/thread/thread.cpp b/thread/thread.cpp index 964929a2..33163c89 100644 --- a/thread/thread.cpp +++ b/thread/thread.cpp @@ -1494,7 +1494,7 @@ R"( *perrno = ETIMEDOUT; return -1; } - return (*perrno == ECANCELED) ? 0 : -1; + return (*perrno == -1) ? 0 : -1; } int waitq::wait(Timeout timeout) { @@ -1591,7 +1591,7 @@ R"( ScopedLockHead h(m); m->owner.store(h); if (h) - prelocked_thread_interrupt(h, ECANCELED); + prelocked_thread_interrupt(h, -1); } static void mutex_unlock(void* m_) { @@ -1672,7 +1672,7 @@ R"( { return cvar_do_wait((thread_list*)&q, m, timeout, spinlock_lock, spinlock_unlock); } - int semaphore::wait(uint64_t count, Timeout timeout) + int semaphore::wait_interruptible(uint64_t count, Timeout timeout) { if (count == 0) return 0; splock.lock(); @@ -1680,11 +1680,14 @@ R"( int ret = 0; while (!try_substract(count)) { ret = waitq::wait_defer(timeout, spinlock_unlock, &splock); + ERRNO err; splock.lock(); - if (ret < 0 && errno == ETIMEDOUT) { + if (ret < 0) { CURRENT->semaphore_count = 0; - try_resume(); // when timeout, we need to try - splock.unlock(); // to resume next thread(s) in q + // when timeout, we need to try to resume next thread(s) in q + if (err.no == ETIMEDOUT) try_resume(); + splock.unlock(); + errno = err.no; return ret; } } @@ -1704,7 +1707,7 @@ R"( if (qfcount > cnt) break; cnt -= qfcount; qfcount = 0; - prelocked_thread_interrupt(th, ECANCELED); + prelocked_thread_interrupt(th, -1); } } bool semaphore::try_substract(uint64_t count) diff --git a/thread/thread.h b/thread/thread.h index 0e99fc64..ec012a9d 100644 --- a/thread/thread.h +++ b/thread/thread.h @@ -212,9 +212,9 @@ namespace photon protected: int wait(Timeout timeout = {}); int wait_defer(Timeout Timeout, void(*defer)(void*), void* arg); - void resume(thread* th, int error_number = ECANCELED); // `th` must be waiting in this waitq! - int resume_all(int error_number = ECANCELED); - thread* resume_one(int error_number = ECANCELED); + void resume(thread* th, int error_number = -1); // `th` must be waiting in this waitq! + int resume_all(int error_number = -1); + thread* resume_one(int error_number = -1); waitq() = default; waitq(const waitq& rhs) = delete; // not allowed to copy construct waitq(waitq&& rhs) = delete; @@ -362,17 +362,50 @@ namespace photon { return waitq::wait(timeout); } + template ()())> + int wait(LOCK&& lock, PRED&& pred, Timeout timeout = {}) { + return do_wait_pred( + [&] { return wait(std::forward(lock), timeout); }, + std::forward(pred), timeout); + } + template ()())> + int wait_no_lock(PRED&& pred, Timeout timeout = {}) { + return do_wait_pred( + [&] { return wait_no_lock(timeout); }, + std::forward(pred), timeout); + } thread* signal() { return resume_one(); } thread* notify_one() { return resume_one(); } int notify_all() { return resume_all(); } int broadcast() { return resume_all(); } + protected: + template + int do_wait_pred(DO_WAIT&& do_wait, PRED&& pred, Timeout timeout) { + int ret = 0; + int err = ETIMEDOUT; + while (!pred() && !timeout.expired()) { + ret = do_wait(); + err = errno; + } + errno = err; + return ret; + } }; class semaphore : protected waitq { public: explicit semaphore(uint64_t count = 0) : m_count(count) { } - int wait(uint64_t count, Timeout timeout = {}); + int wait(uint64_t count, Timeout timeout = {}) { + int ret = 0; + do { + ret = wait_interruptible(count, timeout); + } while (ret < 0 && (errno != ESHUTDOWN && errno != ETIMEDOUT)); + return ret; + } + int wait_interruptible(uint64_t count, Timeout timeout = {}); int signal(uint64_t count) { if (count == 0) return 0;