Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Apr 22, 2024
1 parent 197b3f1 commit 1c1a596
Show file tree
Hide file tree
Showing 61 changed files with 185 additions and 566 deletions.
5 changes: 1 addition & 4 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1015,9 +1015,7 @@ function(onnxruntime_set_compile_flags target_name)
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang")
#external/protobuf/src/google/protobuf/arena.h:445:18: error: unused parameter 'p'
target_compile_options(${target_name} PRIVATE "-Wno-unused-parameter")
endif()
target_compile_definitions(${target_name} PUBLIC -DNSYNC_ATOMIC_CPP11)
onnxruntime_add_include_to_target(${target_name} nsync::nsync_cpp)
endif()
endif()
foreach(ORT_FLAG ${ORT_PROVIDER_FLAGS})
target_compile_definitions(${target_name} PRIVATE ${ORT_FLAG})
Expand Down Expand Up @@ -1640,7 +1638,6 @@ if (WIN32)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES advapi32)
endif()
else()
list(APPEND onnxruntime_EXTERNAL_LIBRARIES nsync::nsync_cpp)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${ICONV_LIB} ${CMAKE_DL_LIBS} Threads::Threads)
endif()

Expand Down
1 change: 0 additions & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.zip
fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494
fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1
google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.3.zip;bf9870756ee3f8d2d3b346b24ee3600a41c74d3d
google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73
json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c
Expand Down
18 changes: 0 additions & 18 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,6 @@ if (onnxruntime_BUILD_BENCHMARKS)
)
endif()

if (NOT WIN32)
FetchContent_Declare(
google_nsync
URL ${DEP_URL_google_nsync}
URL_HASH SHA1=${DEP_SHA1_google_nsync}
FIND_PACKAGE_ARGS NAMES nsync
)
endif()
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/external)

FetchContent_Declare(
Expand Down Expand Up @@ -340,16 +332,6 @@ if (onnxruntime_BUILD_BENCHMARKS)
onnxruntime_fetchcontent_makeavailable(google_benchmark)
endif()

if (NOT WIN32)
#nsync tests failed on Mac Build
set(NSYNC_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
onnxruntime_fetchcontent_makeavailable(google_nsync)
if (google_nsync_SOURCE_DIR)
add_library(nsync::nsync_cpp ALIAS nsync_cpp)
target_include_directories(nsync_cpp PUBLIC ${google_nsync_SOURCE_DIR}/public)
endif()
endif()

if(onnxruntime_USE_CUDA)
FetchContent_Declare(
GSL
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD)
target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo)
endif()
if(NOT WIN32)
target_link_libraries(onnxruntime_mlas_q4dq PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS})
target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${CMAKE_DL_LIBS})
endif()
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs})
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_providers_cann.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
onnxruntime_add_include_to_target(onnxruntime_providers_cann onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface)

add_dependencies(onnxruntime_providers_cann onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_link_libraries(onnxruntime_providers_cann PRIVATE ascendcl acl_op_compiler fmk_onnx_parser nsync::nsync_cpp ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED})
target_link_libraries(onnxruntime_providers_cann PRIVATE ascendcl acl_op_compiler fmk_onnx_parser ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED})
target_link_directories(onnxruntime_providers_cann PRIVATE ${onnxruntime_CANN_HOME}/lib64)
target_include_directories(onnxruntime_providers_cann PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${onnxruntime_CANN_HOME} ${onnxruntime_CANN_HOME}/include)

Expand Down
2 changes: 0 additions & 2 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,8 @@

if(APPLE)
set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/cuda/exported_symbols.lst")
target_link_libraries(${target} PRIVATE nsync::nsync_cpp)
elseif(UNIX)
set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/cuda/version_script.lds -Xlinker --gc-sections")
target_link_libraries(${target} PRIVATE nsync::nsync_cpp)
elseif(WIN32)
set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/cuda/symbols.def")
else()
Expand Down
2 changes: 0 additions & 2 deletions cmake/onnxruntime_providers_dnnl.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@
INSTALL_RPATH "@loader_path"
BUILD_WITH_INSTALL_RPATH TRUE
INSTALL_RPATH_USE_LINK_PATH FALSE)
target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp)
elseif(UNIX)
set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/dnnl/version_script.lds -Xlinker --gc-sections -Xlinker -rpath=\$ORIGIN")
target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp)
elseif(WIN32)
set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/dnnl/symbols.def")
else()
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_providers_migraphx.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare)
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations")
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections")
target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp stdc++fs)
target_link_libraries(onnxruntime_providers_migraphx PRIVATE stdc++fs)

include(CheckLibraryExists)
check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC)
Expand Down
1 change: 0 additions & 1 deletion cmake/onnxruntime_providers_rocm.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@

if(UNIX)
set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections")
target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync::nsync_cpp)
else()
message(FATAL_ERROR "onnxruntime_providers_rocm unknown platform, need to specify shared library exports for it")
endif()
Expand Down
3 changes: 1 addition & 2 deletions cmake/onnxruntime_providers_tensorrt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,10 @@

if(APPLE)
set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/tensorrt/exported_symbols.lst")
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp)
elseif(UNIX)
set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations")
set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/tensorrt/version_script.lds -Xlinker --gc-sections")
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp stdc++fs)
target_link_libraries(onnxruntime_providers_tensorrt PRIVATE stdc++fs)
elseif(WIN32)
set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/tensorrt/symbols.def")
else()
Expand Down
13 changes: 3 additions & 10 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,6 @@ if(MSVC)
else()
target_compile_definitions(onnxruntime_test_utils PUBLIC -DNSYNC_ATOMIC_CPP11)
target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
onnxruntime_add_include_to_target(onnxruntime_test_utils nsync::nsync_cpp)
endif()
if (onnxruntime_USE_NCCL)
target_include_directories(onnxruntime_test_utils PRIVATE ${NCCL_INCLUDE_DIRS})
Expand Down Expand Up @@ -749,7 +748,6 @@ if(NOT IOS)
else()
target_compile_definitions(onnx_test_runner_common PUBLIC -DNSYNC_ATOMIC_CPP11)
target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
onnxruntime_add_include_to_target(onnx_test_runner_common nsync::nsync_cpp)
endif()
if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
#TODO: fix the warnings, they are dangerous
Expand Down Expand Up @@ -1127,7 +1125,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
# "Global initializer calls a non-constexpr function." BENCHMARK_CAPTURE macro needs this.
target_compile_options(onnxruntime_mlas_benchmark PRIVATE /wd26426)
else()
target_link_libraries(onnxruntime_mlas_benchmark PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS})
target_link_libraries(onnxruntime_mlas_benchmark PRIVATE ${CMAKE_DL_LIBS})
endif()
if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
target_link_libraries(onnxruntime_mlas_benchmark PRIVATE cpuinfo)
Expand Down Expand Up @@ -1200,7 +1198,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
${onnxruntime_EXTERNAL_LIBRARIES}
${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS})
if(NOT WIN32)
list(APPEND onnxruntime_perf_test_libs nsync::nsync_cpp)
if(onnxruntime_USE_SNPE)
list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe)
endif()
Expand Down Expand Up @@ -1238,7 +1235,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
# test inference using shared lib
set(onnxruntime_shared_lib_test_LIBS onnxruntime_mocked_allocator onnxruntime_test_utils onnxruntime_common onnx_proto)
if(NOT WIN32)
list(APPEND onnxruntime_shared_lib_test_LIBS nsync::nsync_cpp)
if(onnxruntime_USE_SNPE)
list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_providers_snpe)
endif()
Expand Down Expand Up @@ -1383,7 +1379,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo)
endif()
if(NOT WIN32)
target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS})
target_link_libraries(onnxruntime_mlas_test PRIVATE ${CMAKE_DL_LIBS})
endif()
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs})
Expand Down Expand Up @@ -1556,9 +1552,6 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
${ONNXRUNTIME_CUSTOM_OP_REGISTRATION_TEST_SRC_DIR}/test_registercustomops.cc)

set(onnxruntime_customopregistration_test_LIBS custom_op_library onnxruntime_common onnxruntime_test_utils)
if (NOT WIN32)
list(APPEND onnxruntime_customopregistration_test_LIBS nsync::nsync_cpp)
endif()
if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
list(APPEND onnxruntime_customopregistration_test_LIBS cpuinfo)
endif()
Expand Down Expand Up @@ -1683,7 +1676,7 @@ if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten"
set(onnxruntime_logging_apis_test_LIBS onnxruntime_common onnxruntime_test_utils)

if(NOT WIN32)
list(APPEND onnxruntime_logging_apis_test_LIBS nsync::nsync_cpp ${CMAKE_DL_LIBS})
list(APPEND onnxruntime_logging_apis_test_LIBS ${CMAKE_DL_LIBS})
endif()

AddTest(DYN
Expand Down
4 changes: 2 additions & 2 deletions cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ target_compile_options(onnx PRIVATE -Wno-unused-parameter -Wno-unused-variable)

if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB)
bundle_static_library(onnxruntime_webassembly
nsync::nsync_cpp

${PROTOBUF_LIB}
onnx
onnx_proto
Expand Down Expand Up @@ -174,7 +174,7 @@ else()
endif()

target_link_libraries(onnxruntime_webassembly PRIVATE
nsync::nsync_cpp

${PROTOBUF_LIB}
onnx
onnx_proto
Expand Down
69 changes: 4 additions & 65 deletions include/onnxruntime/core/platform/Barrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,76 +6,15 @@
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#pragma once
#include <assert.h>

#include "core/common/spin_pause.h"
#include "core/platform/ort_mutex.h"

#include <mutex>
#include <atomic>
#include <absl/synchronization/barrier.h>

Check warning on line 14 in include/onnxruntime/core/platform/Barrier.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after other header. Should be: Barrier.h, c system, c++ system, other. [build/include_order] [4] Raw Output: include/onnxruntime/core/platform/Barrier.h:14: Found C system header after other header. Should be: Barrier.h, c system, c++ system, other. [build/include_order] [4]
#include <absl/synchronization/notification.h>

Check warning on line 15 in include/onnxruntime/core/platform/Barrier.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after other header. Should be: Barrier.h, c system, c++ system, other. [build/include_order] [4] Raw Output: include/onnxruntime/core/platform/Barrier.h:15: Found C system header after other header. Should be: Barrier.h, c system, c++ system, other. [build/include_order] [4]

namespace onnxruntime {
class Barrier {
public:
explicit Barrier(unsigned int count, bool spin = false)
: state_(count << 1), notified_(false), spin_(spin) {
assert(((count << 1) >> 1) == count);
}
#ifdef NDEBUG
~Barrier() = default;
#else
~Barrier() {
assert((state_ >> 1) == 0);
}
#endif

void Notify(unsigned int c = 1) {
unsigned int delta = c << 1;
unsigned int v = state_.fetch_sub(delta, std::memory_order_acq_rel) - delta;
if (v != 1) {
// Clear the lowest bit (waiter flag) and check that the original state
// value was not zero. If it was zero, it means that notify was called
// more times than the original count.
assert(((v + delta) & ~1) != 0);
return; // either count has not dropped to 0, or waiter is not waiting
}
std::unique_lock<OrtMutex> l(mu_);
assert(!notified_);
notified_ = true;
cv_.notify_all();
}
using Notification = absl::Notification;

void Wait() {
if (spin_) {
while ((state_ >> 1) != 0) {
onnxruntime::concurrency::SpinPause();
}
} else {
unsigned int v = state_.fetch_or(1, std::memory_order_acq_rel);
if ((v >> 1) == 0)
return;
std::unique_lock<OrtMutex> l(mu_);
while (!notified_) {
cv_.wait(l);
}
}
}

private:
OrtMutex mu_;
OrtCondVar cv_;
std::atomic<unsigned int> state_; // low bit is waiter flag
bool notified_;
const bool spin_;
};

// Notification is an object that allows a user to to wait for another
// thread to signal a notification that an event has occurred.
//
// Multiple threads can wait on the same Notification object,
// but only one caller must call Notify() on the object.
struct Notification : Barrier {
Notification() : Barrier(1){};
};
} // namespace onnxruntime
40 changes: 23 additions & 17 deletions include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ class RunQueue {
#ifdef USE_LOCK_FREE_QUEUE
std::lock_guard<OrtSpinLock> mtx(spin_lock_);
#else
std::lock_guard<OrtMutex> lock(mutex_);
absl::MutexLock lock(&mutex_);
#endif
unsigned back = back_.load(std::memory_order_relaxed);
Elem& e = array_[(back - 1) & kMask];
Expand All @@ -484,7 +484,7 @@ class RunQueue {
#ifdef USE_LOCK_FREE_QUEUE
std::lock_guard<OrtSpinLock> mtx(spin_lock_);
#else
std::lock_guard<OrtMutex> lock(mutex_);
absl::MutexLock lock(&mutex_);
#endif
unsigned back = back_.load(std::memory_order_relaxed);
w_idx = (back - 1) & kMask;
Expand All @@ -509,7 +509,7 @@ class RunQueue {
#ifdef USE_LOCK_FREE_QUEUE
std::lock_guard<OrtSpinLock> mtx(spin_lock_);
#else
std::lock_guard<OrtMutex> lock(mutex_);
absl::MutexLock lock(&mutex_);
#endif
unsigned back;
Elem* e;
Expand Down Expand Up @@ -555,7 +555,7 @@ class RunQueue {
#ifdef USE_LOCK_FREE_QUEUE
std::lock_guard<OrtSpinLock> mtx(spin_lock_);
#else
std::lock_guard<OrtMutex> lock(mutex_);
absl::MutexLock lock(&mutex_);
#endif
Elem& e = array_[w_idx];
ElemState s = e.state.load(std::memory_order_relaxed);
Expand Down Expand Up @@ -1440,17 +1440,22 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
ThreadStatus seen = GetStatus();
if (seen == ThreadStatus::Blocking ||
seen == ThreadStatus::Blocked) {
std::unique_lock<OrtMutex> lk(mutex);
// Blocking state exists only transiently during the SetBlock() method
// while holding the lock. We may observe it at the start of this
// function, but after acquiring the lock then the target thread
// will either be blocked or not.
seen = status.load(std::memory_order_relaxed);
assert(seen != ThreadStatus::Blocking);
if (seen == ThreadStatus::Blocked) {
status.store(ThreadStatus::Waking, std::memory_order_relaxed);
lk.unlock();
cv.notify_one();
bool sig = false;
{
absl::MutexLock lk(&mutex);
// Blocking state exists only transiently during the SetBlock() method
// while holding the lock. We may observe it at the start of this
// function, but after acquiring the lock then the target thread
// will either be blocked or not.
seen = status.load(std::memory_order_relaxed);
assert(seen != ThreadStatus::Blocking);
if (seen == ThreadStatus::Blocked) {
status.store(ThreadStatus::Waking, std::memory_order_relaxed);
sig = true;
}
}
if (sig) {
cv.Signal();
}
}
}
Expand All @@ -1470,17 +1475,18 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter

void SetBlocked(std::function<bool()> should_block,
std::function<void()> post_block) {
std::unique_lock<OrtMutex> lk(mutex);
mutex.Lock();
assert(GetStatus() == ThreadStatus::Spinning);
status.store(ThreadStatus::Blocking, std::memory_order_relaxed);
if (should_block()) {
status.store(ThreadStatus::Blocked, std::memory_order_relaxed);
do {
cv.wait(lk);
cv.Wait(&mutex);
} while (status.load(std::memory_order_relaxed) == ThreadStatus::Blocked);
post_block();
}
status.store(ThreadStatus::Spinning, std::memory_order_relaxed);
mutex.Unlock();
}

private:
Expand Down
Loading

0 comments on commit 1c1a596

Please sign in to comment.