diff --git a/src/windows/wslaservice/exe/WSLAContainer.cpp b/src/windows/wslaservice/exe/WSLAContainer.cpp index dc802d940..575870c77 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.cpp +++ b/src/windows/wslaservice/exe/WSLAContainer.cpp @@ -60,7 +60,7 @@ auto ProcessPortMappings(const WSLA_CONTAINER_OPTIONS& options, WSLAVirtualMachi { try { - vm.MapPort(e.Family, e.HostPort, e.VmPort, true); + vm.UnmapPort(e.Family, e.HostPort, e.VmPort); } CATCH_LOG(); } @@ -115,7 +115,7 @@ auto ProcessPortMappings(const WSLA_CONTAINER_OPTIONS& options, WSLAVirtualMachi // Map Windows <-> VM ports. for (auto& e : *mappedPorts) { - vm.MapPort(e.Family, e.HostPort, e.VmPort, false); + vm.MapPort(e.Family, e.HostPort, e.VmPort); e.MappedToHost = true; } @@ -202,7 +202,7 @@ WSLAContainerImpl::~WSLAContainerImpl() try { - m_parentVM->MapPort(e.Family, e.HostPort, e.VmPort, true); + m_parentVM->UnmapPort(e.Family, e.HostPort, e.VmPort); } CATCH_LOG(); diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index ef1e2318b..1ab42d5ec 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -624,13 +624,24 @@ try } CATCH_RETURN(); -HRESULT WSLASession::MapVmPort(int Family, short WindowsPort, short LinuxPort, BOOL Remove) +HRESULT WSLASession::MapVmPort(int Family, short WindowsPort, short LinuxPort) try { std::lock_guard lock{m_lock}; THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine); - m_virtualMachine->MapPort(Family, WindowsPort, LinuxPort, Remove != FALSE); + m_virtualMachine->MapPort(Family, WindowsPort, LinuxPort); + return S_OK; +} +CATCH_RETURN(); + +HRESULT WSLASession::UnmapVmPort(int Family, short WindowsPort, short LinuxPort) +try +{ + std::lock_guard lock{m_lock}; + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine); + + m_virtualMachine->UnmapPort(Family, WindowsPort, LinuxPort); return S_OK; } CATCH_RETURN(); diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h index 8cb5f5a54..815ca6c16 100644 --- a/src/windows/wslaservice/exe/WSLASession.h +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -74,7 +74,8 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession // Testing. IFACEMETHOD(MountWindowsFolder)(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly) override; IFACEMETHOD(UnmountWindowsFolder)(_In_ LPCSTR LinuxPath) override; - IFACEMETHOD(MapVmPort)(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove) override; + IFACEMETHOD(MapVmPort)(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort) override; + IFACEMETHOD(UnmapVmPort)(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort) override; void OnUserSessionTerminating(); diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp index df1829cf0..38a307807 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -1137,7 +1137,7 @@ void WSLAVirtualMachine::LaunchPortRelay() writePipe.release(); } -void WSLAVirtualMachine::MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove) +void WSLAVirtualMachine::MapPortImpl(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ bool Remove) { std::lock_guard lock(m_portRelaylock); @@ -1160,6 +1160,16 @@ void WSLAVirtualMachine::MapPort(_In_ int Family, _In_ short WindowsPort, _In_ s THROW_IF_FAILED_MSG(result, "Failed to map port: WindowsPort=%d, LinuxPort=%d, Family=%d, Remove=%d", WindowsPort, LinuxPort, Family, Remove); } +void WSLAVirtualMachine::MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort) +{ + MapPortImpl(Family, WindowsPort, LinuxPort, false); +} + +void WSLAVirtualMachine::UnmapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort) +{ + MapPortImpl(Family, WindowsPort, LinuxPort, true); +} + HRESULT WSLAVirtualMachine::MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly) { return MountWindowsFolderImpl(WindowsPath, LinuxPath, ReadOnly ? WSLAMountFlagsReadOnly : WSLAMountFlagsNone); diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h index 8a2087f5f..f39760cf4 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.h +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h @@ -72,7 +72,8 @@ class WSLAVirtualMachine void Start(); void OnSessionTerminated(); - void MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove); + void MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort); + void UnmapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort); void Unmount(_In_ const char* Path); HRESULT MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly); @@ -107,6 +108,8 @@ class WSLAVirtualMachine } private: + void MapPortImpl(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ bool Remove); + static void Mount(wsl::shared::SocketChannel& Channel, LPCSTR Source, _In_ LPCSTR Target, _In_ LPCSTR Type, _In_ LPCSTR Options, _In_ ULONG Flags); void MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint); static void CALLBACK s_OnExit(_In_ HCS_EVENT* Event, _In_opt_ void* Context); diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 7722b43a7..9face7c18 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -324,7 +324,8 @@ interface IWSLASession : IUnknown // Used only for testing. TODO: Think about moving them to a dedicated testing-only interface. HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly); HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath); - HRESULT MapVmPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove); + HRESULT MapVmPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort); + HRESULT UnmapVmPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort); } struct WSLA_SESSION_INFORMATION diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp index 3ed0470e3..7d1687fe3 100644 --- a/test/windows/WSLATests.cpp +++ b/test/windows/WSLATests.cpp @@ -788,10 +788,10 @@ class WSLATests }; // Map port - VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, false)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80)); // Validate that the same port can't be bound twice - VERIFY_ARE_EQUAL(session->MapVmPort(AF_INET, 1234, 80, false), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS)); + VERIFY_ARE_EQUAL(session->MapVmPort(AF_INET, 1234, 80), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS)); // Check simple case listen(80, "port80", false); @@ -805,34 +805,34 @@ class WSLATests expectContent(1234, AF_INET, ""); // Add a ipv6 binding - VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1234, 80, false)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1234, 80)); // Validate that ipv6 bindings work as well. listen(80, "port80ipv6", true); expectContent(1234, AF_INET6, "port80ipv6"); // Unmap the ipv4 port - VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, true)); + VERIFY_SUCCEEDED(session->UnmapVmPort(AF_INET, 1234, 80)); // Verify that a proper error is returned if the mapping doesn't exist - VERIFY_ARE_EQUAL(session->MapVmPort(AF_INET, 1234, 80, true), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); + VERIFY_ARE_EQUAL(session->UnmapVmPort(AF_INET, 1234, 80), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); // Unmap the v6 port - VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1234, 80, true)); + VERIFY_SUCCEEDED(session->UnmapVmPort(AF_INET6, 1234, 80)); // Map another port as v6 only - VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1235, 81, false)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1235, 81)); listen(81, "port81ipv6", true); expectContent(1235, AF_INET6, "port81ipv6"); expectNotBound(1235, AF_INET); - VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1235, 81, true)); - VERIFY_ARE_EQUAL(session->MapVmPort(AF_INET6, 1235, 81, true), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); + VERIFY_SUCCEEDED(session->UnmapVmPort(AF_INET6, 1235, 81)); + VERIFY_ARE_EQUAL(session->UnmapVmPort(AF_INET6, 1235, 81), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); expectNotBound(1235, AF_INET6); // Create a forking relay and stress test - VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, false)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80)); auto process = WSLAProcessLauncher{"/usr/bin/socat", {"/usr/bin/socat", "-dd", "TCP-LISTEN:80,fork,reuseaddr", "system:'echo -n OK'"}} @@ -845,7 +845,7 @@ class WSLATests expectContent(1234, AF_INET, "OK"); } - VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, true)); + VERIFY_SUCCEEDED(session->UnmapVmPort(AF_INET, 1234, 80)); } TEST_METHOD(StuckVmTermination) @@ -2431,6 +2431,9 @@ class WSLATests VERIFY_SUCCEEDED(session1Copy->Terminate()); expectSessions({}); + + // Validate that a new session is created if WSLASessionFlagsOpenExisting is set and no match is found. + auto session2 = create(L"session-2", static_cast(WSLASessionFlagsOpenExisting)); } } };