diff --git a/comtypes/server/register.py b/comtypes/server/register.py index e0ea0dc2..1d17cd43 100644 --- a/comtypes/server/register.py +++ b/comtypes/server/register.py @@ -36,18 +36,24 @@ python mycomobj.py /nodebug """ +import ctypes import logging import os import sys import winreg -from ctypes import WinError, c_ulong, c_wchar_p, create_string_buffer, sizeof, windll +from ctypes import WinError, windll from typing import Iterator, Tuple import comtypes import comtypes.server.inprocserver from comtypes.hresult import * from comtypes.server import w_getopt -from comtypes.typeinfo import REGKIND_REGISTER, LoadTypeLibEx, UnRegisterTypeLib +from comtypes.typeinfo import ( + REGKIND_REGISTER, + GetModuleFileName, + LoadTypeLibEx, + UnRegisterTypeLib, +) _debug = logging.getLogger(__name__).debug @@ -67,7 +73,7 @@ def _non_zero(retval, func, args): SHDeleteKey = windll.shlwapi.SHDeleteKeyW SHDeleteKey.errcheck = _non_zero -SHDeleteKey.argtypes = c_ulong, c_wchar_p +SHDeleteKey.argtypes = ctypes.c_ulong, ctypes.c_wchar_p Set = set @@ -219,9 +225,7 @@ def _get_serverdll(): """Return the pathname of the dll hosting the COM object.""" handle = getattr(sys, "frozendllhandle", None) if handle is not None: - buf = create_string_buffer(260) - windll.kernel32.GetModuleFileNameA(handle, buf, sizeof(buf)) - return buf[:] + return GetModuleFileName(handle, 260) import _ctypes return _ctypes.__file__ diff --git a/comtypes/test/test_server_register.py b/comtypes/test/test_server_register.py index 83fc78b2..e252dfae 100644 --- a/comtypes/test/test_server_register.py +++ b/comtypes/test/test_server_register.py @@ -1,5 +1,4 @@ import _ctypes -import ctypes import os import sys import unittest as ut @@ -194,18 +193,16 @@ class Test_get_serverdll(ut.TestCase): def test_nonfrozen(self): self.assertEqual(_ctypes.__file__, _get_serverdll()) - def test_frozen(self): - with mock.patch.object(register, "sys") as _sys: - with mock.patch.object(register, "windll") as _windll: - handle = 1234 - _sys.frozendllhandle = handle - self.assertEqual(b"\x00" * 260, _get_serverdll()) - GetModuleFileName = _windll.kernel32.GetModuleFileNameA - (((hModule, lpFilename, nSize), _),) = GetModuleFileName.call_args_list - self.assertEqual(handle, hModule) - buf_type = type(ctypes.create_string_buffer(260)) - self.assertIsInstance(lpFilename, buf_type) - self.assertEqual(260, nSize) + @mock.patch.object(register, "GetModuleFileName") + @mock.patch.object(register, "sys") + def test_frozen(self, _sys, GetModuleFileName): + handle, dll_path = 1234, r"path\to\frozendll" + _sys.frozendllhandle = handle + GetModuleFileName.return_value = dll_path + self.assertEqual(r"path\to\frozendll", _get_serverdll()) + (((hmodule, maxsize), _),) = GetModuleFileName.call_args_list + self.assertEqual(handle, hmodule) + self.assertEqual(260, maxsize) class Test_NonFrozen_RegistryEntries(ut.TestCase): diff --git a/comtypes/test/test_typeinfo.py b/comtypes/test/test_typeinfo.py index eaad4c73..1426e350 100644 --- a/comtypes/test/test_typeinfo.py +++ b/comtypes/test/test_typeinfo.py @@ -1,15 +1,15 @@ -import os +import ctypes +import sys import unittest -from ctypes import POINTER, byref + from comtypes import GUID, COMError -from comtypes.automation import DISPATCH_METHOD from comtypes.typeinfo import ( - LoadTypeLibEx, + TKIND_DISPATCH, + TKIND_INTERFACE, + GetModuleFileName, LoadRegTypeLib, + LoadTypeLibEx, QueryPathOfRegTypeLib, - TKIND_INTERFACE, - TKIND_DISPATCH, - TKIND_ENUM, ) @@ -94,5 +94,17 @@ def test_TypeInfo(self): self.assertEqual(guid, ti.GetTypeAttr().guid) +class Test_GetModuleFileName(unittest.TestCase): + def test_null_handler(self): + self.assertEqual(GetModuleFileName(None, 260), sys.executable) + + def test_loaded_module_handle(self): + import _ctypes + + dll_path = _ctypes.__file__ + hmodule = ctypes.WinDLL(dll_path)._handle + self.assertEqual(GetModuleFileName(hmodule, 260), dll_path) + + if __name__ == "__main__": unittest.main() diff --git a/comtypes/typeinfo.py b/comtypes/typeinfo.py index 077a492c..fb17000c 100644 --- a/comtypes/typeinfo.py +++ b/comtypes/typeinfo.py @@ -3,19 +3,48 @@ # generated by 'xml2py' # flags '..\tools\windows.xml -m comtypes -m comtypes.automation -w -r .*TypeLibEx -r .*TypeLib -o typeinfo.py' # then hacked manually +import ctypes import sys -from typing import Any, overload, TypeVar, TYPE_CHECKING -from typing import List, Type, Tuple -from typing import Optional, Union as _UnionT -from typing import Callable, Sequence import weakref - -import ctypes from ctypes import HRESULT, POINTER, _Pointer, byref, c_int, c_void_p, c_wchar_p -from ctypes.wintypes import DWORD, LONG, UINT, ULONG, WCHAR, WORD, INT, SHORT, USHORT -from comtypes import BSTR, _CData, COMMETHOD, GUID, IID, IUnknown, STDMETHOD -from comtypes.automation import DISPID, LCID, SCODE -from comtypes.automation import DISPPARAMS, EXCEPINFO, VARIANT, VARIANTARG, VARTYPE +from ctypes.wintypes import ( + DWORD, + HMODULE, + INT, + LONG, + LPWSTR, + SHORT, + UINT, + ULONG, + USHORT, + WCHAR, + WORD, +) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + overload, +) +from typing import Union as _UnionT + +from comtypes import BSTR, COMMETHOD, GUID, IID, STDMETHOD, IUnknown, _CData +from comtypes.automation import ( + DISPID, + DISPPARAMS, + EXCEPINFO, + LCID, + SCODE, + VARIANT, + VARIANTARG, + VARTYPE, +) if TYPE_CHECKING: from comtypes import hints # type: ignore @@ -666,6 +695,22 @@ def QueryPathOfRegTypeLib( return pathname.value.split("\0")[0] +_GetModuleFileNameW = ctypes.windll.kernel32.GetModuleFileNameW +_GetModuleFileNameW.argtypes = HMODULE, LPWSTR, DWORD +_GetModuleFileNameW.restype = DWORD + + +def GetModuleFileName(handle: Optional[int], maxsize: int) -> str: + """Returns the fullpath of the loaded module specified by the handle. + If the handle is NULL, returns the executable file path of the current process. + + https://learn.microsoft.com/ja-jp/windows/win32/api/libloaderapi/nf-libloaderapi-loadlibraryw + """ + buf = ctypes.create_unicode_buffer(maxsize) + length = _GetModuleFileNameW(handle, buf, maxsize) + return buf.value[:length] + + ################################################################ # Structures