diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in index 0ef367e0d..00c53fae1 100644 --- a/msipackage/package.wix.in +++ b/msipackage/package.wix.in @@ -268,11 +268,6 @@ - - - - - @@ -287,14 +282,6 @@ - - - - - - - - diff --git a/src/linux/init/WSLAInit.cpp b/src/linux/init/WSLAInit.cpp index a69c404bb..eea006aa5 100644 --- a/src/linux/init/WSLAInit.cpp +++ b/src/linux/init/WSLAInit.cpp @@ -656,59 +656,6 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_PORT_RELA RunLocalHostRelay(SocketAddress, ListenSocket.get()); } -void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_WAITPID& Message, const gsl::span& Buffer) -{ - WSLA_WAITPID_RESULT response{}; - response.State = WSLAOpenFlagsUnknown; - - auto sendResponse = wil::scope_exit([&]() { Channel.SendMessage(response); }); - - wil::unique_fd process = syscall(SYS_pidfd_open, Message.Pid, 0); - if (!process) - { - LOG_ERROR("pidfd_open({}) failed, {}", Message.Pid, errno); - response.Errno = errno; - return; - } - - pollfd pollResult{}; - pollResult.fd = process.get(); - pollResult.events = POLLIN | POLLERR; - - int result = poll(&pollResult, 1, Message.TimeoutMs); - if (result < 0) - { - LOG_ERROR("poll failed {}", errno); - response.Errno = errno; - return; - } - else if (result == 0) // Timed out - { - response.State = WSLAOpenFlagsRunning; - response.Errno = 0; - return; - } - - if (WI_IsFlagSet(pollResult.revents, POLLIN)) - { - siginfo_t childState{}; - auto result = waitid(P_PIDFD, process.get(), &childState, WEXITED); - if (result < 0) - { - LOG_ERROR("waitid({}) failed, {}", process.get(), errno); - response.Errno = errno; - return; - } - - response.Code = childState.si_status; - response.Errno = 0; - response.State = childState.si_code == CLD_EXITED ? WSLAOpenFlagsExited : WSLAOpenFlagsSignaled; - return; - } - - LOG_ERROR("Poll returned an unexpected error state on fd: {} for pid: {}", process.get(), Message.Pid); -} - void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_SIGNAL& Message, const gsl::span& Buffer) { auto result = kill(Message.Pid, Message.Signal); @@ -866,7 +813,7 @@ void ProcessMessage(wsl::shared::SocketChannel& Channel, LX_MESSAGE_TYPE Type, c { try { - HandleMessage( + HandleMessage( Channel, Type, Buffer); } catch (...) @@ -882,7 +829,7 @@ void ProcessMessages(wsl::shared::SocketChannel& Channel) while (Channel.Connected()) { auto [Message, Range] = Channel.ReceiveMessageOrClosed(); - if (Message == nullptr || Message->MessageType == LxMessageWSLAShutdown) + if (Message == nullptr) { break; } diff --git a/src/shared/inc/lxinitshared.h b/src/shared/inc/lxinitshared.h index 3aa633968..2c596459e 100644 --- a/src/shared/inc/lxinitshared.h +++ b/src/shared/inc/lxinitshared.h @@ -390,7 +390,6 @@ typedef enum _LX_MESSAGE_TYPE LxMessageWSLAWaitPid, LxMessageWSLAWaitPidResponse, LxMessageWSLASignal, - LxMessageWSLAShutdown, LxMessageWSLARelayTty, LxMessageWSLAMapPort, LxMessageWSLAConnectRelay, @@ -502,7 +501,6 @@ inline auto ToString(LX_MESSAGE_TYPE messageType) X(LxMessageWSLAWaitPid) X(LxMessageWSLAWaitPidResponse) X(LxMessageWSLASignal) - X(LxMessageWSLAShutdown) X(LxMessageWSLARelayTty) X(LxMessageWSLAMapPort) X(LxMessageWSLAConnectRelay) @@ -1734,33 +1732,6 @@ enum WSLAOpenFlags WSLAOpenFlagsSignaled }; -struct WSLA_WAITPID_RESULT -{ - static inline auto Type = LxMessageWSLAWaitPidResponse; - - DECLARE_MESSAGE_CTOR(WSLA_WAITPID_RESULT); - - MESSAGE_HEADER Header; - WSLAOpenFlags State = WSLAOpenFlagsUnknown; - int32_t Code = -1; - int32_t Errno = -1; - PRETTY_PRINT(FIELD(Header), FIELD(State), FIELD(Code), FIELD(Errno)); -}; - -struct WSLA_WAITPID -{ - static inline auto Type = LxMessageWSLAWaitPid; - using TResponse = WSLA_WAITPID_RESULT; - - DECLARE_MESSAGE_CTOR(WSLA_WAITPID); - - MESSAGE_HEADER Header; - int32_t Pid = -1; - uint64_t TimeoutMs = 0; - - PRETTY_PRINT(FIELD(Header), FIELD(Pid), FIELD(TimeoutMs)); -}; - struct WSLA_SIGNAL { static inline auto Type = LxMessageWSLASignal; @@ -1775,17 +1746,6 @@ struct WSLA_SIGNAL PRETTY_PRINT(FIELD(Header), FIELD(Pid), FIELD(Signal)); }; -struct WSLA_SHUTDOWN -{ - static inline auto Type = LxMessageWSLAShutdown; - using TResponse = RESULT_MESSAGE; - - DECLARE_MESSAGE_CTOR(WSLA_SHUTDOWN); - MESSAGE_HEADER Header; - - PRETTY_PRINT(FIELD(Header)); -}; - struct WSLA_MAP_PORT { static inline auto Type = LxMessageWSLAMapPort; diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp index 3b7b8ba05..7f713854b 100644 --- a/src/windows/common/WslClient.cpp +++ b/src/windows/common/WslClient.cpp @@ -1596,7 +1596,6 @@ int WslaShell(_In_ std::wstring_view commandLine) THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); - wil::com_ptr virtualMachine; wil::com_ptr session; if (!rootVhdOverride.empty()) @@ -1623,8 +1622,7 @@ int WslaShell(_In_ std::wstring_view commandLine) } else { - THROW_IF_FAILED(userSession->CreateSession(&sessionSettings, &session)); - THROW_IF_FAILED(session->GetVirtualMachine(&virtualMachine)); + THROW_IF_FAILED(userSession->CreateSession(&sessionSettings, WSLASessionFlagsNone, &session)); wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); } diff --git a/src/windows/wslaservice/exe/WSLAContainer.cpp b/src/windows/wslaservice/exe/WSLAContainer.cpp index 330c650fd..dc802d940 100644 --- a/src/windows/wslaservice/exe/WSLAContainer.cpp +++ b/src/windows/wslaservice/exe/WSLAContainer.cpp @@ -58,12 +58,11 @@ auto ProcessPortMappings(const WSLA_CONTAINER_OPTIONS& options, WSLAVirtualMachi { if (e.MappedToHost) { - LOG_IF_FAILED_MSG( - vm.MapPort(e.Family, e.HostPort, e.VmPort, true), - "Failed to unmap port (family=%i, guestPort=%u, hostPort=%u)", - e.Family, - e.VmPort, - e.HostPort); + try + { + vm.MapPort(e.Family, e.HostPort, e.VmPort, true); + } + CATCH_LOG(); } } }); @@ -116,7 +115,7 @@ auto ProcessPortMappings(const WSLA_CONTAINER_OPTIONS& options, WSLAVirtualMachi // Map Windows <-> VM ports. for (auto& e : *mappedPorts) { - THROW_IF_FAILED(vm.MapPort(e.Family, e.HostPort, e.VmPort, false)); + vm.MapPort(e.Family, e.HostPort, e.VmPort, false); e.MappedToHost = true; } @@ -201,12 +200,11 @@ WSLAContainerImpl::~WSLAContainerImpl() { WI_ASSERT(e.MappedToHost); - LOG_IF_FAILED_MSG( - m_parentVM->MapPort(e.Family, e.HostPort, e.VmPort, true), - "Failed to delete port mapping (family=%i, guestPort=%u, hostPort=%u)", - e.Family, - e.VmPort, - e.HostPort); + try + { + m_parentVM->MapPort(e.Family, e.HostPort, e.VmPort, true); + } + CATCH_LOG(); allocatedGuestPorts.insert(e.VmPort); } diff --git a/src/windows/wslaservice/exe/WSLAProcessControl.cpp b/src/windows/wslaservice/exe/WSLAProcessControl.cpp index 4632aa1fc..65e27a28f 100644 --- a/src/windows/wslaservice/exe/WSLAProcessControl.cpp +++ b/src/windows/wslaservice/exe/WSLAProcessControl.cpp @@ -181,7 +181,7 @@ void VMProcessControl::Signal(int Signal) std::lock_guard lock{m_lock}; THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), m_vm == nullptr || m_exitEvent.is_signaled()); - THROW_IF_FAILED(m_vm->Signal(m_pid, Signal)); + m_vm->Signal(m_pid, Signal); } void VMProcessControl::ResizeTty(ULONG Rows, ULONG Columns) diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp index 08e4d299d..ef1e2318b 100644 --- a/src/windows/wslaservice/exe/WSLASession.cpp +++ b/src/windows/wslaservice/exe/WSLASession.cpp @@ -46,7 +46,7 @@ WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUs { WSL_LOG("SessionCreated", TraceLoggingValue(m_displayName.c_str(), "DisplayName")); - m_virtualMachine = wil::MakeOrThrow(CreateVmSettings(Settings), userSessionImpl.GetUserSid()); + m_virtualMachine.emplace(CreateVmSettings(Settings), userSessionImpl.GetUserSid()); if (Settings.TerminationCallback != nullptr) { @@ -74,7 +74,7 @@ WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUs {"/usr/bin/dockerd" /*, "--debug"*/}, // TODO: Flag for --debug. {{"PATH=/bin:/usr/local/sbin:/usr/bin:/usr/sbin:/sbin"}}, common::ProcessFlags::Stdout | common::ProcessFlags::Stderr}; - m_containerdThread = std::thread(&WSLASession::MonitorContainerd, this, launcher.Launch(*m_virtualMachine.Get())); + m_containerdThread = std::thread(&WSLASession::MonitorContainerd, this, launcher.Launch(*m_virtualMachine)); // Wait for containerd to be ready before starting the event tracker. // TODO: Configurable timeout. @@ -134,34 +134,7 @@ WSLASession::~WSLASession() { WSL_LOG("SessionTerminated", TraceLoggingValue(m_displayName.c_str(), "DisplayName")); - std::lock_guard lock{m_lock}; - - // Stop the event tracker - if (m_eventTracker.has_value()) - { - m_eventTracker->Stop(); - } - - // This will delete all containers. Needs to be done before the VM is terminated. - // TODO: If callers still have references to containers, the instances won't actually be deleted. - m_containers.clear(); - - m_sessionTerminatingEvent.SetEvent(); - - // N.B. The containerd thread can only run if the VM is running. - if (m_containerdThread.joinable()) - { - m_containerdThread.join(); - } - - if (m_virtualMachine) - { - // N.B. containerd has exited by this point, so unmounting the VHD is safe since no container can be running. - LOG_IF_FAILED(m_virtualMachine->Unmount(c_containerdStorage)); - m_virtualMachine->OnSessionTerminated(); - - m_virtualMachine.Reset(); - } + Terminate(); } void WSLASession::ConfigureStorage(const WSLA_SESSION_SETTINGS& Settings, PSID UserSid) @@ -484,7 +457,7 @@ try containerOptions->Name, WSLAContainerImpl::Create( *containerOptions, - *m_virtualMachine.Get(), + *m_virtualMachine, std::bind(&WSLASession::OnContainerDeleted, this, std::placeholders::_1), m_eventTracker.value(), m_dockerClient.value())); @@ -535,15 +508,6 @@ try } CATCH_RETURN(); -HRESULT WSLASession::GetVirtualMachine(IWSLAVirtualMachine** VirtualMachine) -{ - std::lock_guard lock{m_lock}; - THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine); - - THROW_IF_FAILED(m_virtualMachine->QueryInterface(__uuidof(IWSLAVirtualMachine), (void**)VirtualMachine)); - return S_OK; -} - HRESULT WSLASession::CreateRootNamespaceProcess(const WSLA_PROCESS_OPTIONS* Options, IWSLAProcess** Process, int* Errno) try { @@ -555,7 +519,10 @@ try std::lock_guard lock{m_lock}; THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine); - return m_virtualMachine->CreateLinuxProcess(Options, Process, Errno); + auto process = m_virtualMachine->CreateLinuxProcess(*Options, Errno); + THROW_IF_FAILED(process.CopyTo(Process)); + + return S_OK; } CATCH_RETURN(); @@ -563,7 +530,7 @@ void WSLASession::Ext4Format(const std::string& Device) { constexpr auto mkfsPath = "/usr/sbin/mkfs.ext4"; ServiceProcessLauncher launcher(mkfsPath, {mkfsPath, Device}); - auto result = launcher.Launch(*m_virtualMachine.Get()).WaitAndCaptureOutput(); + auto result = launcher.Launch(*m_virtualMachine).WaitAndCaptureOutput(); THROW_HR_IF_MSG(E_FAIL, result.Code != 0, "%hs", launcher.FormatResult(result).c_str()); } @@ -590,25 +557,80 @@ try CATCH_RETURN(); void WSLASession::OnUserSessionTerminating() +{ + LOG_IF_FAILED(Terminate()); +} + +HRESULT WSLASession::Terminate() +try { // m_sessionTerminatingEvent is always valid, so it can be signalled with the lock. // This allows a session to be unblocked if a stuck operation is holding the lock. m_sessionTerminatingEvent.SetEvent(); std::lock_guard lock{m_lock}; + + // Stop the event tracker + if (m_eventTracker.has_value()) + { + m_eventTracker->Stop(); + } + + // This will delete all containers. Needs to be done before the VM is terminated. + m_containers.clear(); + m_dockerClient.reset(); - m_virtualMachine.Reset(); + + // N.B. The containerd thread can only run if the VM is running. + if (m_containerdThread.joinable()) + { + m_containerdThread.join(); + } + + if (m_virtualMachine) + { + // N.B. containerd has exited by this point, so unmounting the VHD is safe since no container can be running. + try + { + m_virtualMachine->Unmount(c_containerdStorage); + } + CATCH_LOG(); + + m_virtualMachine->OnSessionTerminated(); + m_virtualMachine.reset(); + } + + return S_OK; } +CATCH_RETURN(); -HRESULT WSLASession::Shutdown(ULONG Timeout) +HRESULT WSLASession::MountWindowsFolder(LPCWSTR WindowsPath, LPCSTR LinuxPath, BOOL ReadOnly) try { std::lock_guard lock{m_lock}; THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine); - THROW_IF_FAILED(m_virtualMachine->Shutdown(Timeout)); + return m_virtualMachine->MountWindowsFolder(WindowsPath, LinuxPath, ReadOnly); +} +CATCH_RETURN(); - m_virtualMachine.Reset(); +HRESULT WSLASession::UnmountWindowsFolder(LPCSTR LinuxPath) +try +{ + std::lock_guard lock{m_lock}; + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine); + + return m_virtualMachine->UnmountWindowsFolder(LinuxPath); +} +CATCH_RETURN(); + +HRESULT WSLASession::MapVmPort(int Family, short WindowsPort, short LinuxPort, BOOL Remove) +try +{ + std::lock_guard lock{m_lock}; + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine); + + m_virtualMachine->MapPort(Family, WindowsPort, LinuxPort, Remove != FALSE); return S_OK; } CATCH_RETURN(); @@ -627,4 +649,10 @@ HRESULT WSLASession::GetImplNoRef(_Out_ WSLASession** Session) // beyond that lifetime. *Session = this; return S_OK; +} + +bool WSLASession::Terminated() +{ + std::lock_guard lock{m_lock}; + return !m_virtualMachine; } \ No newline at end of file diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h index 4e1792d38..8cb5f5a54 100644 --- a/src/windows/wslaservice/exe/WSLASession.h +++ b/src/windows/wslaservice/exe/WSLASession.h @@ -62,18 +62,24 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession IFACEMETHOD(ListContainers)(_Out_ WSLA_CONTAINER** Images, _Out_ ULONG* Count) override; // VM management. - IFACEMETHOD(GetVirtualMachine)(IWSLAVirtualMachine** VirtualMachine) override; IFACEMETHOD(CreateRootNamespaceProcess)(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** VirtualMachine, _Out_ int* Errno) override; // Disk management. IFACEMETHOD(FormatVirtualDisk)(_In_ LPCWSTR Path) override; - IFACEMETHOD(Shutdown(_In_ ULONG)) override; + IFACEMETHOD(Terminate()) override; IFACEMETHOD(GetImplNoRef)(_Out_ WSLASession** Session) override; + // Testing. + IFACEMETHOD(MountWindowsFolder)(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly) override; + IFACEMETHOD(UnmountWindowsFolder)(_In_ LPCSTR LinuxPath) override; + IFACEMETHOD(MapVmPort)(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove) override; + void OnUserSessionTerminating(); + bool Terminated(); + private: ULONG m_id = 0; @@ -88,7 +94,7 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession WSLA_SESSION_SETTINGS m_sessionSettings; // TODO: Revisit to see if we should have session settings as a member or not std::optional m_dockerClient; - Microsoft::WRL::ComPtr m_virtualMachine; + std::optional m_virtualMachine; std::optional m_eventTracker; wil::unique_event m_containerdReadyEvent{wil::EventOptions::ManualReset}; std::thread m_containerdThread; diff --git a/src/windows/wslaservice/exe/WSLAUserSession.cpp b/src/windows/wslaservice/exe/WSLAUserSession.cpp index 0e006e79b..7ae9f307b 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.cpp +++ b/src/windows/wslaservice/exe/WSLAUserSession.cpp @@ -34,26 +34,49 @@ PSID WSLAUserSessionImpl::GetUserSid() const return m_tokenInfo->User.Sid; } -HRESULT WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** WslaSession) +HRESULT WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, WSLASessionFlags Flags, IWSLASession** WslaSession) +try { ULONG id = m_nextSessionId++; - auto session = wil::MakeOrThrow(id, *Settings, *this); - Microsoft::WRL::ComPtr weakRef; - THROW_IF_FAILED(session->GetWeakReference(&weakRef)); + std::lock_guard lock(m_wslaSessionsLock); + + // Check for an existing session first. + auto result = ForEachSession([&](auto& session) -> std::optional { + // TODO: ACL check. + if (session.DisplayName() == Settings->DisplayName) + { + RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS), WI_IsFlagClear(Flags, WSLASessionFlagsOpenExisting)); + + return session.QueryInterface(__uuidof(IWSLASession), (void**)WslaSession); + } + + return std::optional{}; + }); + if (result.has_value()) { - std::lock_guard lock(m_wslaSessionsLock); - m_sessions.emplace_back(std::move(weakRef)); + return result.value(); } - // Client now owns the session. - // TODO: Add a flag for the client to specify that the session should outlive its process. + // No session was found, create a new one. + auto session = wil::MakeOrThrow(id, *Settings, *this); + + if (WI_IsFlagSet(Flags, WSLASessionFlagsPersistent)) + { + m_persistentSessions.push_back(session); + } + + Microsoft::WRL::ComPtr weakRef; + THROW_IF_FAILED(session->GetWeakReference(&weakRef)); + + m_sessions.emplace_back(std::move(weakRef)); THROW_IF_FAILED(session.CopyTo(__uuidof(IWSLASession), (void**)WslaSession)); return S_OK; } +CATCH_RETURN(); HRESULT WSLAUserSessionImpl::OpenSessionByName(LPCWSTR DisplayName, IWSLASession** Session) { @@ -113,13 +136,13 @@ HRESULT wsl::windows::service::wsla::WSLAUserSession::GetVersion(_Out_ WSLA_VERS return S_OK; } -HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** WslaSession) +HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateSession(const WSLA_SESSION_SETTINGS* Settings, WSLASessionFlags Flags, IWSLASession** WslaSession) try { auto session = m_session.lock(); RETURN_HR_IF(RPC_E_DISCONNECTED, !session); - return session->CreateSession(Settings, WslaSession); + return session->CreateSession(Settings, Flags, WslaSession); } CATCH_RETURN(); diff --git a/src/windows/wslaservice/exe/WSLAUserSession.h b/src/windows/wslaservice/exe/WSLAUserSession.h index b1f07a4b9..74e5aabad 100644 --- a/src/windows/wslaservice/exe/WSLAUserSession.h +++ b/src/windows/wslaservice/exe/WSLAUserSession.h @@ -33,7 +33,7 @@ class WSLAUserSessionImpl PSID GetUserSid() const; - HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** WslaSession); + HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, WSLASessionFlags Flags, IWSLASession** WslaSession); HRESULT OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLASession** Session); HRESULT ListSessions(_Out_ WSLA_SESSION_INFORMATION** Sessions, _Out_ ULONG* SessionsCount); @@ -57,6 +57,18 @@ class WSLAUserSessionImpl WSLASession* SessionImpl{}; THROW_IF_FAILED(lockedSession->GetImplNoRef(&SessionImpl)); + // If the session is terminated, drop its reference so it can be deleted (in case of persistent sessions) + if (SessionImpl->Terminated()) + { + auto remove = + std::ranges::remove_if(m_persistentSessions, [&](const auto& e) { return SessionImpl->GetId() == e->GetId(); }); + + WI_ASSERT(remove.end() - remove.begin() <= 1); + + m_persistentSessions.erase(remove.begin(), remove.end()); + return true; + } + if constexpr (std::is_same_v) { Routine(*SessionImpl); @@ -90,6 +102,8 @@ class WSLAUserSessionImpl std::atomic m_nextSessionId{1}; std::recursive_mutex m_wslaSessionsLock; + // Persistent sessions that outlive their creating process. + std::vector> m_persistentSessions; std::vector> m_sessions; }; @@ -102,7 +116,7 @@ class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLAUserSession WSLAUserSession& operator=(const WSLAUserSession&) = delete; IFACEMETHOD(GetVersion)(_Out_ WSLA_VERSION* Version) override; - IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* WslaSessionSettings, IWSLASession** WslaSession) override; + IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* WslaSessionSettings, WSLASessionFlags Flags, IWSLASession** WslaSession) override; IFACEMETHOD(ListSessions)(_Out_ WSLA_SESSION_INFORMATION** Sessions, _Out_ ULONG* SessionsCount) override; IFACEMETHOD(OpenSession)(_In_ ULONG Id, _Out_ IWSLASession** Session) override; IFACEMETHOD(OpenSessionByName)(_In_ LPCWSTR DisplayName, _Out_ IWSLASession** Session) override; diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp index 659c0550d..370db51fc 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp @@ -736,8 +736,7 @@ std::pair WSLAVirtualMachine::AttachDisk(_In_ PCWSTR Path, _ return {Lun, Device}; } -HRESULT WSLAVirtualMachine::Unmount(_In_ const char* Path) -try +void WSLAVirtualMachine::Unmount(_In_ const char* Path) { auto [pid, _, subChannel] = Fork(WSLA_FORK::Thread); @@ -749,10 +748,7 @@ try // TODO: Return errno to caller THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), response.Result == EINVAL); THROW_HR_IF(E_FAIL, response.Result != 0); - - return S_OK; } -CATCH_RETURN() void WSLAVirtualMachine::DetachDisk(_In_ ULONG Lun) { @@ -859,15 +855,6 @@ void WSLAVirtualMachine::OpenLinuxFile(wsl::shared::SocketChannel& Channel, cons THROW_HR_IF_MSG(E_FAIL, result != 0, "Failed to open %hs (flags: %u), %i", Path, Flags, result); } -HRESULT WSLAVirtualMachine::CreateLinuxProcess(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** Process, _Out_ int* Errno) -try -{ - CreateLinuxProcess(*Options, Errno).CopyTo(Process); - - return S_OK; -} -CATCH_RETURN(); - Microsoft::WRL::ComPtr WSLAVirtualMachine::CreateLinuxProcess(_In_ const WSLA_PROCESS_OPTIONS& Options, int* Errno, const TPrepareCommandLine& PrepareCommandLine) { // N.B This check is there to prevent processes from being started before the VM is done initializing. @@ -1046,46 +1033,7 @@ int32_t WSLAVirtualMachine::ExpectClosedChannelOrError(wsl::shared::SocketChanne } } -HRESULT WSLAVirtualMachine::WaitPid(LONG Pid, ULONGLONG TimeoutMs, ULONG* State, int* Code) -try -{ - auto [pid, _, subChannel] = Fork(WSLA_FORK::Thread); - - WSLA_WAITPID message{}; - message.Pid = Pid; - message.TimeoutMs = TimeoutMs; - - const auto& response = subChannel.Transaction(message); - - THROW_HR_IF(E_FAIL, response.State == WSLAOpenFlagsUnknown); - - *State = response.State; - *Code = response.Code; - - return S_OK; -} -CATCH_RETURN(); - -HRESULT WSLAVirtualMachine::Shutdown(ULONGLONG TimeoutMs) -try -{ - std::lock_guard lock(m_lock); - - THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_running); - - WSLA_SHUTDOWN message{}; - m_initChannel.SendMessage(message); - auto response = m_initChannel.ReceiveMessageOrClosed(static_cast(TimeoutMs)); - - RETURN_HR_IF(E_UNEXPECTED, response.first != nullptr); - - m_running = false; - return S_OK; -} -CATCH_RETURN(); - -HRESULT WSLAVirtualMachine::Signal(_In_ LONG Pid, _In_ int Signal) -try +void WSLAVirtualMachine::Signal(_In_ LONG Pid, _In_ int Signal) { std::lock_guard lock(m_lock); THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_running); @@ -1095,10 +1043,8 @@ try message.Signal = Signal; const auto& response = m_initChannel.Transaction(message); - RETURN_HR_IF(E_FAIL, response.Result != 0); - return S_OK; + THROW_HR_IF(E_FAIL, response.Result != 0); } -CATCH_RETURN(); void WSLAVirtualMachine::RegisterCallback(ITerminationCallback* callback) { @@ -1192,12 +1138,11 @@ void WSLAVirtualMachine::LaunchPortRelay() writePipe.release(); } -HRESULT WSLAVirtualMachine::MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove) -try +void WSLAVirtualMachine::MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove) { std::lock_guard lock(m_portRelaylock); - RETURN_HR_IF(E_ILLEGAL_STATE_CHANGE, !m_portRelayChannelWrite); + THROW_HR_IF(E_ILLEGAL_STATE_CHANGE, !m_portRelayChannelWrite); WSLA_MAP_PORT message; message.WindowsPort = WindowsPort; @@ -1213,10 +1158,8 @@ try THROW_IF_WIN32_BOOL_FALSE(ReadFile(m_portRelayChannelRead.get(), &result, sizeof(result), &bytesTransfered, nullptr)); THROW_HR_IF(E_UNEXPECTED, bytesTransfered != sizeof(result)); - - return result; + THROW_IF_FAILED_MSG(result, "Failed to map port: WindowsPort=%d, LinuxPort=%d, Family=%d, Remove=%d", WindowsPort, LinuxPort, Family, Remove); } -CATCH_RETURN(); HRESULT WSLAVirtualMachine::MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly) { @@ -1332,7 +1275,6 @@ void WSLAVirtualMachine::RemoveShare(_In_ const MountedFolderInfo& MountInfo) } HRESULT WSLAVirtualMachine::UnmountWindowsFolder(_In_ LPCSTR LinuxPath) -try { std::lock_guard lock(m_lock); @@ -1341,7 +1283,7 @@ try THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), it == m_mountedWindowsFolders.end()); // Unmount the folder from the guest. If the mount is not found, this most likely means that the guest unmounted it. - auto result = Unmount(LinuxPath); + auto result = wil::ResultFromException([&]() { Unmount(LinuxPath); }); THROW_HR_IF(result, FAILED(result) && result != HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); auto mountInfo = it->second; @@ -1352,7 +1294,6 @@ try return S_OK; } -CATCH_RETURN(); void WSLAVirtualMachine::MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint) { diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h index 6ac41c3d8..8a2087f5f 100644 --- a/src/windows/wslaservice/exe/WSLAVirtualMachine.h +++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h @@ -34,8 +34,7 @@ enum WSLAMountFlags class WSLAUserSessionImpl; -class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine - : public Microsoft::WRL::RuntimeClass, IWSLAVirtualMachine, IFastRundown> +class WSLAVirtualMachine { public: @@ -73,14 +72,12 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine void Start(); void OnSessionTerminated(); - IFACEMETHOD(CreateLinuxProcess(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** Process, _Out_ int* Errno)) override; - IFACEMETHOD(WaitPid(_In_ LONG Pid, _In_ ULONGLONG TimeoutMs, _Out_ ULONG* State, _Out_ int* Code)) override; - IFACEMETHOD(Signal(_In_ LONG Pid, _In_ int Signal)) override; - IFACEMETHOD(Shutdown(ULONGLONG _In_ TimeoutMs)) override; - IFACEMETHOD(MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)) override; - IFACEMETHOD(Unmount(_In_ const char* Path)) override; - IFACEMETHOD(MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly)) override; - IFACEMETHOD(UnmountWindowsFolder(_In_ LPCSTR LinuxPath)) override; + void MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove); + void Unmount(_In_ const char* Path); + + HRESULT MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly); + HRESULT UnmountWindowsFolder(_In_ LPCSTR LinuxPath); + void Signal(_In_ LONG Pid, _In_ int Signal); void OnProcessReleased(int Pid); void RegisterCallback(_In_ ITerminationCallback* callback); diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl index 517b61f64..7722b43a7 100644 --- a/src/windows/wslaservice/inc/wslaservice.idl +++ b/src/windows/wslaservice/inc/wslaservice.idl @@ -236,25 +236,6 @@ interface IWSLAProcess : IUnknown // Note: the SDK can offer a convenience Wait() method, but that doesn't need to be part of the service API. } - -// TODO: Delete once the new API is wired. -[ - uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8761), - pointer_default(unique), - object -] -interface IWSLAVirtualMachine : IUnknown -{ - HRESULT CreateLinuxProcess([in] const struct WSLA_PROCESS_OPTIONS* Options, [out] IWSLAProcess** Process, [out] int* Errno); - HRESULT WaitPid([in] LONG Pid, [in] ULONGLONG TimeoutMs, [out] ULONG* State, [out] int* Code); - HRESULT Signal([in] LONG Pid, [in] int Signal); - HRESULT Shutdown([in] ULONGLONG TimeoutMs); - HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove); - HRESULT Unmount([in] LPCSTR Path); - HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly); - HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath); -} - typedef enum _WSLANetworkingMode { WSLANetworkingModeNone, @@ -338,10 +319,12 @@ interface IWSLASession : IUnknown HRESULT FormatVirtualDisk([in] LPCWSTR Path); // Terminate the VM and containers. - HRESULT Shutdown([in] ULONG TimeoutMs); + HRESULT Terminate(); - // To be deleted. - HRESULT GetVirtualMachine([out] IWSLAVirtualMachine **VirtualMachine); + // Used only for testing. TODO: Think about moving them to a dedicated testing-only interface. + HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly); + HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath); + HRESULT MapVmPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove); } struct WSLA_SESSION_INFORMATION @@ -351,6 +334,13 @@ struct WSLA_SESSION_INFORMATION wchar_t DisplayName[256]; }; +typedef enum _WSLASessionFlags +{ + WSLASessionFlagsNone = 0, + WSLASessionFlagsPersistent = 1, // Session remains active after its COM reference is released. + WSLASessionFlagsOpenExisting = 2 // Open an existing session if the name is in use. +} WSLASessionFlags; + [ uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760), pointer_default(unique), @@ -361,7 +351,7 @@ interface IWSLAUserSession : IUnknown HRESULT GetVersion([out] WSLA_VERSION* Version); // Session managment. - HRESULT CreateSession([in] const struct WSLA_SESSION_SETTINGS* Settings, [out]IWSLASession** Session); + HRESULT CreateSession([in] const struct WSLA_SESSION_SETTINGS* Settings, WSLASessionFlags Flags, [out]IWSLASession** Session); HRESULT ListSessions([out, size_is(, *SessionsCount)] struct WSLA_SESSION_INFORMATION** Sessions, [out] ULONG* SessionsCount); HRESULT OpenSession([in] ULONG Id, [out]IWSLASession** Session); HRESULT OpenSessionByName([in] LPCWSTR DisplayName, [out] IWSLASession** Session); diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp index 0a49c5d4e..ab90c2636 100644 --- a/test/windows/WSLATests.cpp +++ b/test/windows/WSLATests.cpp @@ -87,7 +87,7 @@ class WSLATests return settings; } - wil::com_ptr CreateSession(const WSLA_SESSION_SETTINGS& sessionSettings = GetDefaultSessionSettings()) + wil::com_ptr CreateSession(const WSLA_SESSION_SETTINGS& sessionSettings = GetDefaultSessionSettings(), WSLASessionFlags Flags = WSLASessionFlagsNone) { wil::com_ptr userSession; VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); @@ -95,7 +95,7 @@ class WSLATests wil::com_ptr session; - VERIFY_SUCCEEDED(userSession->CreateSession(&sessionSettings, &session)); + VERIFY_SUCCEEDED(userSession->CreateSession(&sessionSettings, Flags, &session)); wsl::windows::common::security::ConfigureForCOMImpersonation(session.get()); return session; @@ -208,8 +208,7 @@ class WSLATests wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); - wil::com_ptr session; - VERIFY_SUCCEEDED(userSession->CreateSession(&settings, &session)); + wil::com_ptr session = CreateSession(settings); // Act: list sessions { @@ -226,9 +225,8 @@ class WSLATests // List multiple sessions. { - wil::com_ptr session2; settings.DisplayName = L"wsla-test-list-2"; - VERIFY_SUCCEEDED(userSession->CreateSession(&settings, &session2)); + auto session2 = CreateSession(settings); wil::unique_cotaskmem_array_ptr sessions; VERIFY_SUCCEEDED(userSession->ListSessions(&sessions, sessions.size_address())); @@ -261,8 +259,7 @@ class WSLATests wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); - wil::com_ptr created; - VERIFY_SUCCEEDED(userSession->CreateSession(&settings, &created)); + wil::com_ptr created = CreateSession(settings); // Act: open by the same display name wil::com_ptr opened; @@ -409,7 +406,7 @@ class WSLATests auto session = CreateSession(settings); auto detach = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { - session->Shutdown(30 * 1000); + session.reset(); if (thread.joinable()) { thread.join(); @@ -420,7 +417,7 @@ class WSLATests ExpectCommandResult(session.get(), {"/bin/sh", "-c", "echo DmesgTest > /dev/kmsg"}, 0); - VERIFY_ARE_EQUAL(session->Shutdown(30 * 1000), S_OK); + session.reset(); detach.reset(); auto contentString = std::string(dmesgContent.begin(), dmesgContent.end()); @@ -487,9 +484,7 @@ class WSLATests auto session = CreateSession(sessionSettings); - wil::com_ptr vm; - VERIFY_SUCCEEDED(session->GetVirtualMachine(&vm)); - VERIFY_SUCCEEDED(vm->Shutdown(30 * 1000)); + session.reset(); auto future = promise.get_future(); auto result = future.wait_for(std::chrono::seconds(30)); auto [reason, details] = future.get(); @@ -760,9 +755,6 @@ class WSLATests auto session = CreateSession(settings); - wil::com_ptr vm; - VERIFY_SUCCEEDED(session->GetVirtualMachine(&vm)); - auto listen = [&](short port, const char* content, bool ipv6) { auto cmd = std::format("echo -n '{}' | /usr/bin/socat -dd TCP{}-LISTEN:{},reuseaddr -", content, ipv6 ? "6" : "", port); auto process = WSLAProcessLauncher("/bin/sh", {"/bin/sh", "-c", cmd}).Launch(*session); @@ -796,10 +788,10 @@ class WSLATests }; // Map port - VERIFY_SUCCEEDED(vm->MapPort(AF_INET, 1234, 80, false)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, false)); // Validate that the same port can't be bound twice - VERIFY_ARE_EQUAL(vm->MapPort(AF_INET, 1234, 80, false), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS)); + VERIFY_ARE_EQUAL(session->MapVmPort(AF_INET, 1234, 80, false), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS)); // Check simple case listen(80, "port80", false); @@ -813,34 +805,34 @@ class WSLATests expectContent(1234, AF_INET, ""); // Add a ipv6 binding - VERIFY_SUCCEEDED(vm->MapPort(AF_INET6, 1234, 80, false)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1234, 80, false)); // Validate that ipv6 bindings work as well. listen(80, "port80ipv6", true); expectContent(1234, AF_INET6, "port80ipv6"); // Unmap the ipv4 port - VERIFY_SUCCEEDED(vm->MapPort(AF_INET, 1234, 80, true)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, true)); // Verify that a proper error is returned if the mapping doesn't exist - VERIFY_ARE_EQUAL(vm->MapPort(AF_INET, 1234, 80, true), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); + VERIFY_ARE_EQUAL(session->MapVmPort(AF_INET, 1234, 80, true), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); // Unmap the v6 port - VERIFY_SUCCEEDED(vm->MapPort(AF_INET6, 1234, 80, true)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1234, 80, true)); // Map another port as v6 only - VERIFY_SUCCEEDED(vm->MapPort(AF_INET6, 1235, 81, false)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1235, 81, false)); listen(81, "port81ipv6", true); expectContent(1235, AF_INET6, "port81ipv6"); expectNotBound(1235, AF_INET); - VERIFY_SUCCEEDED(vm->MapPort(AF_INET6, 1235, 81, true)); - VERIFY_ARE_EQUAL(vm->MapPort(AF_INET6, 1235, 81, true), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1235, 81, true)); + VERIFY_ARE_EQUAL(session->MapVmPort(AF_INET6, 1235, 81, true), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); expectNotBound(1235, AF_INET6); // Create a forking relay and stress test - VERIFY_SUCCEEDED(vm->MapPort(AF_INET, 1234, 80, false)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, false)); auto process = WSLAProcessLauncher{"/usr/bin/socat", {"/usr/bin/socat", "-dd", "TCP-LISTEN:80,fork,reuseaddr", "system:'echo -n OK'"}} @@ -853,7 +845,7 @@ class WSLATests expectContent(1234, AF_INET, "OK"); } - VERIFY_SUCCEEDED(vm->MapPort(AF_INET, 1234, 80, true)); + VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, true)); } TEST_METHOD(StuckVmTermination) @@ -876,10 +868,6 @@ class WSLATests auto session = CreateSession(settings); - wil::com_ptr vm; - VERIFY_SUCCEEDED(session->GetVirtualMachine(&vm)); - wsl::windows::common::security::ConfigureForCOMImpersonation(vm.get()); - auto expectedMountOptions = [&](bool readOnly) -> std::string { if (enableVirtioFs) { @@ -898,45 +886,45 @@ class WSLATests // Validate writeable mount. { - VERIFY_SUCCEEDED(vm->MountWindowsFolder(testFolder.c_str(), "/win-path", false)); + VERIFY_SUCCEEDED(session->MountWindowsFolder(testFolder.c_str(), "/win-path", false)); ExpectMount(session.get(), "/win-path", expectedMountOptions(false)); // Validate that mount can't be stacked on each other - VERIFY_ARE_EQUAL(vm->MountWindowsFolder(testFolder.c_str(), "/win-path", false), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS)); + VERIFY_ARE_EQUAL(session->MountWindowsFolder(testFolder.c_str(), "/win-path", false), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS)); // Validate that folder is writeable from linux ExpectCommandResult(session.get(), {"/bin/sh", "-c", "echo -n content > /win-path/file.txt && sync"}, 0); VERIFY_ARE_EQUAL(ReadFileContent(testFolder / "file.txt"), L"content"); - VERIFY_SUCCEEDED(vm->UnmountWindowsFolder("/win-path")); + VERIFY_SUCCEEDED(session->UnmountWindowsFolder("/win-path")); ExpectMount(session.get(), "/win-path", {}); } // Validate read-only mount. { - VERIFY_SUCCEEDED(vm->MountWindowsFolder(testFolder.c_str(), "/win-path", true)); + VERIFY_SUCCEEDED(session->MountWindowsFolder(testFolder.c_str(), "/win-path", true)); ExpectMount(session.get(), "/win-path", expectedMountOptions(true)); // Validate that folder is not writeable from linux ExpectCommandResult(session.get(), {"/bin/sh", "-c", "echo -n content > /win-path/file.txt"}, 1); - VERIFY_SUCCEEDED(vm->UnmountWindowsFolder("/win-path")); + VERIFY_SUCCEEDED(session->UnmountWindowsFolder("/win-path")); ExpectMount(session.get(), "/win-path", {}); } // Validate various error paths { - VERIFY_ARE_EQUAL(vm->MountWindowsFolder(L"relative-path", "/win-path", true), E_INVALIDARG); - VERIFY_ARE_EQUAL(vm->MountWindowsFolder(L"C:\\does-not-exist", "/win-path", true), HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)); - VERIFY_ARE_EQUAL(vm->UnmountWindowsFolder("/not-mounted"), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); - VERIFY_ARE_EQUAL(vm->UnmountWindowsFolder("/proc"), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); + VERIFY_ARE_EQUAL(session->MountWindowsFolder(L"relative-path", "/win-path", true), E_INVALIDARG); + VERIFY_ARE_EQUAL(session->MountWindowsFolder(L"C:\\does-not-exist", "/win-path", true), HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)); + VERIFY_ARE_EQUAL(session->UnmountWindowsFolder("/not-mounted"), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); + VERIFY_ARE_EQUAL(session->UnmountWindowsFolder("/proc"), HRESULT_FROM_WIN32(ERROR_NOT_FOUND)); // Validate that folders that are manually unmounted from the guest are handled properly - VERIFY_SUCCEEDED(vm->MountWindowsFolder(testFolder.c_str(), "/win-path", true)); + VERIFY_SUCCEEDED(session->MountWindowsFolder(testFolder.c_str(), "/win-path", true)); ExpectMount(session.get(), "/win-path", expectedMountOptions(true)); ExpectCommandResult(session.get(), {"/usr/bin/umount", "/win-path"}, 0); - VERIFY_SUCCEEDED(vm->UnmountWindowsFolder("/win-path")); + VERIFY_SUCCEEDED(session->UnmountWindowsFolder("/win-path")); } } @@ -1000,9 +988,6 @@ class WSLATests WI_ClearFlag(settings.FeatureFlags, WslaFeatureFlagsGPU); session = CreateSession(settings); - wil::com_ptr vm; - VERIFY_SUCCEEDED(session->GetVirtualMachine(&vm)); - // Validate that the GPU device is not available. ExpectMount(session.get(), "/usr/lib/wsl/drivers", {}); ExpectMount(session.get(), "/usr/lib/wsl/lib", {}); @@ -1164,24 +1149,9 @@ class WSLATests VERIFY_ARE_EQUAL(process.Get().GetStdHandle(3, reinterpret_cast(&dummyHandle)), E_INVALIDARG); // Validate that the process object correctly handle requests after the VM has terminated. - VERIFY_SUCCEEDED(session->Shutdown(30 * 1000)); + session.reset(); VERIFY_ARE_EQUAL(process.Get().Signal(WSLASignalSIGKILL), HRESULT_FROM_WIN32(ERROR_INVALID_STATE)); } - - { - - // Validate that new processes cannot be created after the VM is terminated. - const char* executable = "dummy"; - WSLA_PROCESS_OPTIONS options{}; - options.CommandLine = &executable; - options.Executable = executable; - options.CommandLineCount = 1; - - wil::com_ptr process; - int error{}; - VERIFY_ARE_EQUAL(session->CreateRootNamespaceProcess(&options, &process, &error), HRESULT_FROM_WIN32(ERROR_INVALID_STATE)); - VERIFY_ARE_EQUAL(error, -1); - } } TEST_METHOD(CrashDumpCollection) @@ -1266,7 +1236,7 @@ class WSLATests wsl::core::filesystem::CreateVhd(formatedVhd, 100 * 1024 * 1024, tokenInfo->User.Sid, false, false); auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() { - LOG_IF_FAILED(session->Shutdown(30 * 1000)); + session.reset(); LOG_IF_WIN32_BOOL_FALSE(DeleteFileW(formatedVhd)); }); @@ -2371,4 +2341,99 @@ class WSLATests VERIFY_ARE_EQUAL(wil::ResultFromException([&]() { runTest(input, "", ""); }), E_INVALIDARG); } } + + TEST_METHOD(PersistentSession) + { + WSL2_TEST_ONLY(); + + wil::com_ptr userSession; + VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession))); + wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get()); + + auto expectSessions = [&](const std::vector& expectedSessions) { + wil::unique_cotaskmem_array_ptr sessions; + VERIFY_SUCCEEDED(userSession->ListSessions(&sessions, sessions.size_address())); + + std::set displayNames; + for (const auto& e : sessions) + { + auto [_, inserted] = displayNames.insert(e.DisplayName); + + VERIFY_IS_TRUE(inserted); + } + + for (const auto& e : expectedSessions) + { + auto it = displayNames.find(e); + if (it == displayNames.end()) + { + LogError("Session not found: %ls", e.c_str()); + VERIFY_FAIL(); + } + + displayNames.erase(it); + } + + for (const auto& e : displayNames) + { + LogError("Unexpected session found: %ls", e.c_str()); + VERIFY_FAIL(); + } + }; + + auto create = [this](LPCWSTR Name, WSLASessionFlags Flags) { + auto settings = GetDefaultSessionSettings(); + settings.DisplayName = Name; + settings.NetworkingMode = WSLANetworkingModeNone; + + return CreateSession(settings, Flags); + }; + + // Validate that non-persistent sessions are dropped when released + { + auto session1 = create(L"session-1", WSLASessionFlagsNone); + expectSessions({L"session-1"}); + + session1.reset(); + expectSessions({}); + } + + // Validate that persistent sessions are only dropped when explicitly terminated. + { + auto session1 = create(L"session-1", WSLASessionFlagsPersistent); + expectSessions({L"session-1"}); + + session1.reset(); + expectSessions({L"session-1"}); + session1 = create(L"session-1", WSLASessionFlagsOpenExisting); + + VERIFY_SUCCEEDED(session1->Terminate()); + session1.reset(); + expectSessions({}); + } + + // Validate that sessions can be reopened by name. + { + auto session1 = create(L"session-1", WSLASessionFlagsPersistent); + expectSessions({L"session-1"}); + + session1.reset(); + expectSessions({L"session-1"}); + + auto session1Copy = + create(L"session-1", static_cast(WSLASessionFlagsPersistent | WSLASessionFlagsOpenExisting)); + + expectSessions({L"session-1"}); + + // Verify that name conflicts are correctly handled. + auto settings = GetDefaultSessionSettings(); + settings.DisplayName = L"session-1"; + + wil::com_ptr session; + VERIFY_ARE_EQUAL(userSession->CreateSession(&settings, WSLASessionFlagsPersistent, &session), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS)); + + VERIFY_SUCCEEDED(session1Copy->Terminate()); + expectSessions({}); + } + } };