From 5cdfb447a51aee2f378b463b26221a67c27b2111 Mon Sep 17 00:00:00 2001 From: Mark Lechtermann Date: Fri, 1 Dec 2023 08:11:49 +0100 Subject: [PATCH] Register COM server types attached to the callers' threading model (#219) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Manuel Leßmann --- src/dscom/RegistrationServices.cs | 4 +- src/dscom/STAClassFactory.cs | 97 +++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 src/dscom/STAClassFactory.cs diff --git a/src/dscom/RegistrationServices.cs b/src/dscom/RegistrationServices.cs index da086ce..5f951d6 100644 --- a/src/dscom/RegistrationServices.cs +++ b/src/dscom/RegistrationServices.cs @@ -121,7 +121,7 @@ private static class RegistryValues /// The parameter cannot be created. public void RegisterTypeForComClients(Type type, ref Guid g) { - var genericClassFactory = typeof(ClassFactory<>); + var genericClassFactory = Thread.CurrentThread.GetApartmentState() == ApartmentState.STA ? typeof(STAClassFactory<>) : typeof(ClassFactory<>); Type[] typeArgs = { type }; var constructedClassFactory = genericClassFactory.MakeGenericType(typeArgs); @@ -147,7 +147,7 @@ public int RegisterTypeForComClients(Type type, ComTypes.RegistrationClassContex ?? type.Assembly.GetCustomAttributes().FirstOrDefault()?.Value) ?? throw new ArgumentException($"The given type {type} does not have a valid GUID attribute."); var guid = new Guid(value); - var genericClassFactory = typeof(ClassFactory<>); + var genericClassFactory = Thread.CurrentThread.GetApartmentState() == ApartmentState.STA ? typeof(STAClassFactory<>) : typeof(ClassFactory<>); Type[] typeArgs = { type }; var constructedClassFactory = genericClassFactory.MakeGenericType(typeArgs); diff --git a/src/dscom/STAClassFactory.cs b/src/dscom/STAClassFactory.cs new file mode 100644 index 0000000..d18cd89 --- /dev/null +++ b/src/dscom/STAClassFactory.cs @@ -0,0 +1,97 @@ +// The MIT License (MIT) +// Copyright (c) Microsoft Corporation + +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +// associated documentation files (the "Software"), to deal in the Software without restriction, +// including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all copies or substantial +// portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +// NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +using System.Runtime.InteropServices; + +namespace dSPACE.Runtime.InteropServices; + +[ComVisible(true)] +internal sealed class STAClassFactory : StandardOleMarshalObject, IClassFactory where T : new() +{ + public void CreateInstance( + [MarshalAs(UnmanagedType.Interface)] object instancePointer, + ref Guid riid, + out IntPtr outInterfacePointFromClass) + { + var interfaceType = GetInterfaceFromClassType(typeof(T), ref riid, instancePointer); + + object aggregatedObject = new T(); + if (instancePointer != null) + { + aggregatedObject = CreateAggregatedObject(instancePointer, aggregatedObject); + } + + outInterfacePointFromClass = GetObjectAsInterface(aggregatedObject, interfaceType); + } + + public void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock) { } + + private static Type GetInterfaceFromClassType(Type classType, ref Guid riid, object outer) + { + if (riid == new Guid(Guids.IID_IUnknown)) + { + return typeof(object); + } + + if (outer != null) + { + throw new COMException(string.Empty, HRESULT.CLASS_E_NOAGGREGATION); + } + + foreach (var i in classType.GetInterfaces()) + { + if (i.GUID == riid) + { + return i; + } + } + + throw new InvalidCastException(); + } + + private static IntPtr GetObjectAsInterface(object obj, Type interfaceType) + { + if (interfaceType == typeof(object)) + { + return Marshal.GetIUnknownForObject(obj); + } + + var interfaceMaybe = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore); + if (interfaceMaybe == IntPtr.Zero) + { + throw new InvalidCastException(); + } + + return interfaceMaybe; + } + + private static object CreateAggregatedObject(object pUnkOuter, object comObject) + { + var outerPtr = Marshal.GetIUnknownForObject(pUnkOuter); + + try + { + var innerPtr = Marshal.CreateAggregatedObject(outerPtr, comObject); + return Marshal.GetObjectForIUnknown(innerPtr); + } + finally + { + Marshal.Release(outerPtr); + } + } +}