diff --git a/msipackage/package.wix.in b/msipackage/package.wix.in
index 0ef367e0d..00c53fae1 100644
--- a/msipackage/package.wix.in
+++ b/msipackage/package.wix.in
@@ -268,11 +268,6 @@
-
-
-
-
-
@@ -287,14 +282,6 @@
-
-
-
-
-
-
-
-
diff --git a/src/linux/init/WSLAInit.cpp b/src/linux/init/WSLAInit.cpp
index a69c404bb..eea006aa5 100644
--- a/src/linux/init/WSLAInit.cpp
+++ b/src/linux/init/WSLAInit.cpp
@@ -656,59 +656,6 @@ void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_PORT_RELA
RunLocalHostRelay(SocketAddress, ListenSocket.get());
}
-void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_WAITPID& Message, const gsl::span& Buffer)
-{
- WSLA_WAITPID_RESULT response{};
- response.State = WSLAOpenFlagsUnknown;
-
- auto sendResponse = wil::scope_exit([&]() { Channel.SendMessage(response); });
-
- wil::unique_fd process = syscall(SYS_pidfd_open, Message.Pid, 0);
- if (!process)
- {
- LOG_ERROR("pidfd_open({}) failed, {}", Message.Pid, errno);
- response.Errno = errno;
- return;
- }
-
- pollfd pollResult{};
- pollResult.fd = process.get();
- pollResult.events = POLLIN | POLLERR;
-
- int result = poll(&pollResult, 1, Message.TimeoutMs);
- if (result < 0)
- {
- LOG_ERROR("poll failed {}", errno);
- response.Errno = errno;
- return;
- }
- else if (result == 0) // Timed out
- {
- response.State = WSLAOpenFlagsRunning;
- response.Errno = 0;
- return;
- }
-
- if (WI_IsFlagSet(pollResult.revents, POLLIN))
- {
- siginfo_t childState{};
- auto result = waitid(P_PIDFD, process.get(), &childState, WEXITED);
- if (result < 0)
- {
- LOG_ERROR("waitid({}) failed, {}", process.get(), errno);
- response.Errno = errno;
- return;
- }
-
- response.Code = childState.si_status;
- response.Errno = 0;
- response.State = childState.si_code == CLD_EXITED ? WSLAOpenFlagsExited : WSLAOpenFlagsSignaled;
- return;
- }
-
- LOG_ERROR("Poll returned an unexpected error state on fd: {} for pid: {}", process.get(), Message.Pid);
-}
-
void HandleMessageImpl(wsl::shared::SocketChannel& Channel, const WSLA_SIGNAL& Message, const gsl::span& Buffer)
{
auto result = kill(Message.Pid, Message.Signal);
@@ -866,7 +813,7 @@ void ProcessMessage(wsl::shared::SocketChannel& Channel, LX_MESSAGE_TYPE Type, c
{
try
{
- HandleMessage(
+ HandleMessage(
Channel, Type, Buffer);
}
catch (...)
@@ -882,7 +829,7 @@ void ProcessMessages(wsl::shared::SocketChannel& Channel)
while (Channel.Connected())
{
auto [Message, Range] = Channel.ReceiveMessageOrClosed();
- if (Message == nullptr || Message->MessageType == LxMessageWSLAShutdown)
+ if (Message == nullptr)
{
break;
}
diff --git a/src/shared/inc/lxinitshared.h b/src/shared/inc/lxinitshared.h
index 3aa633968..2c596459e 100644
--- a/src/shared/inc/lxinitshared.h
+++ b/src/shared/inc/lxinitshared.h
@@ -390,7 +390,6 @@ typedef enum _LX_MESSAGE_TYPE
LxMessageWSLAWaitPid,
LxMessageWSLAWaitPidResponse,
LxMessageWSLASignal,
- LxMessageWSLAShutdown,
LxMessageWSLARelayTty,
LxMessageWSLAMapPort,
LxMessageWSLAConnectRelay,
@@ -502,7 +501,6 @@ inline auto ToString(LX_MESSAGE_TYPE messageType)
X(LxMessageWSLAWaitPid)
X(LxMessageWSLAWaitPidResponse)
X(LxMessageWSLASignal)
- X(LxMessageWSLAShutdown)
X(LxMessageWSLARelayTty)
X(LxMessageWSLAMapPort)
X(LxMessageWSLAConnectRelay)
@@ -1734,33 +1732,6 @@ enum WSLAOpenFlags
WSLAOpenFlagsSignaled
};
-struct WSLA_WAITPID_RESULT
-{
- static inline auto Type = LxMessageWSLAWaitPidResponse;
-
- DECLARE_MESSAGE_CTOR(WSLA_WAITPID_RESULT);
-
- MESSAGE_HEADER Header;
- WSLAOpenFlags State = WSLAOpenFlagsUnknown;
- int32_t Code = -1;
- int32_t Errno = -1;
- PRETTY_PRINT(FIELD(Header), FIELD(State), FIELD(Code), FIELD(Errno));
-};
-
-struct WSLA_WAITPID
-{
- static inline auto Type = LxMessageWSLAWaitPid;
- using TResponse = WSLA_WAITPID_RESULT;
-
- DECLARE_MESSAGE_CTOR(WSLA_WAITPID);
-
- MESSAGE_HEADER Header;
- int32_t Pid = -1;
- uint64_t TimeoutMs = 0;
-
- PRETTY_PRINT(FIELD(Header), FIELD(Pid), FIELD(TimeoutMs));
-};
-
struct WSLA_SIGNAL
{
static inline auto Type = LxMessageWSLASignal;
@@ -1775,17 +1746,6 @@ struct WSLA_SIGNAL
PRETTY_PRINT(FIELD(Header), FIELD(Pid), FIELD(Signal));
};
-struct WSLA_SHUTDOWN
-{
- static inline auto Type = LxMessageWSLAShutdown;
- using TResponse = RESULT_MESSAGE;
-
- DECLARE_MESSAGE_CTOR(WSLA_SHUTDOWN);
- MESSAGE_HEADER Header;
-
- PRETTY_PRINT(FIELD(Header));
-};
-
struct WSLA_MAP_PORT
{
static inline auto Type = LxMessageWSLAMapPort;
diff --git a/src/windows/common/WslClient.cpp b/src/windows/common/WslClient.cpp
index 3b7b8ba05..7f713854b 100644
--- a/src/windows/common/WslClient.cpp
+++ b/src/windows/common/WslClient.cpp
@@ -1596,7 +1596,6 @@ int WslaShell(_In_ std::wstring_view commandLine)
THROW_IF_FAILED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession)));
wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get());
- wil::com_ptr virtualMachine;
wil::com_ptr session;
if (!rootVhdOverride.empty())
@@ -1623,8 +1622,7 @@ int WslaShell(_In_ std::wstring_view commandLine)
}
else
{
- THROW_IF_FAILED(userSession->CreateSession(&sessionSettings, &session));
- THROW_IF_FAILED(session->GetVirtualMachine(&virtualMachine));
+ THROW_IF_FAILED(userSession->CreateSession(&sessionSettings, WSLASessionFlagsNone, &session));
wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get());
}
diff --git a/src/windows/wslaservice/exe/WSLAContainer.cpp b/src/windows/wslaservice/exe/WSLAContainer.cpp
index 330c650fd..dc802d940 100644
--- a/src/windows/wslaservice/exe/WSLAContainer.cpp
+++ b/src/windows/wslaservice/exe/WSLAContainer.cpp
@@ -58,12 +58,11 @@ auto ProcessPortMappings(const WSLA_CONTAINER_OPTIONS& options, WSLAVirtualMachi
{
if (e.MappedToHost)
{
- LOG_IF_FAILED_MSG(
- vm.MapPort(e.Family, e.HostPort, e.VmPort, true),
- "Failed to unmap port (family=%i, guestPort=%u, hostPort=%u)",
- e.Family,
- e.VmPort,
- e.HostPort);
+ try
+ {
+ vm.MapPort(e.Family, e.HostPort, e.VmPort, true);
+ }
+ CATCH_LOG();
}
}
});
@@ -116,7 +115,7 @@ auto ProcessPortMappings(const WSLA_CONTAINER_OPTIONS& options, WSLAVirtualMachi
// Map Windows <-> VM ports.
for (auto& e : *mappedPorts)
{
- THROW_IF_FAILED(vm.MapPort(e.Family, e.HostPort, e.VmPort, false));
+ vm.MapPort(e.Family, e.HostPort, e.VmPort, false);
e.MappedToHost = true;
}
@@ -201,12 +200,11 @@ WSLAContainerImpl::~WSLAContainerImpl()
{
WI_ASSERT(e.MappedToHost);
- LOG_IF_FAILED_MSG(
- m_parentVM->MapPort(e.Family, e.HostPort, e.VmPort, true),
- "Failed to delete port mapping (family=%i, guestPort=%u, hostPort=%u)",
- e.Family,
- e.VmPort,
- e.HostPort);
+ try
+ {
+ m_parentVM->MapPort(e.Family, e.HostPort, e.VmPort, true);
+ }
+ CATCH_LOG();
allocatedGuestPorts.insert(e.VmPort);
}
diff --git a/src/windows/wslaservice/exe/WSLAProcessControl.cpp b/src/windows/wslaservice/exe/WSLAProcessControl.cpp
index 4632aa1fc..65e27a28f 100644
--- a/src/windows/wslaservice/exe/WSLAProcessControl.cpp
+++ b/src/windows/wslaservice/exe/WSLAProcessControl.cpp
@@ -181,7 +181,7 @@ void VMProcessControl::Signal(int Signal)
std::lock_guard lock{m_lock};
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), m_vm == nullptr || m_exitEvent.is_signaled());
- THROW_IF_FAILED(m_vm->Signal(m_pid, Signal));
+ m_vm->Signal(m_pid, Signal);
}
void VMProcessControl::ResizeTty(ULONG Rows, ULONG Columns)
diff --git a/src/windows/wslaservice/exe/WSLASession.cpp b/src/windows/wslaservice/exe/WSLASession.cpp
index 08e4d299d..ef1e2318b 100644
--- a/src/windows/wslaservice/exe/WSLASession.cpp
+++ b/src/windows/wslaservice/exe/WSLASession.cpp
@@ -46,7 +46,7 @@ WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUs
{
WSL_LOG("SessionCreated", TraceLoggingValue(m_displayName.c_str(), "DisplayName"));
- m_virtualMachine = wil::MakeOrThrow(CreateVmSettings(Settings), userSessionImpl.GetUserSid());
+ m_virtualMachine.emplace(CreateVmSettings(Settings), userSessionImpl.GetUserSid());
if (Settings.TerminationCallback != nullptr)
{
@@ -74,7 +74,7 @@ WSLASession::WSLASession(ULONG id, const WSLA_SESSION_SETTINGS& Settings, WSLAUs
{"/usr/bin/dockerd" /*, "--debug"*/}, // TODO: Flag for --debug.
{{"PATH=/bin:/usr/local/sbin:/usr/bin:/usr/sbin:/sbin"}},
common::ProcessFlags::Stdout | common::ProcessFlags::Stderr};
- m_containerdThread = std::thread(&WSLASession::MonitorContainerd, this, launcher.Launch(*m_virtualMachine.Get()));
+ m_containerdThread = std::thread(&WSLASession::MonitorContainerd, this, launcher.Launch(*m_virtualMachine));
// Wait for containerd to be ready before starting the event tracker.
// TODO: Configurable timeout.
@@ -134,34 +134,7 @@ WSLASession::~WSLASession()
{
WSL_LOG("SessionTerminated", TraceLoggingValue(m_displayName.c_str(), "DisplayName"));
- std::lock_guard lock{m_lock};
-
- // Stop the event tracker
- if (m_eventTracker.has_value())
- {
- m_eventTracker->Stop();
- }
-
- // This will delete all containers. Needs to be done before the VM is terminated.
- // TODO: If callers still have references to containers, the instances won't actually be deleted.
- m_containers.clear();
-
- m_sessionTerminatingEvent.SetEvent();
-
- // N.B. The containerd thread can only run if the VM is running.
- if (m_containerdThread.joinable())
- {
- m_containerdThread.join();
- }
-
- if (m_virtualMachine)
- {
- // N.B. containerd has exited by this point, so unmounting the VHD is safe since no container can be running.
- LOG_IF_FAILED(m_virtualMachine->Unmount(c_containerdStorage));
- m_virtualMachine->OnSessionTerminated();
-
- m_virtualMachine.Reset();
- }
+ Terminate();
}
void WSLASession::ConfigureStorage(const WSLA_SESSION_SETTINGS& Settings, PSID UserSid)
@@ -484,7 +457,7 @@ try
containerOptions->Name,
WSLAContainerImpl::Create(
*containerOptions,
- *m_virtualMachine.Get(),
+ *m_virtualMachine,
std::bind(&WSLASession::OnContainerDeleted, this, std::placeholders::_1),
m_eventTracker.value(),
m_dockerClient.value()));
@@ -535,15 +508,6 @@ try
}
CATCH_RETURN();
-HRESULT WSLASession::GetVirtualMachine(IWSLAVirtualMachine** VirtualMachine)
-{
- std::lock_guard lock{m_lock};
- THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine);
-
- THROW_IF_FAILED(m_virtualMachine->QueryInterface(__uuidof(IWSLAVirtualMachine), (void**)VirtualMachine));
- return S_OK;
-}
-
HRESULT WSLASession::CreateRootNamespaceProcess(const WSLA_PROCESS_OPTIONS* Options, IWSLAProcess** Process, int* Errno)
try
{
@@ -555,7 +519,10 @@ try
std::lock_guard lock{m_lock};
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine);
- return m_virtualMachine->CreateLinuxProcess(Options, Process, Errno);
+ auto process = m_virtualMachine->CreateLinuxProcess(*Options, Errno);
+ THROW_IF_FAILED(process.CopyTo(Process));
+
+ return S_OK;
}
CATCH_RETURN();
@@ -563,7 +530,7 @@ void WSLASession::Ext4Format(const std::string& Device)
{
constexpr auto mkfsPath = "/usr/sbin/mkfs.ext4";
ServiceProcessLauncher launcher(mkfsPath, {mkfsPath, Device});
- auto result = launcher.Launch(*m_virtualMachine.Get()).WaitAndCaptureOutput();
+ auto result = launcher.Launch(*m_virtualMachine).WaitAndCaptureOutput();
THROW_HR_IF_MSG(E_FAIL, result.Code != 0, "%hs", launcher.FormatResult(result).c_str());
}
@@ -590,25 +557,80 @@ try
CATCH_RETURN();
void WSLASession::OnUserSessionTerminating()
+{
+ LOG_IF_FAILED(Terminate());
+}
+
+HRESULT WSLASession::Terminate()
+try
{
// m_sessionTerminatingEvent is always valid, so it can be signalled with the lock.
// This allows a session to be unblocked if a stuck operation is holding the lock.
m_sessionTerminatingEvent.SetEvent();
std::lock_guard lock{m_lock};
+
+ // Stop the event tracker
+ if (m_eventTracker.has_value())
+ {
+ m_eventTracker->Stop();
+ }
+
+ // This will delete all containers. Needs to be done before the VM is terminated.
+ m_containers.clear();
+
m_dockerClient.reset();
- m_virtualMachine.Reset();
+
+ // N.B. The containerd thread can only run if the VM is running.
+ if (m_containerdThread.joinable())
+ {
+ m_containerdThread.join();
+ }
+
+ if (m_virtualMachine)
+ {
+ // N.B. containerd has exited by this point, so unmounting the VHD is safe since no container can be running.
+ try
+ {
+ m_virtualMachine->Unmount(c_containerdStorage);
+ }
+ CATCH_LOG();
+
+ m_virtualMachine->OnSessionTerminated();
+ m_virtualMachine.reset();
+ }
+
+ return S_OK;
}
+CATCH_RETURN();
-HRESULT WSLASession::Shutdown(ULONG Timeout)
+HRESULT WSLASession::MountWindowsFolder(LPCWSTR WindowsPath, LPCSTR LinuxPath, BOOL ReadOnly)
try
{
std::lock_guard lock{m_lock};
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine);
- THROW_IF_FAILED(m_virtualMachine->Shutdown(Timeout));
+ return m_virtualMachine->MountWindowsFolder(WindowsPath, LinuxPath, ReadOnly);
+}
+CATCH_RETURN();
- m_virtualMachine.Reset();
+HRESULT WSLASession::UnmountWindowsFolder(LPCSTR LinuxPath)
+try
+{
+ std::lock_guard lock{m_lock};
+ THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine);
+
+ return m_virtualMachine->UnmountWindowsFolder(LinuxPath);
+}
+CATCH_RETURN();
+
+HRESULT WSLASession::MapVmPort(int Family, short WindowsPort, short LinuxPort, BOOL Remove)
+try
+{
+ std::lock_guard lock{m_lock};
+ THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_virtualMachine);
+
+ m_virtualMachine->MapPort(Family, WindowsPort, LinuxPort, Remove != FALSE);
return S_OK;
}
CATCH_RETURN();
@@ -627,4 +649,10 @@ HRESULT WSLASession::GetImplNoRef(_Out_ WSLASession** Session)
// beyond that lifetime.
*Session = this;
return S_OK;
+}
+
+bool WSLASession::Terminated()
+{
+ std::lock_guard lock{m_lock};
+ return !m_virtualMachine;
}
\ No newline at end of file
diff --git a/src/windows/wslaservice/exe/WSLASession.h b/src/windows/wslaservice/exe/WSLASession.h
index 4e1792d38..8cb5f5a54 100644
--- a/src/windows/wslaservice/exe/WSLASession.h
+++ b/src/windows/wslaservice/exe/WSLASession.h
@@ -62,18 +62,24 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession
IFACEMETHOD(ListContainers)(_Out_ WSLA_CONTAINER** Images, _Out_ ULONG* Count) override;
// VM management.
- IFACEMETHOD(GetVirtualMachine)(IWSLAVirtualMachine** VirtualMachine) override;
IFACEMETHOD(CreateRootNamespaceProcess)(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** VirtualMachine, _Out_ int* Errno) override;
// Disk management.
IFACEMETHOD(FormatVirtualDisk)(_In_ LPCWSTR Path) override;
- IFACEMETHOD(Shutdown(_In_ ULONG)) override;
+ IFACEMETHOD(Terminate()) override;
IFACEMETHOD(GetImplNoRef)(_Out_ WSLASession** Session) override;
+ // Testing.
+ IFACEMETHOD(MountWindowsFolder)(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly) override;
+ IFACEMETHOD(UnmountWindowsFolder)(_In_ LPCSTR LinuxPath) override;
+ IFACEMETHOD(MapVmPort)(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove) override;
+
void OnUserSessionTerminating();
+ bool Terminated();
+
private:
ULONG m_id = 0;
@@ -88,7 +94,7 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession
WSLA_SESSION_SETTINGS m_sessionSettings; // TODO: Revisit to see if we should have session settings as a member or not
std::optional m_dockerClient;
- Microsoft::WRL::ComPtr m_virtualMachine;
+ std::optional m_virtualMachine;
std::optional m_eventTracker;
wil::unique_event m_containerdReadyEvent{wil::EventOptions::ManualReset};
std::thread m_containerdThread;
diff --git a/src/windows/wslaservice/exe/WSLAUserSession.cpp b/src/windows/wslaservice/exe/WSLAUserSession.cpp
index 0e006e79b..7ae9f307b 100644
--- a/src/windows/wslaservice/exe/WSLAUserSession.cpp
+++ b/src/windows/wslaservice/exe/WSLAUserSession.cpp
@@ -34,26 +34,49 @@ PSID WSLAUserSessionImpl::GetUserSid() const
return m_tokenInfo->User.Sid;
}
-HRESULT WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** WslaSession)
+HRESULT WSLAUserSessionImpl::CreateSession(const WSLA_SESSION_SETTINGS* Settings, WSLASessionFlags Flags, IWSLASession** WslaSession)
+try
{
ULONG id = m_nextSessionId++;
- auto session = wil::MakeOrThrow(id, *Settings, *this);
- Microsoft::WRL::ComPtr weakRef;
- THROW_IF_FAILED(session->GetWeakReference(&weakRef));
+ std::lock_guard lock(m_wslaSessionsLock);
+
+ // Check for an existing session first.
+ auto result = ForEachSession([&](auto& session) -> std::optional {
+ // TODO: ACL check.
+ if (session.DisplayName() == Settings->DisplayName)
+ {
+ RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS), WI_IsFlagClear(Flags, WSLASessionFlagsOpenExisting));
+
+ return session.QueryInterface(__uuidof(IWSLASession), (void**)WslaSession);
+ }
+
+ return std::optional{};
+ });
+ if (result.has_value())
{
- std::lock_guard lock(m_wslaSessionsLock);
- m_sessions.emplace_back(std::move(weakRef));
+ return result.value();
}
- // Client now owns the session.
- // TODO: Add a flag for the client to specify that the session should outlive its process.
+ // No session was found, create a new one.
+ auto session = wil::MakeOrThrow(id, *Settings, *this);
+
+ if (WI_IsFlagSet(Flags, WSLASessionFlagsPersistent))
+ {
+ m_persistentSessions.push_back(session);
+ }
+
+ Microsoft::WRL::ComPtr weakRef;
+ THROW_IF_FAILED(session->GetWeakReference(&weakRef));
+
+ m_sessions.emplace_back(std::move(weakRef));
THROW_IF_FAILED(session.CopyTo(__uuidof(IWSLASession), (void**)WslaSession));
return S_OK;
}
+CATCH_RETURN();
HRESULT WSLAUserSessionImpl::OpenSessionByName(LPCWSTR DisplayName, IWSLASession** Session)
{
@@ -113,13 +136,13 @@ HRESULT wsl::windows::service::wsla::WSLAUserSession::GetVersion(_Out_ WSLA_VERS
return S_OK;
}
-HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** WslaSession)
+HRESULT wsl::windows::service::wsla::WSLAUserSession::CreateSession(const WSLA_SESSION_SETTINGS* Settings, WSLASessionFlags Flags, IWSLASession** WslaSession)
try
{
auto session = m_session.lock();
RETURN_HR_IF(RPC_E_DISCONNECTED, !session);
- return session->CreateSession(Settings, WslaSession);
+ return session->CreateSession(Settings, Flags, WslaSession);
}
CATCH_RETURN();
diff --git a/src/windows/wslaservice/exe/WSLAUserSession.h b/src/windows/wslaservice/exe/WSLAUserSession.h
index b1f07a4b9..74e5aabad 100644
--- a/src/windows/wslaservice/exe/WSLAUserSession.h
+++ b/src/windows/wslaservice/exe/WSLAUserSession.h
@@ -33,7 +33,7 @@ class WSLAUserSessionImpl
PSID GetUserSid() const;
- HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, IWSLASession** WslaSession);
+ HRESULT CreateSession(const WSLA_SESSION_SETTINGS* Settings, WSLASessionFlags Flags, IWSLASession** WslaSession);
HRESULT OpenSessionByName(_In_ LPCWSTR DisplayName, _Out_ IWSLASession** Session);
HRESULT ListSessions(_Out_ WSLA_SESSION_INFORMATION** Sessions, _Out_ ULONG* SessionsCount);
@@ -57,6 +57,18 @@ class WSLAUserSessionImpl
WSLASession* SessionImpl{};
THROW_IF_FAILED(lockedSession->GetImplNoRef(&SessionImpl));
+ // If the session is terminated, drop its reference so it can be deleted (in case of persistent sessions)
+ if (SessionImpl->Terminated())
+ {
+ auto remove =
+ std::ranges::remove_if(m_persistentSessions, [&](const auto& e) { return SessionImpl->GetId() == e->GetId(); });
+
+ WI_ASSERT(remove.end() - remove.begin() <= 1);
+
+ m_persistentSessions.erase(remove.begin(), remove.end());
+ return true;
+ }
+
if constexpr (std::is_same_v)
{
Routine(*SessionImpl);
@@ -90,6 +102,8 @@ class WSLAUserSessionImpl
std::atomic m_nextSessionId{1};
std::recursive_mutex m_wslaSessionsLock;
+ // Persistent sessions that outlive their creating process.
+ std::vector> m_persistentSessions;
std::vector> m_sessions;
};
@@ -102,7 +116,7 @@ class DECLSPEC_UUID("a9b7a1b9-0671-405c-95f1-e0612cb4ce8f") WSLAUserSession
WSLAUserSession& operator=(const WSLAUserSession&) = delete;
IFACEMETHOD(GetVersion)(_Out_ WSLA_VERSION* Version) override;
- IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* WslaSessionSettings, IWSLASession** WslaSession) override;
+ IFACEMETHOD(CreateSession)(const WSLA_SESSION_SETTINGS* WslaSessionSettings, WSLASessionFlags Flags, IWSLASession** WslaSession) override;
IFACEMETHOD(ListSessions)(_Out_ WSLA_SESSION_INFORMATION** Sessions, _Out_ ULONG* SessionsCount) override;
IFACEMETHOD(OpenSession)(_In_ ULONG Id, _Out_ IWSLASession** Session) override;
IFACEMETHOD(OpenSessionByName)(_In_ LPCWSTR DisplayName, _Out_ IWSLASession** Session) override;
diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp
index 659c0550d..370db51fc 100644
--- a/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp
+++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.cpp
@@ -736,8 +736,7 @@ std::pair WSLAVirtualMachine::AttachDisk(_In_ PCWSTR Path, _
return {Lun, Device};
}
-HRESULT WSLAVirtualMachine::Unmount(_In_ const char* Path)
-try
+void WSLAVirtualMachine::Unmount(_In_ const char* Path)
{
auto [pid, _, subChannel] = Fork(WSLA_FORK::Thread);
@@ -749,10 +748,7 @@ try
// TODO: Return errno to caller
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), response.Result == EINVAL);
THROW_HR_IF(E_FAIL, response.Result != 0);
-
- return S_OK;
}
-CATCH_RETURN()
void WSLAVirtualMachine::DetachDisk(_In_ ULONG Lun)
{
@@ -859,15 +855,6 @@ void WSLAVirtualMachine::OpenLinuxFile(wsl::shared::SocketChannel& Channel, cons
THROW_HR_IF_MSG(E_FAIL, result != 0, "Failed to open %hs (flags: %u), %i", Path, Flags, result);
}
-HRESULT WSLAVirtualMachine::CreateLinuxProcess(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** Process, _Out_ int* Errno)
-try
-{
- CreateLinuxProcess(*Options, Errno).CopyTo(Process);
-
- return S_OK;
-}
-CATCH_RETURN();
-
Microsoft::WRL::ComPtr WSLAVirtualMachine::CreateLinuxProcess(_In_ const WSLA_PROCESS_OPTIONS& Options, int* Errno, const TPrepareCommandLine& PrepareCommandLine)
{
// N.B This check is there to prevent processes from being started before the VM is done initializing.
@@ -1046,46 +1033,7 @@ int32_t WSLAVirtualMachine::ExpectClosedChannelOrError(wsl::shared::SocketChanne
}
}
-HRESULT WSLAVirtualMachine::WaitPid(LONG Pid, ULONGLONG TimeoutMs, ULONG* State, int* Code)
-try
-{
- auto [pid, _, subChannel] = Fork(WSLA_FORK::Thread);
-
- WSLA_WAITPID message{};
- message.Pid = Pid;
- message.TimeoutMs = TimeoutMs;
-
- const auto& response = subChannel.Transaction(message);
-
- THROW_HR_IF(E_FAIL, response.State == WSLAOpenFlagsUnknown);
-
- *State = response.State;
- *Code = response.Code;
-
- return S_OK;
-}
-CATCH_RETURN();
-
-HRESULT WSLAVirtualMachine::Shutdown(ULONGLONG TimeoutMs)
-try
-{
- std::lock_guard lock(m_lock);
-
- THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_running);
-
- WSLA_SHUTDOWN message{};
- m_initChannel.SendMessage(message);
- auto response = m_initChannel.ReceiveMessageOrClosed(static_cast(TimeoutMs));
-
- RETURN_HR_IF(E_UNEXPECTED, response.first != nullptr);
-
- m_running = false;
- return S_OK;
-}
-CATCH_RETURN();
-
-HRESULT WSLAVirtualMachine::Signal(_In_ LONG Pid, _In_ int Signal)
-try
+void WSLAVirtualMachine::Signal(_In_ LONG Pid, _In_ int Signal)
{
std::lock_guard lock(m_lock);
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_running);
@@ -1095,10 +1043,8 @@ try
message.Signal = Signal;
const auto& response = m_initChannel.Transaction(message);
- RETURN_HR_IF(E_FAIL, response.Result != 0);
- return S_OK;
+ THROW_HR_IF(E_FAIL, response.Result != 0);
}
-CATCH_RETURN();
void WSLAVirtualMachine::RegisterCallback(ITerminationCallback* callback)
{
@@ -1192,12 +1138,11 @@ void WSLAVirtualMachine::LaunchPortRelay()
writePipe.release();
}
-HRESULT WSLAVirtualMachine::MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)
-try
+void WSLAVirtualMachine::MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)
{
std::lock_guard lock(m_portRelaylock);
- RETURN_HR_IF(E_ILLEGAL_STATE_CHANGE, !m_portRelayChannelWrite);
+ THROW_HR_IF(E_ILLEGAL_STATE_CHANGE, !m_portRelayChannelWrite);
WSLA_MAP_PORT message;
message.WindowsPort = WindowsPort;
@@ -1213,10 +1158,8 @@ try
THROW_IF_WIN32_BOOL_FALSE(ReadFile(m_portRelayChannelRead.get(), &result, sizeof(result), &bytesTransfered, nullptr));
THROW_HR_IF(E_UNEXPECTED, bytesTransfered != sizeof(result));
-
- return result;
+ THROW_IF_FAILED_MSG(result, "Failed to map port: WindowsPort=%d, LinuxPort=%d, Family=%d, Remove=%d", WindowsPort, LinuxPort, Family, Remove);
}
-CATCH_RETURN();
HRESULT WSLAVirtualMachine::MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly)
{
@@ -1332,7 +1275,6 @@ void WSLAVirtualMachine::RemoveShare(_In_ const MountedFolderInfo& MountInfo)
}
HRESULT WSLAVirtualMachine::UnmountWindowsFolder(_In_ LPCSTR LinuxPath)
-try
{
std::lock_guard lock(m_lock);
@@ -1341,7 +1283,7 @@ try
THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_NOT_FOUND), it == m_mountedWindowsFolders.end());
// Unmount the folder from the guest. If the mount is not found, this most likely means that the guest unmounted it.
- auto result = Unmount(LinuxPath);
+ auto result = wil::ResultFromException([&]() { Unmount(LinuxPath); });
THROW_HR_IF(result, FAILED(result) && result != HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
auto mountInfo = it->second;
@@ -1352,7 +1294,6 @@ try
return S_OK;
}
-CATCH_RETURN();
void WSLAVirtualMachine::MountGpuLibraries(_In_ LPCSTR LibrariesMountPoint, _In_ LPCSTR DriversMountpoint)
{
diff --git a/src/windows/wslaservice/exe/WSLAVirtualMachine.h b/src/windows/wslaservice/exe/WSLAVirtualMachine.h
index 6ac41c3d8..8a2087f5f 100644
--- a/src/windows/wslaservice/exe/WSLAVirtualMachine.h
+++ b/src/windows/wslaservice/exe/WSLAVirtualMachine.h
@@ -34,8 +34,7 @@ enum WSLAMountFlags
class WSLAUserSessionImpl;
-class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine
- : public Microsoft::WRL::RuntimeClass, IWSLAVirtualMachine, IFastRundown>
+class WSLAVirtualMachine
{
public:
@@ -73,14 +72,12 @@ class DECLSPEC_UUID("0CFC5DC1-B6A7-45FC-8034-3FA9ED73CE30") WSLAVirtualMachine
void Start();
void OnSessionTerminated();
- IFACEMETHOD(CreateLinuxProcess(_In_ const WSLA_PROCESS_OPTIONS* Options, _Out_ IWSLAProcess** Process, _Out_ int* Errno)) override;
- IFACEMETHOD(WaitPid(_In_ LONG Pid, _In_ ULONGLONG TimeoutMs, _Out_ ULONG* State, _Out_ int* Code)) override;
- IFACEMETHOD(Signal(_In_ LONG Pid, _In_ int Signal)) override;
- IFACEMETHOD(Shutdown(ULONGLONG _In_ TimeoutMs)) override;
- IFACEMETHOD(MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove)) override;
- IFACEMETHOD(Unmount(_In_ const char* Path)) override;
- IFACEMETHOD(MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly)) override;
- IFACEMETHOD(UnmountWindowsFolder(_In_ LPCSTR LinuxPath)) override;
+ void MapPort(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort, _In_ BOOL Remove);
+ void Unmount(_In_ const char* Path);
+
+ HRESULT MountWindowsFolder(_In_ LPCWSTR WindowsPath, _In_ LPCSTR LinuxPath, _In_ BOOL ReadOnly);
+ HRESULT UnmountWindowsFolder(_In_ LPCSTR LinuxPath);
+ void Signal(_In_ LONG Pid, _In_ int Signal);
void OnProcessReleased(int Pid);
void RegisterCallback(_In_ ITerminationCallback* callback);
diff --git a/src/windows/wslaservice/inc/wslaservice.idl b/src/windows/wslaservice/inc/wslaservice.idl
index 517b61f64..7722b43a7 100644
--- a/src/windows/wslaservice/inc/wslaservice.idl
+++ b/src/windows/wslaservice/inc/wslaservice.idl
@@ -236,25 +236,6 @@ interface IWSLAProcess : IUnknown
// Note: the SDK can offer a convenience Wait() method, but that doesn't need to be part of the service API.
}
-
-// TODO: Delete once the new API is wired.
-[
- uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8761),
- pointer_default(unique),
- object
-]
-interface IWSLAVirtualMachine : IUnknown
-{
- HRESULT CreateLinuxProcess([in] const struct WSLA_PROCESS_OPTIONS* Options, [out] IWSLAProcess** Process, [out] int* Errno);
- HRESULT WaitPid([in] LONG Pid, [in] ULONGLONG TimeoutMs, [out] ULONG* State, [out] int* Code);
- HRESULT Signal([in] LONG Pid, [in] int Signal);
- HRESULT Shutdown([in] ULONGLONG TimeoutMs);
- HRESULT MapPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove);
- HRESULT Unmount([in] LPCSTR Path);
- HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly);
- HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath);
-}
-
typedef enum _WSLANetworkingMode
{
WSLANetworkingModeNone,
@@ -338,10 +319,12 @@ interface IWSLASession : IUnknown
HRESULT FormatVirtualDisk([in] LPCWSTR Path);
// Terminate the VM and containers.
- HRESULT Shutdown([in] ULONG TimeoutMs);
+ HRESULT Terminate();
- // To be deleted.
- HRESULT GetVirtualMachine([out] IWSLAVirtualMachine **VirtualMachine);
+ // Used only for testing. TODO: Think about moving them to a dedicated testing-only interface.
+ HRESULT MountWindowsFolder([in] LPCWSTR WindowsPath, [in] LPCSTR LinuxPath, [in] BOOL ReadOnly);
+ HRESULT UnmountWindowsFolder([in] LPCSTR LinuxPath);
+ HRESULT MapVmPort([in] int Family, [in] short WindowsPort, [in] short LinuxPort, [in] BOOL Remove);
}
struct WSLA_SESSION_INFORMATION
@@ -351,6 +334,13 @@ struct WSLA_SESSION_INFORMATION
wchar_t DisplayName[256];
};
+typedef enum _WSLASessionFlags
+{
+ WSLASessionFlagsNone = 0,
+ WSLASessionFlagsPersistent = 1, // Session remains active after its COM reference is released.
+ WSLASessionFlagsOpenExisting = 2 // Open an existing session if the name is in use.
+} WSLASessionFlags;
+
[
uuid(82A7ABC8-6B50-43FC-AB96-15FBBE7E8760),
pointer_default(unique),
@@ -361,7 +351,7 @@ interface IWSLAUserSession : IUnknown
HRESULT GetVersion([out] WSLA_VERSION* Version);
// Session managment.
- HRESULT CreateSession([in] const struct WSLA_SESSION_SETTINGS* Settings, [out]IWSLASession** Session);
+ HRESULT CreateSession([in] const struct WSLA_SESSION_SETTINGS* Settings, WSLASessionFlags Flags, [out]IWSLASession** Session);
HRESULT ListSessions([out, size_is(, *SessionsCount)] struct WSLA_SESSION_INFORMATION** Sessions, [out] ULONG* SessionsCount);
HRESULT OpenSession([in] ULONG Id, [out]IWSLASession** Session);
HRESULT OpenSessionByName([in] LPCWSTR DisplayName, [out] IWSLASession** Session);
diff --git a/test/windows/WSLATests.cpp b/test/windows/WSLATests.cpp
index 0a49c5d4e..ab90c2636 100644
--- a/test/windows/WSLATests.cpp
+++ b/test/windows/WSLATests.cpp
@@ -87,7 +87,7 @@ class WSLATests
return settings;
}
- wil::com_ptr CreateSession(const WSLA_SESSION_SETTINGS& sessionSettings = GetDefaultSessionSettings())
+ wil::com_ptr CreateSession(const WSLA_SESSION_SETTINGS& sessionSettings = GetDefaultSessionSettings(), WSLASessionFlags Flags = WSLASessionFlagsNone)
{
wil::com_ptr userSession;
VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession)));
@@ -95,7 +95,7 @@ class WSLATests
wil::com_ptr session;
- VERIFY_SUCCEEDED(userSession->CreateSession(&sessionSettings, &session));
+ VERIFY_SUCCEEDED(userSession->CreateSession(&sessionSettings, Flags, &session));
wsl::windows::common::security::ConfigureForCOMImpersonation(session.get());
return session;
@@ -208,8 +208,7 @@ class WSLATests
wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get());
- wil::com_ptr session;
- VERIFY_SUCCEEDED(userSession->CreateSession(&settings, &session));
+ wil::com_ptr session = CreateSession(settings);
// Act: list sessions
{
@@ -226,9 +225,8 @@ class WSLATests
// List multiple sessions.
{
- wil::com_ptr session2;
settings.DisplayName = L"wsla-test-list-2";
- VERIFY_SUCCEEDED(userSession->CreateSession(&settings, &session2));
+ auto session2 = CreateSession(settings);
wil::unique_cotaskmem_array_ptr sessions;
VERIFY_SUCCEEDED(userSession->ListSessions(&sessions, sessions.size_address()));
@@ -261,8 +259,7 @@ class WSLATests
wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get());
- wil::com_ptr created;
- VERIFY_SUCCEEDED(userSession->CreateSession(&settings, &created));
+ wil::com_ptr created = CreateSession(settings);
// Act: open by the same display name
wil::com_ptr opened;
@@ -409,7 +406,7 @@ class WSLATests
auto session = CreateSession(settings);
auto detach = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
- session->Shutdown(30 * 1000);
+ session.reset();
if (thread.joinable())
{
thread.join();
@@ -420,7 +417,7 @@ class WSLATests
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "echo DmesgTest > /dev/kmsg"}, 0);
- VERIFY_ARE_EQUAL(session->Shutdown(30 * 1000), S_OK);
+ session.reset();
detach.reset();
auto contentString = std::string(dmesgContent.begin(), dmesgContent.end());
@@ -487,9 +484,7 @@ class WSLATests
auto session = CreateSession(sessionSettings);
- wil::com_ptr vm;
- VERIFY_SUCCEEDED(session->GetVirtualMachine(&vm));
- VERIFY_SUCCEEDED(vm->Shutdown(30 * 1000));
+ session.reset();
auto future = promise.get_future();
auto result = future.wait_for(std::chrono::seconds(30));
auto [reason, details] = future.get();
@@ -760,9 +755,6 @@ class WSLATests
auto session = CreateSession(settings);
- wil::com_ptr vm;
- VERIFY_SUCCEEDED(session->GetVirtualMachine(&vm));
-
auto listen = [&](short port, const char* content, bool ipv6) {
auto cmd = std::format("echo -n '{}' | /usr/bin/socat -dd TCP{}-LISTEN:{},reuseaddr -", content, ipv6 ? "6" : "", port);
auto process = WSLAProcessLauncher("/bin/sh", {"/bin/sh", "-c", cmd}).Launch(*session);
@@ -796,10 +788,10 @@ class WSLATests
};
// Map port
- VERIFY_SUCCEEDED(vm->MapPort(AF_INET, 1234, 80, false));
+ VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, false));
// Validate that the same port can't be bound twice
- VERIFY_ARE_EQUAL(vm->MapPort(AF_INET, 1234, 80, false), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS));
+ VERIFY_ARE_EQUAL(session->MapVmPort(AF_INET, 1234, 80, false), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS));
// Check simple case
listen(80, "port80", false);
@@ -813,34 +805,34 @@ class WSLATests
expectContent(1234, AF_INET, "");
// Add a ipv6 binding
- VERIFY_SUCCEEDED(vm->MapPort(AF_INET6, 1234, 80, false));
+ VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1234, 80, false));
// Validate that ipv6 bindings work as well.
listen(80, "port80ipv6", true);
expectContent(1234, AF_INET6, "port80ipv6");
// Unmap the ipv4 port
- VERIFY_SUCCEEDED(vm->MapPort(AF_INET, 1234, 80, true));
+ VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, true));
// Verify that a proper error is returned if the mapping doesn't exist
- VERIFY_ARE_EQUAL(vm->MapPort(AF_INET, 1234, 80, true), HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
+ VERIFY_ARE_EQUAL(session->MapVmPort(AF_INET, 1234, 80, true), HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
// Unmap the v6 port
- VERIFY_SUCCEEDED(vm->MapPort(AF_INET6, 1234, 80, true));
+ VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1234, 80, true));
// Map another port as v6 only
- VERIFY_SUCCEEDED(vm->MapPort(AF_INET6, 1235, 81, false));
+ VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1235, 81, false));
listen(81, "port81ipv6", true);
expectContent(1235, AF_INET6, "port81ipv6");
expectNotBound(1235, AF_INET);
- VERIFY_SUCCEEDED(vm->MapPort(AF_INET6, 1235, 81, true));
- VERIFY_ARE_EQUAL(vm->MapPort(AF_INET6, 1235, 81, true), HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
+ VERIFY_SUCCEEDED(session->MapVmPort(AF_INET6, 1235, 81, true));
+ VERIFY_ARE_EQUAL(session->MapVmPort(AF_INET6, 1235, 81, true), HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
expectNotBound(1235, AF_INET6);
// Create a forking relay and stress test
- VERIFY_SUCCEEDED(vm->MapPort(AF_INET, 1234, 80, false));
+ VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, false));
auto process =
WSLAProcessLauncher{"/usr/bin/socat", {"/usr/bin/socat", "-dd", "TCP-LISTEN:80,fork,reuseaddr", "system:'echo -n OK'"}}
@@ -853,7 +845,7 @@ class WSLATests
expectContent(1234, AF_INET, "OK");
}
- VERIFY_SUCCEEDED(vm->MapPort(AF_INET, 1234, 80, true));
+ VERIFY_SUCCEEDED(session->MapVmPort(AF_INET, 1234, 80, true));
}
TEST_METHOD(StuckVmTermination)
@@ -876,10 +868,6 @@ class WSLATests
auto session = CreateSession(settings);
- wil::com_ptr vm;
- VERIFY_SUCCEEDED(session->GetVirtualMachine(&vm));
- wsl::windows::common::security::ConfigureForCOMImpersonation(vm.get());
-
auto expectedMountOptions = [&](bool readOnly) -> std::string {
if (enableVirtioFs)
{
@@ -898,45 +886,45 @@ class WSLATests
// Validate writeable mount.
{
- VERIFY_SUCCEEDED(vm->MountWindowsFolder(testFolder.c_str(), "/win-path", false));
+ VERIFY_SUCCEEDED(session->MountWindowsFolder(testFolder.c_str(), "/win-path", false));
ExpectMount(session.get(), "/win-path", expectedMountOptions(false));
// Validate that mount can't be stacked on each other
- VERIFY_ARE_EQUAL(vm->MountWindowsFolder(testFolder.c_str(), "/win-path", false), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS));
+ VERIFY_ARE_EQUAL(session->MountWindowsFolder(testFolder.c_str(), "/win-path", false), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS));
// Validate that folder is writeable from linux
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "echo -n content > /win-path/file.txt && sync"}, 0);
VERIFY_ARE_EQUAL(ReadFileContent(testFolder / "file.txt"), L"content");
- VERIFY_SUCCEEDED(vm->UnmountWindowsFolder("/win-path"));
+ VERIFY_SUCCEEDED(session->UnmountWindowsFolder("/win-path"));
ExpectMount(session.get(), "/win-path", {});
}
// Validate read-only mount.
{
- VERIFY_SUCCEEDED(vm->MountWindowsFolder(testFolder.c_str(), "/win-path", true));
+ VERIFY_SUCCEEDED(session->MountWindowsFolder(testFolder.c_str(), "/win-path", true));
ExpectMount(session.get(), "/win-path", expectedMountOptions(true));
// Validate that folder is not writeable from linux
ExpectCommandResult(session.get(), {"/bin/sh", "-c", "echo -n content > /win-path/file.txt"}, 1);
- VERIFY_SUCCEEDED(vm->UnmountWindowsFolder("/win-path"));
+ VERIFY_SUCCEEDED(session->UnmountWindowsFolder("/win-path"));
ExpectMount(session.get(), "/win-path", {});
}
// Validate various error paths
{
- VERIFY_ARE_EQUAL(vm->MountWindowsFolder(L"relative-path", "/win-path", true), E_INVALIDARG);
- VERIFY_ARE_EQUAL(vm->MountWindowsFolder(L"C:\\does-not-exist", "/win-path", true), HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND));
- VERIFY_ARE_EQUAL(vm->UnmountWindowsFolder("/not-mounted"), HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
- VERIFY_ARE_EQUAL(vm->UnmountWindowsFolder("/proc"), HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
+ VERIFY_ARE_EQUAL(session->MountWindowsFolder(L"relative-path", "/win-path", true), E_INVALIDARG);
+ VERIFY_ARE_EQUAL(session->MountWindowsFolder(L"C:\\does-not-exist", "/win-path", true), HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND));
+ VERIFY_ARE_EQUAL(session->UnmountWindowsFolder("/not-mounted"), HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
+ VERIFY_ARE_EQUAL(session->UnmountWindowsFolder("/proc"), HRESULT_FROM_WIN32(ERROR_NOT_FOUND));
// Validate that folders that are manually unmounted from the guest are handled properly
- VERIFY_SUCCEEDED(vm->MountWindowsFolder(testFolder.c_str(), "/win-path", true));
+ VERIFY_SUCCEEDED(session->MountWindowsFolder(testFolder.c_str(), "/win-path", true));
ExpectMount(session.get(), "/win-path", expectedMountOptions(true));
ExpectCommandResult(session.get(), {"/usr/bin/umount", "/win-path"}, 0);
- VERIFY_SUCCEEDED(vm->UnmountWindowsFolder("/win-path"));
+ VERIFY_SUCCEEDED(session->UnmountWindowsFolder("/win-path"));
}
}
@@ -1000,9 +988,6 @@ class WSLATests
WI_ClearFlag(settings.FeatureFlags, WslaFeatureFlagsGPU);
session = CreateSession(settings);
- wil::com_ptr vm;
- VERIFY_SUCCEEDED(session->GetVirtualMachine(&vm));
-
// Validate that the GPU device is not available.
ExpectMount(session.get(), "/usr/lib/wsl/drivers", {});
ExpectMount(session.get(), "/usr/lib/wsl/lib", {});
@@ -1164,24 +1149,9 @@ class WSLATests
VERIFY_ARE_EQUAL(process.Get().GetStdHandle(3, reinterpret_cast(&dummyHandle)), E_INVALIDARG);
// Validate that the process object correctly handle requests after the VM has terminated.
- VERIFY_SUCCEEDED(session->Shutdown(30 * 1000));
+ session.reset();
VERIFY_ARE_EQUAL(process.Get().Signal(WSLASignalSIGKILL), HRESULT_FROM_WIN32(ERROR_INVALID_STATE));
}
-
- {
-
- // Validate that new processes cannot be created after the VM is terminated.
- const char* executable = "dummy";
- WSLA_PROCESS_OPTIONS options{};
- options.CommandLine = &executable;
- options.Executable = executable;
- options.CommandLineCount = 1;
-
- wil::com_ptr process;
- int error{};
- VERIFY_ARE_EQUAL(session->CreateRootNamespaceProcess(&options, &process, &error), HRESULT_FROM_WIN32(ERROR_INVALID_STATE));
- VERIFY_ARE_EQUAL(error, -1);
- }
}
TEST_METHOD(CrashDumpCollection)
@@ -1266,7 +1236,7 @@ class WSLATests
wsl::core::filesystem::CreateVhd(formatedVhd, 100 * 1024 * 1024, tokenInfo->User.Sid, false, false);
auto cleanup = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&]() {
- LOG_IF_FAILED(session->Shutdown(30 * 1000));
+ session.reset();
LOG_IF_WIN32_BOOL_FALSE(DeleteFileW(formatedVhd));
});
@@ -2371,4 +2341,99 @@ class WSLATests
VERIFY_ARE_EQUAL(wil::ResultFromException([&]() { runTest(input, "", ""); }), E_INVALIDARG);
}
}
+
+ TEST_METHOD(PersistentSession)
+ {
+ WSL2_TEST_ONLY();
+
+ wil::com_ptr userSession;
+ VERIFY_SUCCEEDED(CoCreateInstance(__uuidof(WSLAUserSession), nullptr, CLSCTX_LOCAL_SERVER, IID_PPV_ARGS(&userSession)));
+ wsl::windows::common::security::ConfigureForCOMImpersonation(userSession.get());
+
+ auto expectSessions = [&](const std::vector& expectedSessions) {
+ wil::unique_cotaskmem_array_ptr sessions;
+ VERIFY_SUCCEEDED(userSession->ListSessions(&sessions, sessions.size_address()));
+
+ std::set displayNames;
+ for (const auto& e : sessions)
+ {
+ auto [_, inserted] = displayNames.insert(e.DisplayName);
+
+ VERIFY_IS_TRUE(inserted);
+ }
+
+ for (const auto& e : expectedSessions)
+ {
+ auto it = displayNames.find(e);
+ if (it == displayNames.end())
+ {
+ LogError("Session not found: %ls", e.c_str());
+ VERIFY_FAIL();
+ }
+
+ displayNames.erase(it);
+ }
+
+ for (const auto& e : displayNames)
+ {
+ LogError("Unexpected session found: %ls", e.c_str());
+ VERIFY_FAIL();
+ }
+ };
+
+ auto create = [this](LPCWSTR Name, WSLASessionFlags Flags) {
+ auto settings = GetDefaultSessionSettings();
+ settings.DisplayName = Name;
+ settings.NetworkingMode = WSLANetworkingModeNone;
+
+ return CreateSession(settings, Flags);
+ };
+
+ // Validate that non-persistent sessions are dropped when released
+ {
+ auto session1 = create(L"session-1", WSLASessionFlagsNone);
+ expectSessions({L"session-1"});
+
+ session1.reset();
+ expectSessions({});
+ }
+
+ // Validate that persistent sessions are only dropped when explicitly terminated.
+ {
+ auto session1 = create(L"session-1", WSLASessionFlagsPersistent);
+ expectSessions({L"session-1"});
+
+ session1.reset();
+ expectSessions({L"session-1"});
+ session1 = create(L"session-1", WSLASessionFlagsOpenExisting);
+
+ VERIFY_SUCCEEDED(session1->Terminate());
+ session1.reset();
+ expectSessions({});
+ }
+
+ // Validate that sessions can be reopened by name.
+ {
+ auto session1 = create(L"session-1", WSLASessionFlagsPersistent);
+ expectSessions({L"session-1"});
+
+ session1.reset();
+ expectSessions({L"session-1"});
+
+ auto session1Copy =
+ create(L"session-1", static_cast(WSLASessionFlagsPersistent | WSLASessionFlagsOpenExisting));
+
+ expectSessions({L"session-1"});
+
+ // Verify that name conflicts are correctly handled.
+ auto settings = GetDefaultSessionSettings();
+ settings.DisplayName = L"session-1";
+
+ wil::com_ptr session;
+ VERIFY_ARE_EQUAL(userSession->CreateSession(&settings, WSLASessionFlagsPersistent, &session), HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS));
+
+ VERIFY_SUCCEEDED(session1Copy->Terminate());
+ expectSessions({});
+ }
+ }
};