From bb737b96098557091480adecaf8863e8087d34cc Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 1 Oct 2024 15:24:10 +0200 Subject: [PATCH] PushAndPopContext outside callback_device_memory() --- cpp/include/kvikio/remote_handle.hpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 1bcf983ed5..436d07c91d 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -90,8 +90,6 @@ inline std::size_t callback_device_memory(char* data, return CURL_WRITEFUNC_ERROR; } - CUcontext cuda_ctx = get_context_from_pointer(ctx->buf); - PushAndPopContext c(cuda_ctx); CUstream stream = detail::StreamsByThread::get(); CUDA_DRIVER_TRY(cudaAPI::instance().MemcpyHtoDAsync( convert_void2deviceptr(ctx->buf + ctx->offset), data, nbytes, stream)); @@ -225,15 +223,15 @@ class RemoteHandle { << " bytes file (" << _endpoint->str() << ")"; throw std::invalid_argument(ss.str()); } - - auto curl = create_curl_handle(); + const bool is_host_mem = is_host_memory(buf); + auto curl = create_curl_handle(); _endpoint->setopt(curl); std::string const byte_range = std::to_string(file_offset) + "-" + std::to_string(file_offset + size - 1); curl.setopt(CURLOPT_RANGE, byte_range.c_str()); - if (is_host_memory(buf)) { + if (is_host_mem) { curl.setopt(CURLOPT_WRITEFUNCTION, detail::callback_host_memory); } else { curl.setopt(CURLOPT_WRITEFUNCTION, detail::callback_device_memory); @@ -243,7 +241,12 @@ class RemoteHandle { curl.setopt(CURLOPT_WRITEDATA, &ctx); try { - curl.perform(); + if (is_host_mem) { + curl.perform(); + } else { + PushAndPopContext c(get_context_from_pointer(buf)); + curl.perform(); + } } catch (std::runtime_error const& e) { if (ctx.overflow_error) { std::stringstream ss;