Skip to content

Commit

Permalink
Register COM server types attached to the callers' threading model (#219
Browse files Browse the repository at this point in the history
)

Co-authored-by: Manuel Leßmann <[email protected]>
  • Loading branch information
marklechtermann and Manuel Leßmann committed Dec 1, 2023
1 parent 9ec3276 commit 5cdfb44
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/dscom/RegistrationServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ private static class RegistryValues
/// <exception cref="T:System.ArgumentNullException">The <paramref name="type" /> parameter cannot be created.</exception>
public void RegisterTypeForComClients(Type type, ref Guid g)
{
var genericClassFactory = typeof(ClassFactory<>);
var genericClassFactory = Thread.CurrentThread.GetApartmentState() == ApartmentState.STA ? typeof(STAClassFactory<>) : typeof(ClassFactory<>);
Type[] typeArgs = { type };
var constructedClassFactory = genericClassFactory.MakeGenericType(typeArgs);

Expand All @@ -147,7 +147,7 @@ public int RegisterTypeForComClients(Type type, ComTypes.RegistrationClassContex
?? type.Assembly.GetCustomAttributes<GuidAttribute>().FirstOrDefault()?.Value) ?? throw new ArgumentException($"The given type {type} does not have a valid GUID attribute.");
var guid = new Guid(value);

var genericClassFactory = typeof(ClassFactory<>);
var genericClassFactory = Thread.CurrentThread.GetApartmentState() == ApartmentState.STA ? typeof(STAClassFactory<>) : typeof(ClassFactory<>);
Type[] typeArgs = { type };
var constructedClassFactory = genericClassFactory.MakeGenericType(typeArgs);

Expand Down
97 changes: 97 additions & 0 deletions src/dscom/STAClassFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// The MIT License (MIT)
// Copyright (c) Microsoft Corporation

// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
// associated documentation files (the "Software"), to deal in the Software without restriction,
// including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:

// The above copyright notice and this permission notice shall be included in all copies or substantial
// portions of the Software.

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
// NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
// SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

using System.Runtime.InteropServices;

namespace dSPACE.Runtime.InteropServices;

[ComVisible(true)]
internal sealed class STAClassFactory<T> : StandardOleMarshalObject, IClassFactory where T : new()
{
public void CreateInstance(
[MarshalAs(UnmanagedType.Interface)] object instancePointer,
ref Guid riid,
out IntPtr outInterfacePointFromClass)
{
var interfaceType = GetInterfaceFromClassType(typeof(T), ref riid, instancePointer);

object aggregatedObject = new T();
if (instancePointer != null)
{
aggregatedObject = CreateAggregatedObject(instancePointer, aggregatedObject);
}

outInterfacePointFromClass = GetObjectAsInterface(aggregatedObject, interfaceType);
}

public void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock) { }

private static Type GetInterfaceFromClassType(Type classType, ref Guid riid, object outer)
{
if (riid == new Guid(Guids.IID_IUnknown))
{
return typeof(object);
}

if (outer != null)
{
throw new COMException(string.Empty, HRESULT.CLASS_E_NOAGGREGATION);
}

foreach (var i in classType.GetInterfaces())
{
if (i.GUID == riid)
{
return i;
}
}

throw new InvalidCastException();
}

private static IntPtr GetObjectAsInterface(object obj, Type interfaceType)
{
if (interfaceType == typeof(object))
{
return Marshal.GetIUnknownForObject(obj);
}

var interfaceMaybe = Marshal.GetComInterfaceForObject(obj, interfaceType, CustomQueryInterfaceMode.Ignore);
if (interfaceMaybe == IntPtr.Zero)
{
throw new InvalidCastException();
}

return interfaceMaybe;
}

private static object CreateAggregatedObject(object pUnkOuter, object comObject)
{
var outerPtr = Marshal.GetIUnknownForObject(pUnkOuter);

try
{
var innerPtr = Marshal.CreateAggregatedObject(outerPtr, comObject);
return Marshal.GetObjectForIUnknown(innerPtr);
}
finally
{
Marshal.Release(outerPtr);
}
}
}

0 comments on commit 5cdfb44

Please sign in to comment.