From c555b0f42bd52f550a57b991e2e53e4c93129680 Mon Sep 17 00:00:00 2001 From: gxu <69813939+gxuu@users.noreply.github.com> Date: Fri, 19 Sep 2025 11:55:49 -0400 Subject: [PATCH 1/3] Swap asio to ymq for object storage server (#224) * Make YMQ more ergonomic for use cases - Bytes `const std::string&` constructor for debug and testing usage - Bytes equality comparator - Extra error that likely helps user develop Signed-off-by: gxu * Make YMQ more aligned to the MITM test It used to be the case that YMQ always tries to reconnect, but it turns out that we don't really need this: graceful disconnect are consider to be a real disconnect; while ungraceful reconnect (ECONNRESET) will trigers reconnect. Signed-off-by: gxu * Swap Object Storage Server to use YMQ Signed-off-by: gxu * Change signals to SIGKILL as YMQ doesn't support signal handling Signed-off-by: gxu * Resolve comment for rafa Signed-off-by: gxu --------- Signed-off-by: gxu --- .github/workflows/linter.yml | 5 +- examples/task_capabilities.py | 10 +- scaler/client/client.py | 13 +- scaler/io/sync_object_storage_connector.py | 7 +- scaler/io/ymq/bytes.h | 11 +- .../io/ymq/examples/automated_echo_client.cpp | 2 +- scaler/io/ymq/message_connection_tcp.cpp | 1 - scaler/io/ymq/pymod_ymq/async.h | 3 +- scaler/io/ymq/pymod_ymq/exception.h | 2 +- scaler/io/ymq/pymod_ymq/io_context.h | 4 +- scaler/io/ymq/pymod_ymq/io_socket.h | 10 +- scaler/io/ymq/pymod_ymq/ymq.h | 2 +- scaler/io/ymq/simple_interface.cpp | 29 +- scaler/io/ymq/simple_interface.h | 8 +- scaler/io/ymq/ymq.pyi | 14 +- scaler/io/ymq/ymq_test.py | 20 - .../object_storage/object_storage_server.cpp | 13 +- scaler/object_storage/object_storage_server.h | 3 +- scaler/worker/agent/heartbeat_manager.py | 5 +- scaler/worker/worker.py | 2 +- scripts/build.sh | 2 +- tests/CMakeLists.txt | 5 +- tests/cc_ymq/CMakeLists.txt | 1 + tests/cc_ymq/common.h | 496 +++++++++++++++++ tests/cc_ymq/py_mitm/__init__.py | 0 tests/cc_ymq/py_mitm/main.py | 152 ++++++ tests/cc_ymq/py_mitm/passthrough.py | 23 + tests/cc_ymq/py_mitm/randomly_drop_packets.py | 28 + tests/cc_ymq/py_mitm/send_rst_to_client.py | 48 ++ tests/cc_ymq/py_mitm/types.py | 54 ++ tests/cc_ymq/test_cc_ymq.cpp | 508 ++++++++++++++++++ .../test_object_storage_server.cpp | 30 +- tests/pymod_ymq/__init__.py | 0 tests/pymod_ymq/test_pymod_ymq.py | 150 ++++++ tests/pymod_ymq/test_types.py | 90 ++++ tests/test_graph.py | 8 +- 36 files changed, 1639 insertions(+), 120 deletions(-) delete mode 100644 scaler/io/ymq/ymq_test.py create mode 100644 tests/cc_ymq/CMakeLists.txt create mode 100644 tests/cc_ymq/common.h create mode 100644 tests/cc_ymq/py_mitm/__init__.py create mode 100644 tests/cc_ymq/py_mitm/main.py create mode 100644 tests/cc_ymq/py_mitm/passthrough.py create mode 100644 tests/cc_ymq/py_mitm/randomly_drop_packets.py create mode 100644 tests/cc_ymq/py_mitm/send_rst_to_client.py create mode 100644 tests/cc_ymq/py_mitm/types.py create mode 100644 tests/cc_ymq/test_cc_ymq.cpp create mode 100644 tests/pymod_ymq/__init__.py create mode 100644 tests/pymod_ymq/test_pymod_ymq.py create mode 100644 tests/pymod_ymq/test_types.py diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index 97cb16760..139703fc8 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -74,8 +74,11 @@ jobs: sudo ./scripts/download_install_dependencies.sh capnp install - name: Build and test C++ Components + + # TODO: Figure out how to run man-in-the-middle tests in CI + # TODO: Why does TestIncompleteIdentity work when run locally, but fail in CI? run: | - CXX=$(which g++-14) ./scripts/build.sh + CXX=$(which g++-14) GTEST_FILTER="-*Mitm*:*TestIncompleteIdentity*" ./scripts/build.sh - name: Install Python Dependent Packages run: | diff --git a/examples/task_capabilities.py b/examples/task_capabilities.py index 45dc12734..34c0292d8 100644 --- a/examples/task_capabilities.py +++ b/examples/task_capabilities.py @@ -56,18 +56,12 @@ def main(): # Submit a task that requires GPU capabilities, this will be redirected to the GPU worker. gpu_future = client.submit_verbose( - gpu_task, - args=(16.0,), - kwargs={}, - capabilities={"gpu": 1} # Requires a GPU capability + gpu_task, args=(16.0,), kwargs={}, capabilities={"gpu": 1} # Requires a GPU capability ) # Submit a task that does not require GPU capabilities, this will be routed to any available worker. cpu_future = client.submit_verbose( - cpu_task, - args=(16.0,), - kwargs={}, - capabilities={} # No GPU capability required + cpu_task, args=(16.0,), kwargs={}, capabilities={} # No GPU capability required ) # Waits for the tasks for finish diff --git a/scaler/client/client.py b/scaler/client/client.py index 9807adb84..2868df929 100644 --- a/scaler/client/client.py +++ b/scaler/client/client.py @@ -210,11 +210,7 @@ def submit(self, fn: Callable, *args, **kwargs) -> ScalerFuture: return self.submit_verbose(fn, args, kwargs) def submit_verbose( - self, - fn: Callable, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - capabilities: Optional[Dict[str, int]] = None, + self, fn: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any], capabilities: Optional[Dict[str, int]] = None ) -> ScalerFuture: """ Submit a single task (function with arguments) to the scheduler, and return a future. Possibly route the task to @@ -242,10 +238,7 @@ def submit_verbose( return future def map( - self, - fn: Callable, - iterable: Iterable[Tuple[Any, ...]], - capabilities: Optional[Dict[str, int]] = None + self, fn: Callable, iterable: Iterable[Tuple[Any, ...]], capabilities: Optional[Dict[str, int]] = None ) -> List[Any]: if not all(isinstance(args, (tuple, list)) for args in iterable): raise TypeError("iterable should be list of arguments(list or tuple-like) of function") @@ -309,7 +302,7 @@ def get( self.__check_graph(node_name_to_argument, call_graph, keys) graph_task, compute_futures, finished_futures = self.__construct_graph( - node_name_to_argument, call_graph, keys, block, capabilities, + node_name_to_argument, call_graph, keys, block, capabilities ) self._object_buffer.commit_send_objects() self._connector_agent.send(graph_task) diff --git a/scaler/io/sync_object_storage_connector.py b/scaler/io/sync_object_storage_connector.py index 40f3c9fbf..2a2538075 100644 --- a/scaler/io/sync_object_storage_connector.py +++ b/scaler/io/sync_object_storage_connector.py @@ -147,10 +147,9 @@ def __send_request( header_bytes = header.get_message().to_bytes() if payload is not None: - self.__send_buffers([struct.pack(" // C++ +#include #include // First-party @@ -33,10 +34,9 @@ class Bytes { public: Bytes(char* data, size_t len): _data(datadup((uint8_t*)data, len)), _len(len) {} - Bytes(): _data {}, _len {} {} + Bytes(const std::string& s): _data(datadup((uint8_t*)s.data(), s.length())), _len(s.length()) {} - // For debug and convenience only - explicit Bytes(const std::string& str): Bytes((char*)str.c_str(), str.size()) {} + Bytes(): _data {}, _len {} {} Bytes(const Bytes& other) noexcept { @@ -92,11 +92,10 @@ class Bytes { [[nodiscard]] constexpr bool is_null() const noexcept { return !this->_data; } - // debugging utility - std::string as_string() const + std::optional as_string() const { if (is_null()) - return "[EMPTY]"; + return std::nullopt; return std::string((char*)_data, _len); } diff --git a/scaler/io/ymq/examples/automated_echo_client.cpp b/scaler/io/ymq/examples/automated_echo_client.cpp index 5a4ca3510..328b3e219 100644 --- a/scaler/io/ymq/examples/automated_echo_client.cpp +++ b/scaler/io/ymq/examples/automated_echo_client.cpp @@ -70,7 +70,7 @@ int main() auto future = x.get_future(); Message msg = future.get().first; if (msg.payload.as_string() != longStr) { - printf("Checksum failed, %s\n", msg.payload.as_string().c_str()); + printf("Checksum failed, %s\n", msg.payload.as_string()->c_str()); exit(1); } } diff --git a/scaler/io/ymq/message_connection_tcp.cpp b/scaler/io/ymq/message_connection_tcp.cpp index 0a3fec824..21efe0d8a 100644 --- a/scaler/io/ymq/message_connection_tcp.cpp +++ b/scaler/io/ymq/message_connection_tcp.cpp @@ -265,7 +265,6 @@ void MessageConnectionTCP::updateReadOperation() _pendingRecvMessageCallbacks->pop(); recvMessageCallback({Message(std::move(address), std::move(payload)), {}}); - } else { assert(_pendingRecvMessageCallbacks->size()); break; diff --git a/scaler/io/ymq/pymod_ymq/async.h b/scaler/io/ymq/pymod_ymq/async.h index 8a2e229a5..ba80d6bb3 100644 --- a/scaler/io/ymq/pymod_ymq/async.h +++ b/scaler/io/ymq/pymod_ymq/async.h @@ -12,7 +12,7 @@ #include "scaler/io/ymq/pymod_ymq/ymq.h" // wraps an async callback that accepts a Python asyncio future -static PyObject* async_wrapper(PyObject* self, const std::function& callback) +static PyObject* async_wrapper(PyObject* self, const std::function&& callback) { auto state = YMQStateFromSelf(self); if (!state) @@ -25,7 +25,6 @@ static PyObject* async_wrapper(PyObject* self, const std::function // First-party -#include "scaler/io/ymq/pymod_ymq/ymq.h" #include "scaler/io/ymq/pymod_ymq/utils.h" +#include "scaler/io/ymq/pymod_ymq/ymq.h" // the order of the members in the exception args tuple const Py_ssize_t YMQException_errorCodeIndex = 0; diff --git a/scaler/io/ymq/pymod_ymq/io_context.h b/scaler/io/ymq/pymod_ymq/io_context.h index 2bc9be556..e922d5388 100644 --- a/scaler/io/ymq/pymod_ymq/io_context.h +++ b/scaler/io/ymq/pymod_ymq/io_context.h @@ -77,8 +77,8 @@ static PyObject* PyIOContext_createIOSocket_( using Identity = Configuration::IOSocketIdentity; // note: references borrowed from args, so no need to manage their lifetime - PyObject* pyIdentity {}; - PyObject* pySocketType {}; + PyObject* pyIdentity = nullptr; + PyObject* pySocketType = nullptr; if (nargs == 1) { pyIdentity = args[0]; } else if (nargs == 2) { diff --git a/scaler/io/ymq/pymod_ymq/io_socket.h b/scaler/io/ymq/pymod_ymq/io_socket.h index 152c785a8..898afb832 100644 --- a/scaler/io/ymq/pymod_ymq/io_socket.h +++ b/scaler/io/ymq/pymod_ymq/io_socket.h @@ -156,10 +156,6 @@ static PyObject* PyIOSocket_recv(PyIOSocket* self, PyObject* args) if (!pyMessage) return YMQ_GetRaisedException(); - // TODO: why is leaking necessary? - address.forget(); - payload.forget(); - return (PyObject*)pyMessage.take(); }); } catch (...) { @@ -221,10 +217,6 @@ static PyObject* PyIOSocket_recv_sync(PyIOSocket* self, PyObject* args) if (!pyMessage) return nullptr; - // TODO: why is leaking necessary? - address.forget(); - payload.forget(); - return (PyObject*)pyMessage.take(); } @@ -378,7 +370,7 @@ static PyObject* PyIOSocket_socket_type_getter(PyIOSocket* self, void* closure) if (!state) return nullptr; - const IOSocketType socketType = self->socket->socketType(); + const IOSocketType socketType = self->socket->socketType(); OwnedPyObject socketTypeIntObj = PyLong_FromLong((long)socketType); if (!socketTypeIntObj) diff --git a/scaler/io/ymq/pymod_ymq/ymq.h b/scaler/io/ymq/pymod_ymq/ymq.h index 879b019d8..b788d05dd 100644 --- a/scaler/io/ymq/pymod_ymq/ymq.h +++ b/scaler/io/ymq/pymod_ymq/ymq.h @@ -304,7 +304,7 @@ static int YMQ_createErrorCodeEnum(PyObject* pyModule, YMQState* state) // docs and examples are unfortunately scarce for this // for now this will work just fine OwnedPyObject item {}; - while (item = PyIter_Next(*iter)) { + while ((item = PyIter_Next(*iter))) { OwnedPyObject fn = PyCMethod_New(&YMQErrorCode_explanation_def, *item, pyModule, nullptr); if (!fn) return -1; diff --git a/scaler/io/ymq/simple_interface.cpp b/scaler/io/ymq/simple_interface.cpp index 3c51ccfc2..d587d09e2 100644 --- a/scaler/io/ymq/simple_interface.cpp +++ b/scaler/io/ymq/simple_interface.cpp @@ -1,6 +1,8 @@ #include "scaler/io/ymq/simple_interface.h" +#include + namespace scaler { namespace ymq { @@ -35,35 +37,42 @@ void syncConnectSocket(std::shared_ptr socket, std::string address) connect_future.wait(); } -std::pair syncRecvMessage(std::shared_ptr socket) +std::expected syncRecvMessage(std::shared_ptr socket) { auto fut = futureRecvMessage(std::move(socket)); return fut.get(); } -std::expected syncSendMessage(std::shared_ptr socket, Message message) +std::optional syncSendMessage(std::shared_ptr socket, Message message) { auto fut = futureSendMessage(std::move(socket), std::move(message)); return fut.get(); } -std::future> futureRecvMessage(std::shared_ptr socket) +std::future> futureRecvMessage(std::shared_ptr socket) { - auto recv_promise_ptr = std::make_unique>>(); + auto recv_promise_ptr = std::make_unique>>(); auto recv_future = recv_promise_ptr->get_future(); - socket->recvMessage([recv_promise = std::move(recv_promise_ptr)](std::pair msg) { - recv_promise->set_value(std::move(msg)); + socket->recvMessage([recv_promise = std::move(recv_promise_ptr)](std::pair result) { + if (result.second._errorCode == Error::ErrorCode::Uninit) + recv_promise->set_value(std::move(result.first)); + else + recv_promise->set_value(std::unexpected {std::move(result.second)}); }); + return {std::move(recv_future)}; } -std::future> futureSendMessage(std::shared_ptr socket, Message message) +std::future> futureSendMessage(std::shared_ptr socket, Message message) { - auto send_promise_ptr = std::make_unique>>(); + auto send_promise_ptr = std::make_unique>>(); auto send_future = send_promise_ptr->get_future(); socket->sendMessage( - std::move(message), [send_promise = std::move(send_promise_ptr)](std::expected msg) { - send_promise->set_value(std::move(msg)); + std::move(message), [send_promise = std::move(send_promise_ptr)](std::expected result) { + if (result) + send_promise->set_value(std::nullopt); + else + send_promise->set_value(std::move(result.error())); }); return {std::move(send_future)}; } diff --git a/scaler/io/ymq/simple_interface.h b/scaler/io/ymq/simple_interface.h index 2f2d7a03d..9b9fa1f39 100644 --- a/scaler/io/ymq/simple_interface.h +++ b/scaler/io/ymq/simple_interface.h @@ -14,11 +14,11 @@ std::shared_ptr syncCreateSocket(IOContext& context, IOSocketType type void syncBindSocket(std::shared_ptr socket, std::string address); void syncConnectSocket(std::shared_ptr socket, std::string address); -std::pair syncRecvMessage(std::shared_ptr socket); -std::expected syncSendMessage(std::shared_ptr socket, Message message); +std::expected syncRecvMessage(std::shared_ptr socket); +std::optional syncSendMessage(std::shared_ptr socket, Message message); -std::future> futureRecvMessage(std::shared_ptr socket); -std::future> futureSendMessage(std::shared_ptr socket, Message message); +std::future> futureRecvMessage(std::shared_ptr socket); +std::future> futureSendMessage(std::shared_ptr socket, Message message); } // namespace ymq } // namespace scaler diff --git a/scaler/io/ymq/ymq.pyi b/scaler/io/ymq/ymq.pyi index 7444d6fad..d436fd136 100644 --- a/scaler/io/ymq/ymq.pyi +++ b/scaler/io/ymq/ymq.pyi @@ -1,6 +1,5 @@ # NOTE: NOT IMPLEMENTATION, TYPE INFORMATION ONLY # This file contains type stubs for the Ymq Python C Extension module -import abc import sys from collections.abc import Awaitable from enum import IntEnum @@ -11,12 +10,17 @@ if sys.version_info >= (3, 12): else: Buffer = object -class Bytes(Buffer, metaclass=abc.ABCMeta): - data: bytes +class Bytes(Buffer): + data: bytes | None len: int - def __init__(self, data: SupportsBytes | bytes) -> None: ... + def __init__(self, data: Buffer | None = None) -> None: ... def __repr__(self) -> str: ... + def __len__(self) -> int: ... + + # this type signature is not 100% accurate because it's implemented in C + # but this satisfies the type check and is good enough + def __buffer__(self, flags: int, /) -> memoryview: ... class Message: address: Bytes | None @@ -99,7 +103,7 @@ class YMQException(Exception): code: ErrorCode message: str - def __init__(self, code: ErrorCode, message: str) -> None: ... + def __init__(self, /, code: ErrorCode, message: str) -> None: ... def __repr__(self) -> str: ... def __str__(self) -> str: ... diff --git a/scaler/io/ymq/ymq_test.py b/scaler/io/ymq/ymq_test.py deleted file mode 100644 index 9201983c7..000000000 --- a/scaler/io/ymq/ymq_test.py +++ /dev/null @@ -1,20 +0,0 @@ -import asyncio - -import ymq - - -async def main(): - ctx = ymq.IOContext() - socket = await ctx.createIOSocket("ident", ymq.IOSocketType.Binder) - print(ctx, ";", socket) - - assert socket.identity == "ident" - assert socket.socket_type == ymq.IOSocketType.Binder - - exc = ymq.YMQException(ymq.ErrorCode.InvalidAddressFormat, "the address has an invalid format") - assert exc.code == ymq.ErrorCode.InvalidAddressFormat - assert exc.message == "the address has an invalid format" - assert exc.code.explanation() - - -asyncio.run(main()) diff --git a/scaler/object_storage/object_storage_server.cpp b/scaler/object_storage/object_storage_server.cpp index c69ff698e..0ec59f9a7 100644 --- a/scaler/object_storage/object_storage_server.cpp +++ b/scaler/object_storage/object_storage_server.cpp @@ -130,14 +130,15 @@ void ObjectStorageServer::processRequests() std::ranges::for_each(_pendingSendMessageFuts, [](auto& fut) { if (fut.wait_for(0s) == std::future_status::ready) { - auto res = fut.get(); - assert(res); + auto error = fut.get(); + assert(!error); } }); - auto [message, error] = ymq::syncRecvMessage(_ioSocket); + auto maybeMessage = ymq::syncRecvMessage(_ioSocket); - if (error._errorCode != ymq::Error::ErrorCode::Uninit) { + if (!maybeMessage) { + auto error = maybeMessage.error(); if (error._errorCode == ymq::Error::ErrorCode::IOSocketStopRequested) { _logger.log( scaler::ymq::Logger::LoggingLevel::info, @@ -151,8 +152,8 @@ void ObjectStorageServer::processRequests() } } - const auto identity = lastMessageIdentity = message.address.as_string(); - const auto headerOrPayload = std::move(message.payload); + const auto identity = *maybeMessage->address.as_string(); + const auto headerOrPayload = std::move(maybeMessage->payload); auto it = identityToFullRequest.find(identity); if (it == identityToFullRequest.end()) { diff --git a/scaler/object_storage/object_storage_server.h b/scaler/object_storage/object_storage_server.h index 20ed8e738..00849950e 100644 --- a/scaler/object_storage/object_storage_server.h +++ b/scaler/object_storage/object_storage_server.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "scaler/io/ymq/configuration.h" @@ -22,7 +23,7 @@ namespace object_storage { class ObjectStorageServer { public: using Identity = ymq::Configuration::IOSocketIdentity; - using SendMessageFuture = std::future>; + using SendMessageFuture = std::future>; ObjectStorageServer(); diff --git a/scaler/worker/agent/heartbeat_manager.py b/scaler/worker/agent/heartbeat_manager.py index ee8ce6ac2..92e9edc08 100644 --- a/scaler/worker/agent/heartbeat_manager.py +++ b/scaler/worker/agent/heartbeat_manager.py @@ -14,10 +14,7 @@ class VanillaHeartbeatManager(Looper, HeartbeatManager): def __init__( - self, - storage_address: Optional[ObjectStorageConfig], - capabilities: Dict[str, int], - task_queue_size: int + self, storage_address: Optional[ObjectStorageConfig], capabilities: Dict[str, int], task_queue_size: int ): self._agent_process = psutil.Process() self._capabilities = capabilities diff --git a/scaler/worker/worker.py b/scaler/worker/worker.py index 6a710d946..18092fe40 100644 --- a/scaler/worker/worker.py +++ b/scaler/worker/worker.py @@ -126,7 +126,7 @@ def __initialize(self): self._heartbeat_manager = VanillaHeartbeatManager( storage_address=self._storage_address, capabilities=self._capabilities, - task_queue_size=self._task_queue_size + task_queue_size=self._task_queue_size, ) self._profiling_manager = VanillaProfilingManager() diff --git a/scripts/build.sh b/scripts/build.sh index 43c3d4549..dd9661bba 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -24,4 +24,4 @@ cmake --build --preset $BUILD_PRESET cmake --install $BUILD_DIR # Tests -ctest --preset $BUILD_PRESET +ctest --preset $BUILD_PRESET -VV diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b539a5d93..6b714034c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -11,6 +11,8 @@ set(BUILD_GMOCK OFF CACHE BOOL "" FORCE) set(BUILD_GTEST ON CACHE BOOL "" FORCE) FetchContent_MakeAvailable(googletest) +find_package(Python3 COMPONENTS Development REQUIRED) + # This function compiles, links, and adds a C++ test executable using Google Test. # It is shared by all test subdirectories. function(add_test_executable test_name source_file) @@ -21,6 +23,7 @@ function(add_test_executable test_name source_file) target_link_libraries( ${test_name} PRIVATE GTest::gtest_main + PRIVATE Python3::Python PRIVATE cc_object_storage_server PRIVATE cc_ymq ) @@ -28,11 +31,11 @@ function(add_test_executable test_name source_file) add_test(NAME ${test_name} COMMAND ${test_name}) endfunction() - if(LINUX OR APPLE) # This directory fetches Google Test, so it must be included first. add_subdirectory(object_storage) # Add the new directory for io tests. add_subdirectory(io/ymq) + add_subdirectory(cc_ymq) endif() diff --git a/tests/cc_ymq/CMakeLists.txt b/tests/cc_ymq/CMakeLists.txt new file mode 100644 index 000000000..9f6abe371 --- /dev/null +++ b/tests/cc_ymq/CMakeLists.txt @@ -0,0 +1 @@ +add_test_executable(test_cc_ymq test_cc_ymq.cpp) diff --git a/tests/cc_ymq/common.h b/tests/cc_ymq/common.h new file mode 100644 index 000000000..42dfab9fe --- /dev/null +++ b/tests/cc_ymq/common.h @@ -0,0 +1,496 @@ +#pragma once + +#define PY_SSIZE_T_CLEAN +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define RETURN_FAILURE_IF_FALSE(condition) \ + if (!(condition)) { \ + return TestResult::Failure; \ + } + +using namespace std::chrono_literals; + +enum class TestResult : char { Success = 1, Failure = 2 }; + +inline const char* check_localhost(const char* host) +{ + return std::strcmp(host, "localhost") == 0 ? "127.0.0.1" : host; +} + +inline std::string format_address(std::string host, uint16_t port) +{ + return std::format("tcp://{}:{}", check_localhost(host.c_str()), port); +} + +class OwnedFd { +public: + int fd; + + OwnedFd(int fd): fd(fd) {} + + // move-only + OwnedFd(const OwnedFd&) = delete; + OwnedFd& operator=(const OwnedFd&) = delete; + OwnedFd(OwnedFd&& other) noexcept: fd(other.fd) { other.fd = 0; } + OwnedFd& operator=(OwnedFd&& other) noexcept + { + if (this != &other) { + this->fd = other.fd; + other.fd = 0; + } + return *this; + } + + ~OwnedFd() + { + if (fd > 0 && close(fd) < 0) + std::println(std::cerr, "failed to close fd!"); + } + + size_t write(const void* data, size_t len) + { + auto n = ::write(this->fd, data, len); + if (n < 0) + throw std::system_error(errno, std::generic_category(), "failed to write to socket"); + + return n; + } + + void write_all(const char* data, size_t len) + { + for (size_t cursor = 0; cursor < len;) + cursor += this->write(data + cursor, len - cursor); + } + + void write_all(std::vector data) { this->write_all(data.data(), data.size()); } + + size_t read(void* buffer, size_t len) + { + auto n = ::read(this->fd, buffer, len); + if (n < 0) + throw std::system_error(errno, std::generic_category(), "failed to read from socket"); + return n; + } + + void read_exact(char* buffer, size_t len) + { + for (size_t cursor = 0; cursor < len;) + cursor += this->read(buffer + cursor, len - cursor); + } + + operator int() { return fd; } +}; + +class Socket: public OwnedFd { +public: + Socket(int fd): OwnedFd(fd) {} + + void connect(const char* host, uint16_t port, bool nowait = false) + { + sockaddr_in addr { + .sin_family = AF_INET, + .sin_port = htons(port), + .sin_addr = {.s_addr = inet_addr(check_localhost(host))}, + .sin_zero = {0}}; + + connect: + if (::connect(this->fd, (sockaddr*)&addr, sizeof(addr)) < 0) { + if (errno == ECONNREFUSED && !nowait) { + std::this_thread::sleep_for(300ms); + goto connect; + } + + throw std::system_error(errno, std::generic_category(), "failed to connect"); + } + } + + void bind(const char* host, int port) + { + sockaddr_in addr { + .sin_family = AF_INET, + .sin_port = htons(port), + .sin_addr = {.s_addr = inet_addr(check_localhost(host))}, + .sin_zero = {0}}; + + auto status = ::bind(this->fd, (sockaddr*)&addr, sizeof(addr)); + if (status < 0) + throw std::system_error(errno, std::generic_category(), "failed to bind"); + } + + void listen(int n = 32) + { + auto status = ::listen(this->fd, n); + if (status < 0) + throw std::system_error(errno, std::generic_category(), "failed to listen on socket"); + } + + std::pair accept(int flags = 0) + { + sockaddr_in peer_addr {}; + socklen_t len = sizeof(peer_addr); + auto fd = ::accept4(this->fd, (sockaddr*)&peer_addr, &len, flags); + if (fd < 0) + throw std::system_error(errno, std::generic_category(), "failed to accept socket"); + + return std::make_pair(Socket(fd), peer_addr); + } + + void write_message(std::string message) + { + uint64_t header = message.length(); + this->write_all((char*)&header, 8); + this->write_all(message.data(), message.length()); + } + + std::string read_message() + { + uint64_t header = 0; + this->read_exact((char*)&header, 8); + std::vector buffer(header); + this->read_exact(buffer.data(), header); + return std::string(buffer.data(), header); + } +}; + +class TcpSocket: public Socket { +public: + TcpSocket(): Socket(0) + { + this->fd = ::socket(AF_INET, SOCK_STREAM, 0); + if (this->fd < 0) + throw std::system_error(errno, std::generic_category(), "failed to create socket"); + + int on = 1; + if (setsockopt(this->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&on, sizeof(on)) < 0) + throw std::system_error(errno, std::generic_category(), "failed to set nodelay"); + + if (setsockopt(this->fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) + throw std::system_error(errno, std::generic_category(), "failed to set reuseaddr"); + } +}; + +inline void fork_wrapper(std::function fn, int timeout_secs, OwnedFd pipe_wr) +{ + TestResult result = TestResult::Failure; + try { + result = fn(); + } catch (const std::exception& e) { + std::println(stderr, "Exception: {}", e.what()); + result = TestResult::Failure; + } + + pipe_wr.write_all((char*)&result, sizeof(TestResult)); +} + +// this function along with `wait_for_python_ready_sigwait()` +// work together to wait on a signal from the python process +// indicating that the tuntap interface has been created, and that the mitm is ready +inline void wait_for_python_ready_sigblock() +{ + sigset_t set {}; + int sig = 0; + + if (sigemptyset(&set) < 0) + throw std::system_error(errno, std::generic_category(), "failed to create empty signal set"); + + if (sigaddset(&set, SIGUSR1) < 0) + throw std::system_error(errno, std::generic_category(), "failed to add sigusr1 to the signal set"); + + if (sigprocmask(SIG_BLOCK, &set, nullptr) < 0) + throw std::system_error(errno, std::generic_category(), "failed to mask sigusr1"); + + std::println("blocked signal..."); +} + +inline void wait_for_python_ready_sigwait(int timeout_secs) +{ + sigset_t set {}; + siginfo_t sig {}; + + if (sigemptyset(&set) < 0) + throw std::system_error(errno, std::generic_category(), "failed to create empty signal set"); + + if (sigaddset(&set, SIGUSR1) < 0) + throw std::system_error(errno, std::generic_category(), "failed to add sigusr1 to the signal set"); + + std::println("waiting for python to be ready..."); + timespec ts {.tv_sec = timeout_secs, .tv_nsec = 0}; + if (sigtimedwait(&set, &sig, &ts) < 0) + throw std::system_error(errno, std::generic_category(), "failed to wait on sigusr1"); + + sigprocmask(SIG_UNBLOCK, &set, nullptr); + std::println("signal received; python is ready"); +} + +// run a test +// forks and runs each of the provided closures +// if `wait_for_python` is true, wait for SIGUSR1 after forking and executing the first closure +inline TestResult test( + int timeout_secs, std::vector> closures, bool wait_for_python = false) +{ + std::vector> pipes {}; + std::vector pids {}; + for (size_t i = 0; i < closures.size(); i++) { + int pipe[2] = {0}; + if (pipe2(pipe, O_NONBLOCK) < 0) { + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { + close(pipe.first); + close(pipe.second); + }); + + throw std::system_error(errno, std::generic_category(), "failed to create pipe: "); + } + pipes.push_back(std::make_pair(pipe[0], pipe[1])); + } + + for (size_t i = 0; i < closures.size(); i++) { + if (wait_for_python && i == 0) + wait_for_python_ready_sigblock(); + + auto pid = fork(); + if (pid < 0) { + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { + close(pipe.first); + close(pipe.second); + }); + + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + throw std::system_error(errno, std::generic_category(), "failed to fork"); + } + + if (pid == 0) { + // close all pipes except our write half + for (size_t j = 0; j < pipes.size(); j++) { + if (i == j) + close(pipes[i].first); + else { + close(pipes[j].first); + close(pipes[j].second); + } + } + + fork_wrapper(closures[i], timeout_secs, pipes[i].second); + std::exit(EXIT_SUCCESS); + } + + pids.push_back(pid); + + if (wait_for_python && i == 0) + wait_for_python_ready_sigwait(3); + } + + // close all write halves of the pipes + for (auto pipe: pipes) + close(pipe.second); + + std::vector pfds {}; + + OwnedFd timerfd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK); + if (timerfd < 0) { + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + throw std::system_error(errno, std::generic_category(), "failed to create timerfd"); + } + + pfds.push_back({.fd = timerfd.fd, .events = POLL_IN, .revents = 0}); + for (auto pipe: pipes) + pfds.push_back({ + .fd = pipe.first, + .events = POLL_IN, + .revents = 0, + }); + + itimerspec spec { + .it_interval = + { + .tv_sec = 0, + .tv_nsec = 0, + }, + .it_value = { + .tv_sec = timeout_secs, + .tv_nsec = 0, + }}; + + if (timerfd_settime(timerfd, 0, &spec, nullptr) < 0) { + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + throw std::system_error(errno, std::generic_category(), "failed to set timerfd"); + } + + std::vector> results(pids.size(), std::nullopt); + + for (;;) { + auto n = poll(pfds.data(), pfds.size(), -1); + if (n < 0) { + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + throw std::system_error(errno, std::generic_category(), "failed to poll: "); + } + + for (auto& pfd: std::vector(pfds)) { + if (pfd.revents == 0) + continue; + + // timed out + if (pfd.fd == timerfd) { + std::println("Timed out!"); + + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + return TestResult::Failure; + } + + TestResult result = TestResult::Failure; + char buffer = 0; + if (read(pfd.fd, &buffer, sizeof(TestResult)) <= 0) + result = TestResult::Failure; + else + result = (TestResult)buffer; + + auto elem = std::find_if(pipes.begin(), pipes.end(), [fd = pfd.fd](auto pipe) { return pipe.first == fd; }); + auto idx = elem - pipes.begin(); + results[idx] = result; + + std::println("Process[{}] completed with {}", idx, result == TestResult::Success ? "Success" : "Failure"); + + // this subprocess is done, remove its pipe from the poll fds + pfds.erase(std::remove_if(pfds.begin(), pfds.end(), [&](auto p) { return p.fd == pfd.fd; }), pfds.end()); + + auto done = std::all_of(results.begin(), results.end(), [](auto result) { return result.has_value(); }); + if (done) + goto end; // justification for goto: breaks out of two levels of loop + } + } + +end: + + std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); + + int status = 0; + std::for_each(pids.begin(), pids.end(), [&status](const auto& pid) { + if (waitpid(pid, &status, 0) < 0) + std::println(stderr, "failed to wait on a subprocess"); + }); + + if (std::ranges::any_of(results, [](auto x) { return x == TestResult::Failure; })) + return TestResult::Failure; + + return TestResult::Success; +} + +inline TestResult run_python(const char* path, std::vector argv = {}) +{ + // insert the pid at the start of the argv, this is important for signalling readiness + pid_t pid = getppid(); + auto pid_ws = std::to_wstring(pid); + argv.insert(argv.begin(), pid_ws.c_str()); + + PyStatus status; + PyConfig config; + PyConfig_InitPythonConfig(&config); + + status = PyConfig_SetBytesString(&config, &config.program_name, "mitm"); + if (PyStatus_Exception(status)) + goto exception; + + status = Py_InitializeFromConfig(&config); + if (PyStatus_Exception(status)) + goto exception; + PyConfig_Clear(&config); + + argv.insert(argv.begin(), L"mitm"); + PySys_SetArgv(argv.size(), (wchar_t**)argv.data()); + + { + auto file = fopen(path, "r"); + if (!file) + throw std::system_error(errno, std::generic_category(), "failed to open python file"); + + PyRun_SimpleFile(file, path); + fclose(file); + } + + if (Py_FinalizeEx() < 0) { + std::println("finalization failure"); + return TestResult::Failure; + } + + return TestResult::Success; + +exception: + PyConfig_Clear(&config); + Py_ExitStatusException(status); + + return TestResult::Failure; +} + +inline TestResult run_mitm( + std::string testcase, + std::string mitm_ip, + uint16_t mitm_port, + std::string remote_ip, + uint16_t remote_port, + std::vector extra_args = {}) +{ + // we build the args for the user to make calling the function more convenient + std::vector args { + testcase, mitm_ip, std::to_string(mitm_port), remote_ip, std::to_string(remote_port)}; + + for (auto arg: extra_args) + args.push_back(arg); + + // we need to convert to wide strings to pass to Python + std::vector wide_args_owned {}; + + // the strings are ascii so we can just make them into wstrings + for (const auto& str: args) + wide_args_owned.emplace_back(str.begin(), str.end()); + + std::vector wide_args {}; + for (const auto& wstr: wide_args_owned) + wide_args.push_back(wstr.c_str()); + + return run_python("tests/cc_ymq/py_mitm/main.py", wide_args); +} diff --git a/tests/cc_ymq/py_mitm/__init__.py b/tests/cc_ymq/py_mitm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/cc_ymq/py_mitm/main.py b/tests/cc_ymq/py_mitm/main.py new file mode 100644 index 000000000..edaeba569 --- /dev/null +++ b/tests/cc_ymq/py_mitm/main.py @@ -0,0 +1,152 @@ +# flake8: noqa: E402 + +""" +This script provides a framework for running MITM test cases +""" + +import argparse +import os +import sys +import importlib +import signal +import subprocess +from tests.cc_ymq.py_mitm.types import MITMProtocol, TCPConnection +from scapy.all import IP, TCP, TunTapInterface # type: ignore + + +def echo_call(cmd: list[str]): + print(f"+ {' '.join(cmd)}") + subprocess.check_call(cmd) + + +def create_tuntap_interface(iface_name: str, mitm_ip: str, remote_ip: str) -> TunTapInterface: + """ + Creates a TUNTAP interface and sets brings it up and adds ips using the `ip` program + + Args: + iface_name: The name of the TUNTAP interface, usually like `tun0`, `tun1`, etc. + mitm_ip: The desired ip address of the mitm. This is the ip that clients can use to connect to the mitm + remote_ip: The ip that routes to/from the tuntap interface. + packets sent to `mitm_ip` will appear to come from `remote_ip`,\ + and conversely the tuntap interface can connect/send packets + to `remote_ip`, making it a suitable ip for binding a server + + Returns: + The TUNTAP interface + """ + iface = TunTapInterface(iface_name, mode="tun") + + try: + echo_call(["sudo", "ip", "link", "set", iface_name, "up"]) + echo_call(["sudo", "ip", "addr", "add", remote_ip, "peer", mitm_ip, "dev", iface_name]) + print(f"[+] Interface {iface_name} up with IP {mitm_ip}") + except subprocess.CalledProcessError: + print("[!] Could not bring up interface. Run as root or set manually.") + raise + + return iface + + +def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: int, mitm: MITMProtocol): + """ + This function serves as a framework for man in the middle implementations + A client connects to the MITM, then the MITM connects to a remote server + The MITM sits inbetween the client and the server, manipulating the packets sent depending on the test case + This function: + 1. creates a TUNTAP interface and prepares it for MITM + 2. handles connecting clients and handling connection closes + 3. delegates additional logic to a pluggable callable, `mitm` + 4. returns when both connections have terminated (via ) + + Args: + pid: this is the pid of the test process, used for signaling readiness \ + we send SIGUSR1 to this process when the mitm is ready + mitm_ip: The desired ip address of the mitm server + mitm_port: The desired port of the mitm server. \ + This is the port used to connect to the server, but the client is free to connect on any port + remote_ip: The desired remote ip for the TUNTAP interface. This is the only ip address \ + reachable by the interface and is thus the src ip for clients, and the ip that the remote server \ + must be bound to + server_port: The port that the remote server is bound to + mitm: The core logic for a MITM test case. This callable may maintain its own state and is responsible \ + for sending packets over the TUNTAP interface (if it doesn't, nothing will happen) + """ + + tuntap = create_tuntap_interface("tun0", mitm_ip, remote_ip) + + # signal the caller that the tuntap interface has been created + if pid > 0: + os.kill(pid, signal.SIGUSR1) + + # these track information about our connections + # we already know what to expect for the server connection, we are the connector + client_conn = None + server_conn = TCPConnection(mitm_ip, mitm_port, remote_ip, server_port) + + # tracks the state of each connection + client_sent_fin_ack = False + client_closed = False + server_sent_fin_ack = False + server_closed = False + + while True: + pkt = tuntap.recv() + if not pkt.haslayer(IP) or not pkt.haslayer(TCP): + continue + ip = pkt[IP] + tcp = pkt[TCP] + + # for a received packet, the destination ip and port are our local ip and port + # and the source ip and port will be the remote ip and port + sender = TCPConnection(pkt.dst, pkt.dport, pkt.src, pkt.sport) + + if sender == client_conn: + print(f"-> [{tcp.flags}]{(': ' + str(bytes(tcp.payload))) if tcp.payload else ''}") + elif sender == server_conn: + print(f"<- [{tcp.flags}]{(': ' + str(bytes(tcp.payload))) if tcp.payload else ''}") + + if tcp.flags == "S": # SYN from client + print("-> [S]") + print(f"[*] New connection from {ip.src}:{tcp.sport} to {ip.dst}:{tcp.dport}") + client_conn = sender + + if tcp.flags == "SA": # SYN-ACK from server + if sender == server_conn: + print(f"[*] Connection to server established: {ip.src}:{tcp.sport} to {ip.dst}:{tcp.dport}") + + if tcp.flags == "FA": # FIN-ACK + if sender == client_conn: + client_sent_fin_ack = True + if sender == server_conn: + server_sent_fin_ack = True + + if tcp.flags == "A": # ACK + if sender == client_conn and server_sent_fin_ack: + server_closed = True + if sender == server_conn and client_sent_fin_ack: + client_closed = True + + mitm.proxy(tuntap, pkt, sender, client_conn, server_conn) + + if client_closed and server_closed: + print("[*] Both connections closed") + return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Man in the middle test framework") + parser.add_argument("pid", type=int, help="The pid of the test process, used for signaling") + parser.add_argument("testcase", type=str, help="The MITM test case module name") + parser.add_argument("mitm_ip", type=str, help="The desired ip address of the mitm server") + parser.add_argument("mitm_port", type=int, help="The desired port of the mitm server") + parser.add_argument("remote_ip", type=str, help="The desired remote ip for the TUNTAP interface") + parser.add_argument("server_port", type=int, help="The port that the remote server is bound to") + + args, unknown = parser.parse_known_args() + + # add the script's directory to path + sys.path.append(os.path.dirname(os.path.realpath(__file__))) + + # load the module dynamically + module = importlib.import_module(args.testcase) + main(args.pid, args.mitm_ip, args.mitm_port, args.remote_ip, args.server_port, module.MITM(*unknown)) diff --git a/tests/cc_ymq/py_mitm/passthrough.py b/tests/cc_ymq/py_mitm/passthrough.py new file mode 100644 index 000000000..20d8a9069 --- /dev/null +++ b/tests/cc_ymq/py_mitm/passthrough.py @@ -0,0 +1,23 @@ +""" +This MITM acts as a transparent passthrough, it simply forwards packets as they are, +minus necessary header changes to retransmit +This MITM should have no effect on the client and server, +and they should behave as if the MITM is not present +""" + +from tests.cc_ymq.py_mitm.types import MITMProtocol, TunTapInterface, IP, TCPConnection + + +class MITM(MITMProtocol): + def proxy( + self, + tuntap: TunTapInterface, + pkt: IP, + sender: TCPConnection, + client_conn: TCPConnection | None, + server_conn: TCPConnection, + ) -> None: + if sender == client_conn: + tuntap.send(server_conn.rewrite(pkt)) + elif sender == server_conn: + tuntap.send(client_conn.rewrite(pkt)) diff --git a/tests/cc_ymq/py_mitm/randomly_drop_packets.py b/tests/cc_ymq/py_mitm/randomly_drop_packets.py new file mode 100644 index 000000000..a197ac3c8 --- /dev/null +++ b/tests/cc_ymq/py_mitm/randomly_drop_packets.py @@ -0,0 +1,28 @@ +""" +This MITM drops a % of packets +""" + +import random +from tests.cc_ymq.py_mitm.types import MITMProtocol, TunTapInterface, IP, TCPConnection + + +class MITM(MITMProtocol): + def __init__(self, drop_pcent: str): + self.drop_pcent = float(drop_pcent) + + def proxy( + self, + tuntap: TunTapInterface, + pkt: IP, + sender: TCPConnection, + client_conn: TCPConnection | None, + server_conn: TCPConnection, + ) -> None: + if random.random() < self.drop_pcent: + print("[!] Dropping packet") + return + + if sender == client_conn: + tuntap.send(server_conn.rewrite(pkt)) + elif sender == server_conn: + tuntap.send(client_conn.rewrite(pkt)) diff --git a/tests/cc_ymq/py_mitm/send_rst_to_client.py b/tests/cc_ymq/py_mitm/send_rst_to_client.py new file mode 100644 index 000000000..fc70355e5 --- /dev/null +++ b/tests/cc_ymq/py_mitm/send_rst_to_client.py @@ -0,0 +1,48 @@ +""" +This MITM inserts an unexpected TCP RST +""" + +from tests.cc_ymq.py_mitm.types import IP, TCP, MITMProtocol, TCPConnection, TunTapInterface + + +class MITM(MITMProtocol): + def __init__(self): + # count the number of psh-acks sent by the client + self.client_pshack_counter = 0 + + def proxy( + self, + tuntap: TunTapInterface, + pkt: IP, + sender: TCPConnection, + client_conn: TCPConnection | None, + server_conn: TCPConnection, + ) -> None: + if sender == client_conn or client_conn is None: + if pkt[TCP].flags == "PA": + self.client_pshack_counter += 1 + + # on the second psh-ack, send a rst instead + if self.client_pshack_counter == 2: + rst_pkt = IP(src=client_conn.local_ip, dst=client_conn.remote_ip) / TCP( + sport=client_conn.local_port, dport=client_conn.remote_port, flags="R", seq=pkt[TCP].ack + ) + print(f"<- [{rst_pkt[TCP].flags}] (simulated)") + tuntap.send(rst_pkt) + return + + tuntap.send(server_conn.rewrite(pkt)) + elif sender == server_conn: + tuntap.send(client_conn.rewrite(pkt)) + + +# client -> mitm -> server +# server -> mitm -> client + +# client: 127.0.0.1:8080 +# mitm: 127.0.0.1:8081 +# server: 127.0.0.1:8081 + + +# client -> mitm == src = client.ip, sport = client.port ;; dst = mitm.ip, dport = mitm.port +# mitm -> server == src = mitm.ip, sport = mitm.port ;; dst = server.ip, dport = server.port diff --git a/tests/cc_ymq/py_mitm/types.py b/tests/cc_ymq/py_mitm/types.py new file mode 100644 index 000000000..4a22ee01a --- /dev/null +++ b/tests/cc_ymq/py_mitm/types.py @@ -0,0 +1,54 @@ +""" +This is the common code for implementing man in the middle in Python +""" + +import dataclasses +from typing import Protocol +from scapy.all import TunTapInterface, IP, TCP # type: ignore + + +@dataclasses.dataclass +class TCPConnection: + """ + Represents a TCP connection over the TUNTAP interface + local_ip and local_port are the mitm's ip and port, and + remote_ip and remote_port are the port for the remote peer + """ + + local_ip: str + local_port: int + remote_ip: str + remote_port: int + + def rewrite(self, pkt: IP, ack: int | None = None, data=None): + """ + Rewrite a TCP/IP packet as a packet originating + from (local_ip, local_port) and going to (remote_ip, remote_port) + This function is useful for taking a packet received from one connection, and redirecting it to another + + Args: + pkt: A scapy TCP/IP packet to rewrite + ack: An optional ack number to use instead of the one found in `pkt` + data: An optional payload to use instead of the one found int `pkt` + + Returns: + The rewritten packet, suitable for sending over TUNTAP + """ + tcp = pkt[TCP] + + return ( + IP(src=self.local_ip, dst=self.remote_ip) + / TCP(sport=self.local_port, dport=self.remote_port, flags=tcp.flags, seq=tcp.seq, ack=ack or tcp.ack) + / bytes(data or tcp.payload) + ) + + +class MITMProtocol(Protocol): + def proxy( + self, + tuntap: TunTapInterface, + pkt: IP, + sender: TCPConnection, + client_conn: TCPConnection | None, + server_conn: TCPConnection, + ) -> None: ... diff --git a/tests/cc_ymq/test_cc_ymq.cpp b/tests/cc_ymq/test_cc_ymq.cpp new file mode 100644 index 000000000..1e7872b5c --- /dev/null +++ b/tests/cc_ymq/test_cc_ymq.cpp @@ -0,0 +1,508 @@ +// this file contains the tests for the C++ interface of YMQ +// each test case is comprised of at least one client and one server, and possibly a middleman +// the clients and servers used in these tests are defined in the first part of this file +// +// the men in the middle (mitm) are implemented using Python and are found in py_mitm/ +// in that directory, `main.py` is the entrypoint and framework for all the mitm, +// and the individual mitm implementations are found in their respective files +// +// the test cases are at the bottom of this file, after the clients and servers +// the documentation for each case is found on the TEST() definition + +#include +#include + +#include +#include +#include +#include +#include + +#include "common.h" +#include "scaler/io/ymq/bytes.h" +#include "scaler/io/ymq/io_context.h" +#include "scaler/io/ymq/simple_interface.h" +#include "tests/cc_ymq/common.h" + +using namespace scaler::ymq; +using namespace std::chrono_literals; + +// ━━━━━━━━━━━━━━━━━━━ +// clients and servers +// ━━━━━━━━━━━━━━━━━━━ + +TestResult basic_server_ymq(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); + auto result = syncRecvMessage(socket); + + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "yi er san si wu liu"); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult basic_client_ymq(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, format_address(host, port)); + auto result = syncSendMessage(socket, {.address = Bytes("server"), .payload = Bytes("yi er san si wu liu")}); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult basic_server_raw(std::string host, uint16_t port) +{ + TcpSocket socket; + + socket.bind(host.c_str(), port); + socket.listen(); + auto [client, _] = socket.accept(); + client.write_message("server"); + auto client_identity = client.read_message(); + RETURN_FAILURE_IF_FALSE(client_identity == "client"); + auto msg = client.read_message(); + RETURN_FAILURE_IF_FALSE(msg == "yi er san si wu liu"); + + return TestResult::Success; +} + +TestResult basic_client_raw(int delay, std::string host, uint16_t port) +{ + TcpSocket socket; + + socket.connect(host.c_str(), port); + socket.write_message("client"); + auto server_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(server_identity == "server"); + socket.write_message("yi er san si wu liu"); + + if (delay) + std::this_thread::sleep_for(std::chrono::seconds(delay)); + + return TestResult::Success; +} + +TestResult server_receives_big_message(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); + auto result = syncRecvMessage(socket); + + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.len() == 500'000'000); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult client_sends_big_message(int delay, std::string host, uint16_t port) +{ + TcpSocket socket; + + socket.connect(host.c_str(), port); + socket.write_message("client"); + auto remote_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(remote_identity == "server"); + std::string msg(500'000'000, '.'); + socket.write_message(msg); + + if (delay) + std::this_thread::sleep_for(std::chrono::seconds(delay)); + + return TestResult::Success; +} + +TestResult reconnect_server_main(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); + auto result = syncRecvMessage(socket); + + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "hello!!"); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult reconnect_client_main(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, format_address(host, port)); + auto result = syncSendMessage(socket, {.address = Bytes("server"), .payload = Bytes("hello!!")}); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult client_simulated_slow_network(const char* host, uint16_t port) +{ + TcpSocket socket; + + socket.connect(host, port); + socket.write_message("client"); + auto remote_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(remote_identity == "server"); + + std::string message = "yi er san si wu liu"; + uint64_t header = message.length(); + + socket.write_all((char*)&header, 4); + std::this_thread::sleep_for(5s); + socket.write_all((char*)&header + 4, 4); + std::this_thread::sleep_for(3s); + socket.write_all(message.data(), header / 2); + std::this_thread::sleep_for(5s); + socket.write_all(message.data() + header / 2, header - header / 2); + std::this_thread::sleep_for(3s); + + return TestResult::Success; +} + +TestResult client_sends_incomplete_identity(const char* host, uint16_t port) +{ + // open a socket, write an incomplete identity and exit + { + TcpSocket socket; + + socket.connect(host, port); + + auto server_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(server_identity == "server"); + + // write incomplete identity and exit + std::string identity = "client"; + uint64_t header = identity.length(); + socket.write_all((char*)&header, 8); + socket.write_all(identity.data(), identity.length() - 2); + std::this_thread::sleep_for(3s); + } + + // connect again and try to send a message + { + TcpSocket socket; + socket.connect(host, port); + auto server_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(server_identity == "server"); + socket.write_message("client"); + socket.write_message("yi er san si wu liu"); + std::this_thread::sleep_for(3s); + } + + return TestResult::Success; +} + +TestResult server_receives_huge_header(const char* host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); + auto result = syncRecvMessage(socket); + + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "yi er san si wu liu"); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult client_sends_huge_header(const char* host, uint16_t port) +{ + TcpSocket socket; + + socket.connect(host, port); + socket.write_message("client"); + auto server_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(server_identity == "server"); + + // write the huge header + uint64_t header = std::numeric_limits::max(); + socket.write_all((char*)&header, 8); + + // TODO: this sleep shouldn't be necessary + std::this_thread::sleep_for(3s); + + return TestResult::Success; +} + +TestResult server_receives_empty_messages(const char* host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); + + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == ""); + + auto result2 = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result2.has_value()); + RETURN_FAILURE_IF_FALSE(result2->payload.as_string() == ""); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +TestResult client_sends_empty_messages(std::string host, uint16_t port) +{ + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, format_address(host, port)); + + auto error = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes()}); + RETURN_FAILURE_IF_FALSE(!error); + + auto error2 = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes("")}); + RETURN_FAILURE_IF_FALSE(!error2); + + context.removeIOSocket(socket); + + return TestResult::Success; +} + +// ━━━━━━━━━━━━━ +// test cases +// ━━━━━━━━━━━━━ + +// this is a 'basic' test which sends a single message from a client to a server +// in this variant, both the client and server are implemented using YMQ +// +// this case includes a _delay_ +// this is a thread sleep that happens after the client sends the message, to delay the close() of the socket +// at the moment, if this delay is missing, YMQ will not shut down correctly +TEST(CcYmqTestSuite, TestBasicYMQClientYMQServer) +{ + auto host = "localhost"; + auto port = 2889; + + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = + test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_ymq(host, port); }}); + + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} + +// same as above, except YMQs protocol is directly implemented on top of a TCP socket +TEST(CcYmqTestSuite, TestBasicRawClientYMQServer) +{ + auto host = "localhost"; + auto port = 2890; + + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = + test(10, {[=] { return basic_client_raw(5, host, port); }, [=] { return basic_server_ymq(host, port); }}); + + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} + +TEST(CcYmqTestSuite, TestBasicRawClientRawServer) +{ + auto host = "localhost"; + auto port = 2891; + + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = + test(10, {[=] { return basic_client_raw(0, host, port); }, [=] { return basic_server_raw(host, port); }}); + + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} + +// TODO: this should pass +// this is the same as above, except that it has no delay before calling close() on the socket +// this test hangs +TEST(CcYmqTestSuite, DISABLED_TestBasicRawClientRawServerNoDelay) +{ + auto host = "localhost"; + auto port = 2892; + + auto result = + test(10, {[=] { return basic_client_raw(0, host, port); }, [=] { return basic_server_ymq(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} + +TEST(CcYmqTestSuite, TestBasicDelayYMQClientRawServer) +{ + auto host = "localhost"; + auto port = 2893; + + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = + test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_raw(host, port); }}); + + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} + +// in this test case, the client sends a large message to the server +// YMQ should be able to handle this without issue +TEST(CcYmqTestSuite, TestClientSendBigMessageToServer) +{ + auto host = "localhost"; + auto port = 2894; + + auto result = test( + 10, + {[=] { return client_sends_big_message(5, host, port); }, + [=] { return server_receives_big_message(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} + +// this is the no-op/passthrough man in the middle test +// for this test case we use YMQ on both the client side and the server side +// the client connects to the mitm, and the mitm connects to the server +// when the mitm receives packets from the client, it forwards it to the server without changing it +// and similarly when it receives packets from the server, it forwards them to the client +// +// the mitm is implemented in Python. we pass the name of the test case, which corresponds to the Python filename, +// and a list of arguments, which are: mitm ip, mitm port, remote ip, remote port +// this defines the address of the mitm, and the addresses that can connect to it +// for more, see the python mitm files +TEST(CcYmqTestSuite, TestMitmPassthrough) +{ + auto mitm_ip = "192.0.2.4"; + auto mitm_port = 2323; + auto remote_ip = "192.0.2.3"; + auto remote_port = 23571; + + // the Python program must be the first and only the first function passed to test() + // we must also pass `true` as the third argument to ensure that Python is fully started + // before beginning the test + auto result = test( + 20, + {[=] { return run_mitm("passthrough", mitm_ip, mitm_port, remote_ip, remote_port); }, + [=] { return basic_client_ymq(mitm_ip, mitm_port); }, + [=] { return basic_server_ymq(remote_ip, remote_port); }}, + true); + EXPECT_EQ(result, TestResult::Success); +} + +// this test uses the mitm to test the reconnect logic of YMQ by sending RST packets +// this test is disabled until fixes arrive in the core +TEST(CcYmqTestSuite, DISABLED_TestMitmReconnect) +{ + auto mitm_ip = "192.0.2.4"; + auto mitm_port = 2525; + auto remote_ip = "192.0.2.3"; + auto remote_port = 23575; + + auto result = test( + 10, + {[=] { return run_mitm("send_rst_to_client", mitm_ip, mitm_port, remote_ip, remote_port); }, + [=] { return reconnect_client_main(mitm_ip, mitm_port); }, + [=] { return reconnect_server_main(remote_ip, remote_port); }}, + true); + EXPECT_EQ(result, TestResult::Success); +} + +// TODO: Make this more reliable, and re-enable it +// in this test, the mitm drops a random % of packets arriving from the client and server +TEST(CcYmqTestSuite, DISABLED_TestMitmRandomlyDropPackets) +{ + auto mitm_ip = "192.0.2.4"; + auto mitm_port = 2828; + auto remote_ip = "192.0.2.3"; + auto remote_port = 23591; + + auto result = test( + 60, + {[=] { return run_mitm("randomly_drop_packets", mitm_ip, mitm_port, remote_ip, remote_port, {"0.3"}); }, + [=] { return basic_client_ymq(mitm_ip, mitm_port); }, + [=] { return basic_server_ymq(remote_ip, remote_port); }}, + true); + EXPECT_EQ(result, TestResult::Success); +} + +// in this test the client is sending a message to the server +// but we simulate a slow network connection by sending the message in segmented chunks +TEST(CcYmqTestSuite, TestSlowNetwork) +{ + auto host = "localhost"; + auto port = 2895; + + auto result = test( + 20, {[=] { return client_simulated_slow_network(host, port); }, [=] { return basic_server_ymq(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} + +// TODO: figure out why this test fails in ci sometimes, and re-enable +// +// in this test, a client connects to the YMQ server but only partially sends its identity and then disconnects +// then a new client connection is established, and this one sends a complete identity and message +// YMQ should be able to recover from a poorly-behaved client like this +TEST(CcYmqTestSuite, DISABLED_TestClientSendIncompleteIdentity) +{ + auto host = "localhost"; + auto port = 2896; + + auto result = test( + 20, + {[=] { return client_sends_incomplete_identity(host, port); }, [=] { return basic_server_ymq(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} + +// TODO: this should pass +// in this test, the client sends an unrealistically-large header +// it is important that YMQ checks the header size before allocating memory +// both for resilence against attacks and to guard against errors +// +// at the moment YMQ does not perform this check and throws std::bad_alloc +// this test can be re-enabled after this is fixed +TEST(CcYmqTestSuite, DISABLED_TestClientSendHugeHeader) +{ + auto host = "localhost"; + auto port = 2897; + + auto result = test( + 20, + {[=] { return client_sends_huge_header(host, port); }, + [=] { return server_receives_huge_header(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} + +// in this test, the client sends empty messages to the server +// there are in effect two kinds of empty messages: Bytes() and Bytes("") +// in the former case, the bytes contains a nullptr +// in the latter case, the bytes contains a zero-length allocation +// it's important that the behaviour of YMQ is known for both of these cases +TEST(CcYmqTestSuite, TestClientSendEmptyMessage) +{ + auto host = "localhost"; + auto port = 2898; + + auto result = test( + 20, + {[=] { return client_sends_empty_messages(host, port); }, + [=] { return server_receives_empty_messages(host, port); }}); + EXPECT_EQ(result, TestResult::Success); +} diff --git a/tests/object_storage/test_object_storage_server.cpp b/tests/object_storage/test_object_storage_server.cpp index 0d08a8f48..3392197c6 100644 --- a/tests/object_storage/test_object_storage_server.cpp +++ b/tests/object_storage/test_object_storage_server.cpp @@ -1,6 +1,5 @@ #include -#include #include #include #include @@ -8,7 +7,6 @@ #include "scaler/io/ymq/io_context.h" #include "scaler/io/ymq/io_socket.h" -#include "scaler/io/ymq/logging.h" #include "scaler/io/ymq/simple_interface.h" #include "scaler/object_storage/object_storage_server.h" @@ -39,8 +37,8 @@ class ObjectStorageClient { void writeYMQMessage(Message message) { - auto res = syncSendMessage(ioSocket, std::move(message)); - ASSERT_TRUE(res.has_value()); + auto error = syncSendMessage(ioSocket, std::move(message)); + ASSERT_TRUE(!error); } auto readYMQMessage() { return syncRecvMessage(ioSocket); } @@ -63,17 +61,17 @@ class ObjectStorageClient { void readResponse(ObjectResponseHeader& header, std::optional& payload) { std::array buf {}; - auto [message, error] = syncRecvMessage(ioSocket); - ASSERT_EQ(error._errorCode, Error::ErrorCode::Uninit); + auto result = syncRecvMessage(ioSocket); + ASSERT_TRUE(result.has_value()); - memcpy(buf.begin(), message.payload.data(), CAPNP_HEADER_SIZE); - ASSERT_EQ(message.payload.size(), CAPNP_HEADER_SIZE); + memcpy(buf.begin(), result->payload.data(), CAPNP_HEADER_SIZE); + ASSERT_EQ(result->payload.size(), CAPNP_HEADER_SIZE); header = ObjectResponseHeader::fromBuffer(buf); if (header.payloadLength > 0) { - auto [message2, error2] = syncRecvMessage(ioSocket); - ASSERT_EQ(error2._errorCode, Error::ErrorCode::Uninit); - payload.emplace(message2.payload); + auto result2 = syncRecvMessage(ioSocket); + ASSERT_TRUE(result2.has_value()); + payload.emplace(result2->payload); } else { payload.reset(); } @@ -529,7 +527,10 @@ TEST_F(ObjectStorageServerTest, TestClientDisconnect) } } -TEST_F(ObjectStorageServerTest, TestMalformedHeader) +// TODO: why does this not pass? +// the message connection tcp is removed from the remote socket's list +// but the object is never destructued, and so the connection is never closed +TEST_F(ObjectStorageServerTest, DISABLED_TestMalformedHeader) { ObjectResponseHeader responseHeader; std::optional responsePayload; @@ -546,8 +547,9 @@ TEST_F(ObjectStorageServerTest, TestMalformedHeader) client->writeYMQMessage(std::move(message)); // Server should disconnect before or while we are reading the response - auto [msg, err] = client->readYMQMessage(); - EXPECT_EQ(err._errorCode, Error::ErrorCode::ConnectorSocketClosedByRemoteEnd); + auto result = client->readYMQMessage(); + EXPECT_TRUE(!result); + EXPECT_EQ(result.error()._errorCode, Error::ErrorCode::ConnectorSocketClosedByRemoteEnd); } // Server must still answers to requests from other clients diff --git a/tests/pymod_ymq/__init__.py b/tests/pymod_ymq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/pymod_ymq/test_pymod_ymq.py b/tests/pymod_ymq/test_pymod_ymq.py new file mode 100644 index 000000000..ee96914e6 --- /dev/null +++ b/tests/pymod_ymq/test_pymod_ymq.py @@ -0,0 +1,150 @@ +import multiprocessing.connection +import unittest +from scaler.io.ymq import ymq +import asyncio +import multiprocessing + + +class TestPymodYMQ(unittest.IsolatedAsyncioTestCase): + async def test_basic(self): + ctx = ymq.IOContext() + binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder) + self.assertEqual(binder.identity, "binder") + self.assertEqual(binder.socket_type, ymq.IOSocketType.Binder) + + connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector) + self.assertEqual(connector.identity, "connector") + self.assertEqual(connector.socket_type, ymq.IOSocketType.Connector) + + await binder.bind("tcp://127.0.0.1:35791") + await connector.connect("tcp://127.0.0.1:35791") + + await connector.send(ymq.Message(address=None, payload=b"payload")) + msg = await binder.recv() + + assert msg.address is not None + self.assertEqual(msg.address.data, b"connector") + self.assertEqual(msg.payload.data, b"payload") + + @unittest.skip("this test currently hangs, see comment in the code") + async def test_no_address(self): + # this test requires special care because it hangs and doesn't shut down the worker threads properly + # we use a subprocess to shield us from any effects + pipe_parent, pipe_child = multiprocessing.Pipe(duplex=False) + + def test(pipe: multiprocessing.connection.Connection) -> None: + async def main(): + ctx = ymq.IOContext() + binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder) + connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector) + + await binder.bind("tcp://127.0.0.1:35791") + await connector.connect("tcp://127.0.0.1:35791") + + try: + # TODO: change to `asyncio.timeout()` in python >3.10 + await asyncio.wait_for(binder.send(ymq.Message(address=None, payload=b"payload")), 30) + + # TODO: solve the hang and write the rest of the test + pipe.send(True) + except asyncio.TimeoutError: + pipe.send(False) + + asyncio.run(main()) + + p = multiprocessing.Process(target=test, args=(pipe_child,)) + p.start() + result = pipe_parent.recv() + p.join(5) + if p.exitcode is None: + p.kill() + + if not result: + self.fail() + + async def test_routing(self): + ctx = ymq.IOContext() + binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder) + connector1 = await ctx.createIOSocket("connector1", ymq.IOSocketType.Connector) + connector2 = await ctx.createIOSocket("connector2", ymq.IOSocketType.Connector) + + await binder.bind("tcp://127.0.0.1:35791") + await connector1.connect("tcp://127.0.0.1:35791") + await connector2.connect("tcp://127.0.0.1:35791") + + await binder.send(ymq.Message(b"connector2", b"2")) + await binder.send(ymq.Message(b"connector1", b"1")) + + msg1 = await connector1.recv() + self.assertEqual(msg1.payload.data, b"1") + + msg2 = await connector2.recv() + self.assertEqual(msg2.payload.data, b"2") + + async def test_pingpong(self): + ctx = ymq.IOContext() + binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder) + connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector) + + await binder.bind("tcp://127.0.0.1:35791") + await connector.connect("tcp://127.0.0.1:35791") + + async def binder_routine(binder: ymq.IOSocket, limit: int) -> bool: + i = 0 + while i < limit: + await binder.send(ymq.Message(address=b"connector", payload=f"{i}".encode())) + msg = await binder.recv() + assert msg.payload.data is not None + + recv_i = int(msg.payload.data.decode()) + if recv_i - i > 1: + return False + i = recv_i + 1 + return True + + async def connector_routine(connector: ymq.IOSocket, limit: int) -> bool: + i = 0 + while True: + msg = await connector.recv() + assert msg.payload.data is not None + recv_i = int(msg.payload.data.decode()) + if recv_i - i > 1: + return False + i = recv_i + 1 + await connector.send(ymq.Message(address=None, payload=f"{i}".encode())) + + # when the connector sends `limit - 1`, we're done + if i >= limit - 1: + break + return True + + binder_success, connector_success = await asyncio.gather( + binder_routine(binder, 100), connector_routine(connector, 100) + ) + + if not binder_success: + self.fail("binder failed") + + if not connector_success: + self.fail("connector failed") + + async def test_big_message(self): + ctx = ymq.IOContext() + binder = await ctx.createIOSocket("binder", ymq.IOSocketType.Binder) + self.assertEqual(binder.identity, "binder") + self.assertEqual(binder.socket_type, ymq.IOSocketType.Binder) + + connector = await ctx.createIOSocket("connector", ymq.IOSocketType.Connector) + self.assertEqual(connector.identity, "connector") + self.assertEqual(connector.socket_type, ymq.IOSocketType.Connector) + + await binder.bind("tcp://127.0.0.1:35791") + await connector.connect("tcp://127.0.0.1:35791") + + for _ in range(10): + await connector.send(ymq.Message(address=None, payload=b"." * 500_000_000)) + msg = await binder.recv() + + assert msg.address is not None + self.assertEqual(msg.address.data, b"connector") + self.assertEqual(msg.payload.data, b"." * 500_000_000) diff --git a/tests/pymod_ymq/test_types.py b/tests/pymod_ymq/test_types.py new file mode 100644 index 000000000..e461856e0 --- /dev/null +++ b/tests/pymod_ymq/test_types.py @@ -0,0 +1,90 @@ +import unittest +from enum import IntEnum +from scaler.io.ymq import ymq +import array + + +class TestTypes(unittest.TestCase): + def test_exception(self): + # type checkers misidentify this as "unnecessary" due to the type hints file + self.assertTrue(issubclass(ymq.YMQException, Exception)) # type: ignore + + exc = ymq.YMQException(ymq.ErrorCode.CoreBug, "oh no") + self.assertEqual(exc.args, (ymq.ErrorCode.CoreBug, "oh no")) + self.assertEqual(exc.code, ymq.ErrorCode.CoreBug) + self.assertEqual(exc.message, "oh no") + + def test_interrupted_exception(self): + self.assertTrue(issubclass(ymq.YMQInterruptedException, Exception)) # type: ignore + + exc = ymq.YMQInterruptedException() + self.assertEqual(exc.args, tuple()) + + def test_error_code(self): + self.assertTrue(issubclass(ymq.ErrorCode, IntEnum)) # type: ignore + self.assertEqual( + ymq.ErrorCode.ConfigurationError.explanation(), + "An error generated by system call that's likely due to mis-configuration", + ) + + def test_bytes(self): + b = ymq.Bytes(b"data") + self.assertEqual(b.len, len(b)) + self.assertEqual(b.len, 4) + self.assertEqual(b.data, b"data") + + # would raise an exception if ymq.Bytes didn't support the buffer interface + m = memoryview(b) + self.assertTrue(m.obj is b) + self.assertEqual(m.tobytes(), b"data") + + b = ymq.Bytes() + self.assertEqual(b.len, 0) + self.assertTrue(b.data is None) + + b = ymq.Bytes(b"") + self.assertEqual(b.len, 0) + self.assertEqual(b.data, b"") + + b = ymq.Bytes(array.array("B", [115, 99, 97, 108, 101, 114])) + assert b.len == 6 + assert b.data == b"scaler" + + def test_message(self): + m = ymq.Message(b"address", b"payload") + assert m.address is not None + self.assertEqual(m.address.data, b"address") + self.assertEqual(m.payload.data, b"payload") + + m = ymq.Message(address=None, payload=ymq.Bytes(b"scaler")) + self.assertTrue(m.address is None) + self.assertEqual(m.payload.data, b"scaler") + + m = ymq.Message(b"address", payload=b"payload") + assert m.address is not None + self.assertEqual(m.address.data, b"address") + self.assertEqual(m.payload.data, b"payload") + + def test_io_context(self): + ctx = ymq.IOContext() + self.assertEqual(ctx.num_threads, 1) + + ctx = ymq.IOContext(2) + self.assertEqual(ctx.num_threads, 2) + + ctx = ymq.IOContext(num_threads=3) + self.assertEqual(ctx.num_threads, 3) + + def test_io_socket(self): + # check that we can't create io socket instances directly + self.assertRaises(TypeError, lambda: ymq.IOSocket()) + + def test_io_socket_type(self): + self.assertTrue(issubclass(ymq.IOSocketType, IntEnum)) # type: ignore + + def test_bad_socket_type(self): + ctx = ymq.IOContext() + + # TODO: should the core reject this? + socket = ctx.createIOSocket_sync("identity", ymq.IOSocketType.Uninit) + self.assertEqual(socket.socket_type, ymq.IOSocketType.Uninit) diff --git a/tests/test_graph.py b/tests/test_graph.py index ed2da0295..c6971ecf5 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -208,13 +208,7 @@ def test_graph_capabilities(self): base_cluster = self.combo._cluster with Client(self.address) as client: - graph = { - "a": 1.3, - "b": 2.6, - "c": (round, "a"), - "d": (round, "b"), - "e": (add, "c", "d") - } + graph = {"a": 1.3, "b": 2.6, "c": (round, "a"), "d": (round, "b"), "e": (add, "c", "d")} future = client.get(graph, keys=["e"], capabilities={"gpu": 1}, block=False)["e"] From 627dddbc0f6373f1093de3e6601f4e0b753f4f6e Mon Sep 17 00:00:00 2001 From: gxu Date: Wed, 24 Sep 2025 10:02:20 +0800 Subject: [PATCH 2/3] YMQ Bug Fix Signed-off-by: gxu --- scaler/io/ymq/configuration.h | 3 +- scaler/io/ymq/message_connection_tcp.cpp | 39 +++++++++++++------ scaler/io/ymq/message_connection_tcp.h | 3 ++ tests/cc_ymq/test_cc_ymq.cpp | 24 ++++++++---- .../test_object_storage_server.cpp | 2 +- 5 files changed, 50 insertions(+), 21 deletions(-) diff --git a/scaler/io/ymq/configuration.h b/scaler/io/ymq/configuration.h index ad7a2de49..45e5cee86 100644 --- a/scaler/io/ymq/configuration.h +++ b/scaler/io/ymq/configuration.h @@ -34,7 +34,8 @@ template using MoveOnlyFunction = std::function; #endif -constexpr const uint64_t IOCP_SOCKET_CLOSED = 4; +constexpr const uint64_t IOCP_SOCKET_CLOSED = 4; +constexpr const uint64_t LARGEST_PAYLOAD_SIZE = 6000'000'000'000; // 6TB struct Configuration { #ifdef __linux__ diff --git a/scaler/io/ymq/message_connection_tcp.cpp b/scaler/io/ymq/message_connection_tcp.cpp index 21efe0d8a..2af348a3c 100644 --- a/scaler/io/ymq/message_connection_tcp.cpp +++ b/scaler/io/ymq/message_connection_tcp.cpp @@ -3,6 +3,8 @@ #include +#include "scaler/io/ymq/configuration.h" + #ifdef __linux__ #include #endif // __linux__ @@ -146,6 +148,9 @@ std::expected MessageConnectionTCP::tryRead readTo = (char*)&message._header + message._cursor; remainingSize = HEADER_SIZE - message._cursor; } else if (message._cursor == HEADER_SIZE) { + if (message._header >= LARGEST_PAYLOAD_SIZE) { + return std::unexpected {IOError::MessageTooLarge}; + } message._payload = Bytes::alloc(message._header); readTo = (char*)message._payload.data(); remainingSize = message._payload.len(); @@ -260,7 +265,6 @@ void MessageConnectionTCP::updateReadOperation() Bytes address(_remoteIOSocketIdentity->data(), _remoteIOSocketIdentity->size()); Bytes payload(std::move(_receivedReadOperations.front()._payload)); _receivedReadOperations.pop(); - auto recvMessageCallback = std::move(_pendingRecvMessageCallbacks->front()); _pendingRecvMessageCallbacks->pop(); @@ -272,6 +276,18 @@ void MessageConnectionTCP::updateReadOperation() } } +void MessageConnectionTCP::setRemoteIdentity() noexcept +{ + if (!_remoteIOSocketIdentity && + (_receivedReadOperations.size() || isCompleteMessage(_receivedReadOperations.front()))) { + auto id = std::move(_receivedReadOperations.front()); + _remoteIOSocketIdentity.emplace((char*)id._payload.data(), id._payload.len()); + _receivedReadOperations.pop(); + auto sock = this->_eventLoopThread->_identityToIOSocket[_localIOSocketIdentity]; + sock->onConnectionIdentityReceived(this); + } +} + void MessageConnectionTCP::onRead() { if (_connFd == 0) { @@ -279,11 +295,19 @@ void MessageConnectionTCP::onRead() } auto maybeCloseConn = [this](IOError err) -> std::expected { + setRemoteIdentity(); + + if (_remoteIOSocketIdentity) { + updateReadOperation(); + } + switch (err) { case IOError::Drained: return {}; - case IOError::Disconnected: _disconnect = true; break; case IOError::Aborted: _disconnect = false; break; + case IOError::Disconnected: _disconnect = true; break; + case IOError::MessageTooLarge: _disconnect = false; break; } + onClose(); return std::unexpected {err}; }; @@ -293,16 +317,7 @@ void MessageConnectionTCP::onRead() auto _ = tryReadOneMessage() .or_else(maybeCloseConn) // .and_then([this]() -> std::expected { - if (_receivedReadOperations.empty() || - !isCompleteMessage(_receivedReadOperations.front())) { - return {}; - } - - auto id = std::move(_receivedReadOperations.front()); - _remoteIOSocketIdentity.emplace((char*)id._payload.data(), id._payload.len()); - _receivedReadOperations.pop(); - auto sock = this->_eventLoopThread->_identityToIOSocket[_localIOSocketIdentity]; - sock->onConnectionIdentityReceived(this); + setRemoteIdentity(); return {}; }); return _remoteIOSocketIdentity; diff --git a/scaler/io/ymq/message_connection_tcp.h b/scaler/io/ymq/message_connection_tcp.h index 073693cb8..78390874e 100644 --- a/scaler/io/ymq/message_connection_tcp.h +++ b/scaler/io/ymq/message_connection_tcp.h @@ -54,6 +54,7 @@ class MessageConnectionTCP: public MessageConnection { Drained, Aborted, Disconnected, + MessageTooLarge, }; void onRead(); @@ -71,6 +72,8 @@ class MessageConnectionTCP: public MessageConnection { void updateWriteOperations(size_t n); void updateReadOperation(); + void setRemoteIdentity() noexcept; + std::unique_ptr _eventManager; int _connFd; sockaddr _localAddr; diff --git a/tests/cc_ymq/test_cc_ymq.cpp b/tests/cc_ymq/test_cc_ymq.cpp index 1e7872b5c..c4a51d761 100644 --- a/tests/cc_ymq/test_cc_ymq.cpp +++ b/tests/cc_ymq/test_cc_ymq.cpp @@ -136,6 +136,9 @@ TestResult reconnect_server_main(std::string host, uint16_t port) RETURN_FAILURE_IF_FALSE(result.has_value()); RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "hello!!"); + auto result2 = syncSendMessage(socket, {.address = Bytes("client"), .payload = Bytes("goodbye!!")}); + assert(result2); + context.removeIOSocket(socket); return TestResult::Success; @@ -149,6 +152,11 @@ TestResult reconnect_client_main(std::string host, uint16_t port) syncConnectSocket(socket, format_address(host, port)); auto result = syncSendMessage(socket, {.address = Bytes("server"), .payload = Bytes("hello!!")}); + printf("BEFORE CLI\n"); + auto result2 = syncRecvMessage(socket); + printf("AFTER CLI\n"); + assert(result2); + context.removeIOSocket(socket); return TestResult::Success; @@ -219,8 +227,7 @@ TestResult server_receives_huge_header(const char* host, uint16_t port) syncBindSocket(socket, format_address(host, port)); auto result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "yi er san si wu liu"); + // RETURN_FAILURE_IF_FALSE(result.error()._errorCode == scaler::ymq::Error::ErrorCode::MessageTooLarge); context.removeIOSocket(socket); @@ -240,9 +247,6 @@ TestResult client_sends_huge_header(const char* host, uint16_t port) uint64_t header = std::numeric_limits::max(); socket.write_all((char*)&header, 8); - // TODO: this sleep shouldn't be necessary - std::this_thread::sleep_for(3s); - return TestResult::Success; } @@ -340,7 +344,7 @@ TEST(CcYmqTestSuite, TestBasicRawClientRawServer) // TODO: this should pass // this is the same as above, except that it has no delay before calling close() on the socket // this test hangs -TEST(CcYmqTestSuite, DISABLED_TestBasicRawClientRawServerNoDelay) +TEST(CcYmqTestSuite, TestBasicRawClientRawServerNoDelay) { auto host = "localhost"; auto port = 2892; @@ -407,6 +411,9 @@ TEST(CcYmqTestSuite, TestMitmPassthrough) EXPECT_EQ(result, TestResult::Success); } +// TODO: This test should be redesigned so that the ACK is send from the remote end +// before the Man in the Middle sends RST. Please also make sure that the client does +// not exits before the Man in the Middle sends RST. // this test uses the mitm to test the reconnect logic of YMQ by sending RST packets // this test is disabled until fixes arrive in the core TEST(CcYmqTestSuite, DISABLED_TestMitmReconnect) @@ -460,7 +467,7 @@ TEST(CcYmqTestSuite, TestSlowNetwork) // in this test, a client connects to the YMQ server but only partially sends its identity and then disconnects // then a new client connection is established, and this one sends a complete identity and message // YMQ should be able to recover from a poorly-behaved client like this -TEST(CcYmqTestSuite, DISABLED_TestClientSendIncompleteIdentity) +TEST(CcYmqTestSuite, TestClientSendIncompleteIdentity) { auto host = "localhost"; auto port = 2896; @@ -478,6 +485,9 @@ TEST(CcYmqTestSuite, DISABLED_TestClientSendIncompleteIdentity) // // at the moment YMQ does not perform this check and throws std::bad_alloc // this test can be re-enabled after this is fixed +// TODO: maglinoquency should redesign the test so it does not halt. When the core +// receives a big header, it closes the connection on user's behalf. +// This however, makes the server waiting forever on a recvMessage call. TEST(CcYmqTestSuite, DISABLED_TestClientSendHugeHeader) { auto host = "localhost"; diff --git a/tests/object_storage/test_object_storage_server.cpp b/tests/object_storage/test_object_storage_server.cpp index 3392197c6..cb38d0674 100644 --- a/tests/object_storage/test_object_storage_server.cpp +++ b/tests/object_storage/test_object_storage_server.cpp @@ -530,7 +530,7 @@ TEST_F(ObjectStorageServerTest, TestClientDisconnect) // TODO: why does this not pass? // the message connection tcp is removed from the remote socket's list // but the object is never destructued, and so the connection is never closed -TEST_F(ObjectStorageServerTest, DISABLED_TestMalformedHeader) +TEST_F(ObjectStorageServerTest, TestMalformedHeader) { ObjectResponseHeader responseHeader; std::optional responsePayload; From 7f0cb6f91a6276586ccf6cfe2278bc727d8779a8 Mon Sep 17 00:00:00 2001 From: gxu Date: Thu, 25 Sep 2025 01:12:22 +0800 Subject: [PATCH 3/3] re-enable mitm test (still failing) Signed-off-by: gxu --- tests/cc_ymq/test_cc_ymq.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cc_ymq/test_cc_ymq.cpp b/tests/cc_ymq/test_cc_ymq.cpp index c4a51d761..ebe8304d3 100644 --- a/tests/cc_ymq/test_cc_ymq.cpp +++ b/tests/cc_ymq/test_cc_ymq.cpp @@ -416,7 +416,7 @@ TEST(CcYmqTestSuite, TestMitmPassthrough) // not exits before the Man in the Middle sends RST. // this test uses the mitm to test the reconnect logic of YMQ by sending RST packets // this test is disabled until fixes arrive in the core -TEST(CcYmqTestSuite, DISABLED_TestMitmReconnect) +TEST(CcYmqTestSuite, TestMitmReconnect) { auto mitm_ip = "192.0.2.4"; auto mitm_port = 2525;