diff --git a/include/ylt/coro_io/coro_file.hpp b/include/ylt/coro_io/coro_file.hpp index bddf79063..aa3342a90 100644 --- a/include/ylt/coro_io/coro_file.hpp +++ b/include/ylt/coro_io/coro_file.hpp @@ -18,7 +18,6 @@ #include #include -#include #include #include #include @@ -44,6 +43,11 @@ #include "coro_io.hpp" +#if defined(ASIO_WINDOWS) +#include +#include +#endif + namespace coro_io { /* @@ -94,15 +98,14 @@ enum flags { }; enum class read_type { + init, #if defined(YLT_ENABLE_FILE_IO_URING) uring, uring_random, #else fread, #endif -#if defined(__GNUC__) pread, -#endif }; class coro_file { @@ -124,7 +127,13 @@ class coro_file { : executor_wrapper_(executor) {} #endif - bool is_open() { return stream_file_ != nullptr || fd_file_ != nullptr; } + bool is_open() { + if (type_ == read_type::pread) { + return fd_file_ != nullptr; + } + + return stream_file_ != nullptr; + } void flush() { #if defined(YLT_ENABLE_FILE_IO_URING) @@ -158,75 +167,72 @@ class coro_file { return size; } -#if defined(__GNUC__) - bool open_fd(std::string_view filepath, int open_mode = flags::read_write) { - if (fd_file_) { - return true; - } - - int fd = open(filepath.data(), open_mode); - if (fd < 0) { - return false; + async_simple::coro::Lazy> async_pread( + size_t offset, char* data, size_t size) { + if (type_ != read_type::pread) { + co_return std::make_pair( + std::make_error_code(std::errc::bad_file_descriptor), 0); } - - fd_file_ = std::shared_ptr(new int(fd), [](int* ptr) { - ::close(*ptr); - delete ptr; - }); - return true; - } - - async_simple::coro::Lazy> async_prw( - auto io_func, bool is_read, size_t offset, char* buf, size_t size) { - std::function func = [=, this] { - int fd = *fd_file_; - return io_func(fd, buf, size, offset); - }; - - std::error_code ec{}; - size_t op_size = 0; - - auto len_val = co_await coro_io::post(std::move(func), &executor_wrapper_); - int len = len_val.value(); - if (len == 0) { - if (is_read) { - eof_ = true; +#if defined(ASIO_WINDOWS) + auto pread = [](int fd, void* buf, uint64_t count, + uint64_t offset) -> int64_t { + DWORD bytes_read = 0; + OVERLAPPED overlapped; + memset(&overlapped, 0, sizeof(OVERLAPPED)); + overlapped.Offset = offset & 0xFFFFFFFF; + overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; + + BOOL ok = ReadFile(reinterpret_cast(_get_osfhandle(fd)), buf, + count, &bytes_read, &overlapped); + if (!ok && (errno = GetLastError()) != ERROR_HANDLE_EOF) { + return -1; } - } - else if (len > 0) { - op_size = len; - } - else { - ec = std::make_error_code(std::errc::io_error); - } - co_return std::make_pair(ec, op_size); - } - - async_simple::coro::Lazy> async_pread( - size_t offset, char* data, size_t size) { + return bytes_read; + }; +#endif co_return co_await async_prw(pread, true, offset, data, size); } async_simple::coro::Lazy async_pwrite(size_t offset, const char* data, size_t size) { + if (type_ != read_type::pread) { + co_return std::make_error_code(std::errc::bad_file_descriptor); + } +#if defined(ASIO_WINDOWS) + auto pwrite = [](int fd, const void* buf, uint64_t count, + uint64_t offset) -> int64_t { + DWORD bytes_write = 0; + OVERLAPPED overlapped; + memset(&overlapped, 0, sizeof(OVERLAPPED)); + overlapped.Offset = offset & 0xFFFFFFFF; + overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF; + + BOOL ok = WriteFile(reinterpret_cast(_get_osfhandle(fd)), buf, + count, &bytes_write, &overlapped); + if (!ok) { + return -1; + } + + return bytes_write; + }; +#endif auto result = co_await async_prw(pwrite, false, offset, (char*)data, size); co_return result.first; } -#endif #if defined(YLT_ENABLE_FILE_IO_URING) async_simple::coro::Lazy async_open(std::string_view filepath, int open_mode = flags::read_write, read_type type = read_type::uring) { type_ = type; - if (type == read_type::pread) { + if (type_ == read_type::pread) { co_return open_fd(filepath, open_mode); } try { - if (type == read_type::uring) { + if (type_ == read_type::uring) { stream_file_ = std::make_shared( executor_wrapper_.get_asio_executor()); } @@ -235,7 +241,9 @@ class coro_file { executor_wrapper_.get_asio_executor()); } } catch (std::exception& ex) { - std::cout << ex.what() << "\n"; + stream_file_ = nullptr; + std::cout << "line " << __LINE__ << " coro_file create failed" + << ex.what() << "\n"; co_return false; } @@ -244,7 +252,9 @@ class coro_file { static_cast(open_mode), ec); if (ec) { - std::cout << ec.message() << "\n"; + stream_file_ = nullptr; + std::cout << "line " << __LINE__ << " coro_file open failed" + << ec.message() << "\n"; co_return false; } @@ -256,7 +266,6 @@ class coro_file { return false; } - assert(stream_file_); std::error_code seek_ec; reinterpret_cast(stream_file_.get()) ->seek(offset, static_cast(whence), @@ -269,8 +278,10 @@ class coro_file { async_simple::coro::Lazy> async_read_at( uint64_t offset, char* data, size_t size) { - assert(stream_file_); - assert(type_ == read_type::uring_random); + if (type_ != read_type::uring_random) { + co_return std::make_pair( + std::make_error_code(std::errc::bad_file_descriptor), 0); + } auto [ec, read_size] = co_await coro_io::async_read_at( offset, @@ -288,8 +299,9 @@ class coro_file { async_simple::coro::Lazy async_write_at(uint64_t offset, const char* data, size_t size) { - assert(stream_file_); - assert(type_ == read_type::uring_random); + if (type_ != read_type::uring_random) { + co_return std::make_error_code(std::errc::bad_file_descriptor); + } auto [ec, write_size] = co_await coro_io::async_write_at( offset, @@ -300,8 +312,10 @@ class coro_file { async_simple::coro::Lazy> async_read( char* data, size_t size) { - assert(stream_file_); - assert(type_ == read_type::uring); + if (type_ != read_type::uring) { + co_return std::make_pair( + std::make_error_code(std::errc::bad_file_descriptor), 0); + } auto [ec, read_size] = co_await coro_io::async_read( *reinterpret_cast(stream_file_.get()), @@ -316,8 +330,9 @@ class coro_file { async_simple::coro::Lazy async_write(const char* data, size_t size) { - assert(stream_file_); - assert(type_ == read_type::uring); + if (type_ != read_type::uring) { + co_return std::make_error_code(std::errc::bad_file_descriptor); + } auto [ec, write_size] = co_await coro_io::async_write( *reinterpret_cast(stream_file_.get()), @@ -347,7 +362,9 @@ class coro_file { } bool seek(long offset, int whence) { - assert(fd_file_ == nullptr); + if (stream_file_ == nullptr) { + return false; + } return fseek(stream_file_.get(), offset, whence) == 0; } @@ -355,11 +372,10 @@ class coro_file { async_simple::coro::Lazy async_open(std::string filepath, int open_mode = flags::read_write, read_type type = read_type::fread) { -#if defined(__GNUC__) - if (type == read_type::pread) { + type_ = type; + if (type_ == read_type::pread) { co_return open_fd(filepath, open_mode); } -#endif if (stream_file_ != nullptr) { co_return true; @@ -369,8 +385,8 @@ class coro_file { [this, &filepath, open_mode] { auto fptr = fopen(filepath.data(), str_mode(open_mode).data()); if (fptr == nullptr) { - std::cout << "open file " << filepath << " failed " - << "\n"; + std::cout << "line " << __LINE__ << " coro_file open failed " + << filepath << "\n"; return false; } stream_file_ = std::shared_ptr(fptr, [](FILE* ptr) { @@ -384,6 +400,10 @@ class coro_file { async_simple::coro::Lazy> async_read( char* data, size_t size) { + if (type_ != read_type::fread) { + co_return std::make_pair( + std::make_error_code(std::errc::bad_file_descriptor), 0); + } auto result = co_await coro_io::post( [this, data, size] { auto fptr = stream_file_.get(); @@ -403,6 +423,9 @@ class coro_file { async_simple::coro::Lazy async_write(const char* data, size_t size) { + if (type_ != read_type::fread) { + co_return std::make_error_code(std::errc::bad_file_descriptor); + } auto result = co_await coro_io::post( [this, data, size] { auto fptr = stream_file_.get(); @@ -419,9 +442,95 @@ class coro_file { #endif private: + async_simple::coro::Lazy> async_prw( + auto io_func, bool is_read, size_t offset, char* buf, size_t size) { + std::function func = [=, this] { + int fd = *fd_file_; + return io_func(fd, buf, size, offset); + }; + + std::error_code ec{}; + size_t op_size = 0; + + auto len_val = co_await coro_io::post(std::move(func), &executor_wrapper_); + int len = len_val.value(); + if (len == 0) { + if (is_read) { + eof_ = true; + } + } + else if (len > 0) { + op_size = len; + } + else { + ec = std::make_error_code(std::errc::io_error); + } + + co_return std::make_pair(ec, op_size); + } + + bool open_fd(std::string_view filepath, int open_mode = flags::read_write) { + if (fd_file_) { + return true; + } + +#if defined(ASIO_WINDOWS) + int fd = _open(filepath.data(), adjust_open_mode(open_mode)); +#else + int fd = open(filepath.data(), open_mode); +#endif + if (fd < 0) { + return false; + } + + fd_file_ = std::shared_ptr(new int(fd), [](int* ptr) { +#if defined(ASIO_WINDOWS) + _close(*ptr); +#else + ::close(*ptr); +#endif + delete ptr; + }); + return true; + } + +#if defined(ASIO_WINDOWS) + static int adjust_open_mode(int open_mode) { + switch (open_mode) { + case flags::read_only: + return _O_RDONLY; + case flags::write_only: + return _O_WRONLY; + case flags::read_write: + return _O_RDWR; + case flags::append: + return _O_APPEND; + case flags::create: + return _O_CREAT; + case flags::exclusive: + return _O_EXCL; + case flags::truncate: + return _O_TRUNC; + case flags::create_write: + return _O_CREAT | _O_WRONLY; + case flags::create_write_trunc: + return _O_CREAT | _O_WRONLY | _O_TRUNC; + case flags::create_read_write_trunc: + return _O_RDWR | _O_CREAT | _O_TRUNC; + case flags::create_read_write_append: + return _O_RDWR | _O_CREAT | _O_APPEND; + case flags::sync_all_on_write: + default: + return open_mode; + break; + } + return open_mode; + } +#endif + private: + read_type type_ = read_type::init; #if defined(YLT_ENABLE_FILE_IO_URING) std::shared_ptr> stream_file_; - read_type type_ = read_type::uring; #else std::shared_ptr stream_file_; #endif diff --git a/src/coro_io/tests/test_corofile.cpp b/src/coro_io/tests/test_corofile.cpp index d6fea1264..28c5ef13c 100644 --- a/src/coro_io/tests/test_corofile.cpp +++ b/src/coro_io/tests/test_corofile.cpp @@ -69,7 +69,85 @@ void create_files(const std::vector& files, size_t file_size) { } } -#if defined(__GNUC__) +TEST_CASE("validate corofile") { + std::string filename = "validate.tmp"; + create_files({filename}, 190); + { + coro_io::coro_file file{}; + async_simple::coro::syncAwait(file.async_open( + filename.data(), coro_io::flags::read_only, coro_io::read_type::pread)); + CHECK(file.is_open()); + + char buf[100]; + std::error_code ec; + size_t size; + std::tie(ec, size) = + async_simple::coro::syncAwait(file.async_read(buf, 10)); + CHECK(ec == std::make_error_code(std::errc::bad_file_descriptor)); + CHECK(size == 0); + + auto write_ec = async_simple::coro::syncAwait(file.async_write(buf, 10)); + CHECK(write_ec == std::make_error_code(std::errc::bad_file_descriptor)); + } +#if defined(YLT_ENABLE_FILE_IO_URING) + { + coro_io::coro_file file{}; + async_simple::coro::syncAwait( + file.async_open(filename.data(), coro_io::flags::read_only, + coro_io::read_type::uring_random)); + CHECK(file.is_open()); + + char buf[100]; + std::error_code ec; + size_t size; + std::tie(ec, size) = + async_simple::coro::syncAwait(file.async_read(buf, 10)); + CHECK(ec == std::make_error_code(std::errc::bad_file_descriptor)); + CHECK(size == 0); + + ec = async_simple::coro::syncAwait(file.async_write(buf, 10)); + CHECK(ec == std::make_error_code(std::errc::bad_file_descriptor)); + } + + { + coro_io::coro_file file{}; + async_simple::coro::syncAwait(file.async_open( + filename.data(), coro_io::flags::read_only, coro_io::read_type::uring)); + CHECK(file.is_open()); + + char buf[100]; + std::error_code ec; + size_t size; + std::tie(ec, size) = + async_simple::coro::syncAwait(file.async_read_at(0, buf, 10)); + CHECK(ec == std::make_error_code(std::errc::bad_file_descriptor)); + CHECK(size == 0); + + ec = async_simple::coro::syncAwait(file.async_write_at(0, buf, 10)); + CHECK(ec == std::make_error_code(std::errc::bad_file_descriptor)); + } +#else + { + coro_io::coro_file file{}; + async_simple::coro::syncAwait(file.async_open( + filename.data(), coro_io::flags::read_only, coro_io::read_type::fread)); + CHECK(file.is_open()); + + char buf[100]; + std::error_code ec; + size_t size; + std::tie(ec, size) = + async_simple::coro::syncAwait(file.async_pread(0, buf, 10)); + CHECK(ec == std::make_error_code(std::errc::bad_file_descriptor)); + CHECK(size == 0); + + auto write_ec = + async_simple::coro::syncAwait(file.async_pwrite(0, buf, 10)); + CHECK(write_ec == std::make_error_code(std::errc::bad_file_descriptor)); + } +#endif +} + TEST_CASE("coro_file pread and pwrite basic test") { std::string filename = "test.tmp"; create_files({filename}, 190); @@ -179,7 +257,6 @@ TEST_CASE("coro_file pread and pwrite basic test") { CHECK(std::string_view(buf2, pair.second) == "dddddddddd"); } } -#endif async_simple::coro::Lazy test_basic_read(std::string filename) { coro_io::coro_file file{};