Skip to content

Commit

Permalink
Replace GetModuleFileNameA with GetModuleFileNameW to prevent a `…
Browse files Browse the repository at this point in the history
…TypeError`. (#733)

* Add `GetModuleFileNameW` to `typeinfo`.

* Improve referring `ctypes` in `server.register`.

* Fix the frozen dll path problem.

* Rename from `...W` to `GetModuleFileName` and small fixes.
  • Loading branch information
junkmd authored Jan 7, 2025
1 parent fc2792e commit 781e8e2
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 36 deletions.
16 changes: 10 additions & 6 deletions comtypes/server/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,24 @@
python mycomobj.py /nodebug
"""

import ctypes
import logging
import os
import sys
import winreg
from ctypes import WinError, c_ulong, c_wchar_p, create_string_buffer, sizeof, windll
from ctypes import WinError, windll
from typing import Iterator, Tuple

import comtypes
import comtypes.server.inprocserver
from comtypes.hresult import *
from comtypes.server import w_getopt
from comtypes.typeinfo import REGKIND_REGISTER, LoadTypeLibEx, UnRegisterTypeLib
from comtypes.typeinfo import (
REGKIND_REGISTER,
GetModuleFileName,
LoadTypeLibEx,
UnRegisterTypeLib,
)

_debug = logging.getLogger(__name__).debug

Expand All @@ -67,7 +73,7 @@ def _non_zero(retval, func, args):

SHDeleteKey = windll.shlwapi.SHDeleteKeyW
SHDeleteKey.errcheck = _non_zero
SHDeleteKey.argtypes = c_ulong, c_wchar_p
SHDeleteKey.argtypes = ctypes.c_ulong, ctypes.c_wchar_p

Set = set

Expand Down Expand Up @@ -219,9 +225,7 @@ def _get_serverdll():
"""Return the pathname of the dll hosting the COM object."""
handle = getattr(sys, "frozendllhandle", None)
if handle is not None:
buf = create_string_buffer(260)
windll.kernel32.GetModuleFileNameA(handle, buf, sizeof(buf))
return buf[:]
return GetModuleFileName(handle, 260)
import _ctypes

return _ctypes.__file__
Expand Down
23 changes: 10 additions & 13 deletions comtypes/test/test_server_register.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import _ctypes
import ctypes
import os
import sys
import unittest as ut
Expand Down Expand Up @@ -194,18 +193,16 @@ class Test_get_serverdll(ut.TestCase):
def test_nonfrozen(self):
self.assertEqual(_ctypes.__file__, _get_serverdll())

def test_frozen(self):
with mock.patch.object(register, "sys") as _sys:
with mock.patch.object(register, "windll") as _windll:
handle = 1234
_sys.frozendllhandle = handle
self.assertEqual(b"\x00" * 260, _get_serverdll())
GetModuleFileName = _windll.kernel32.GetModuleFileNameA
(((hModule, lpFilename, nSize), _),) = GetModuleFileName.call_args_list
self.assertEqual(handle, hModule)
buf_type = type(ctypes.create_string_buffer(260))
self.assertIsInstance(lpFilename, buf_type)
self.assertEqual(260, nSize)
@mock.patch.object(register, "GetModuleFileName")
@mock.patch.object(register, "sys")
def test_frozen(self, _sys, GetModuleFileName):
handle, dll_path = 1234, r"path\to\frozendll"
_sys.frozendllhandle = handle
GetModuleFileName.return_value = dll_path
self.assertEqual(r"path\to\frozendll", _get_serverdll())
(((hmodule, maxsize), _),) = GetModuleFileName.call_args_list
self.assertEqual(handle, hmodule)
self.assertEqual(260, maxsize)


class Test_NonFrozen_RegistryEntries(ut.TestCase):
Expand Down
26 changes: 19 additions & 7 deletions comtypes/test/test_typeinfo.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import os
import ctypes
import sys
import unittest
from ctypes import POINTER, byref

from comtypes import GUID, COMError
from comtypes.automation import DISPATCH_METHOD
from comtypes.typeinfo import (
LoadTypeLibEx,
TKIND_DISPATCH,
TKIND_INTERFACE,
GetModuleFileName,
LoadRegTypeLib,
LoadTypeLibEx,
QueryPathOfRegTypeLib,
TKIND_INTERFACE,
TKIND_DISPATCH,
TKIND_ENUM,
)


Expand Down Expand Up @@ -94,5 +94,17 @@ def test_TypeInfo(self):
self.assertEqual(guid, ti.GetTypeAttr().guid)


class Test_GetModuleFileName(unittest.TestCase):
def test_null_handler(self):
self.assertEqual(GetModuleFileName(None, 260), sys.executable)

def test_loaded_module_handle(self):
import _ctypes

dll_path = _ctypes.__file__
hmodule = ctypes.WinDLL(dll_path)._handle
self.assertEqual(GetModuleFileName(hmodule, 260), dll_path)


if __name__ == "__main__":
unittest.main()
65 changes: 55 additions & 10 deletions comtypes/typeinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,48 @@
# generated by 'xml2py'
# flags '..\tools\windows.xml -m comtypes -m comtypes.automation -w -r .*TypeLibEx -r .*TypeLib -o typeinfo.py'
# then hacked manually
import ctypes
import sys
from typing import Any, overload, TypeVar, TYPE_CHECKING
from typing import List, Type, Tuple
from typing import Optional, Union as _UnionT
from typing import Callable, Sequence
import weakref

import ctypes
from ctypes import HRESULT, POINTER, _Pointer, byref, c_int, c_void_p, c_wchar_p
from ctypes.wintypes import DWORD, LONG, UINT, ULONG, WCHAR, WORD, INT, SHORT, USHORT
from comtypes import BSTR, _CData, COMMETHOD, GUID, IID, IUnknown, STDMETHOD
from comtypes.automation import DISPID, LCID, SCODE
from comtypes.automation import DISPPARAMS, EXCEPINFO, VARIANT, VARIANTARG, VARTYPE
from ctypes.wintypes import (
DWORD,
HMODULE,
INT,
LONG,
LPWSTR,
SHORT,
UINT,
ULONG,
USHORT,
WCHAR,
WORD,
)
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
overload,
)
from typing import Union as _UnionT

from comtypes import BSTR, COMMETHOD, GUID, IID, STDMETHOD, IUnknown, _CData
from comtypes.automation import (
DISPID,
DISPPARAMS,
EXCEPINFO,
LCID,
SCODE,
VARIANT,
VARIANTARG,
VARTYPE,
)

if TYPE_CHECKING:
from comtypes import hints # type: ignore
Expand Down Expand Up @@ -666,6 +695,22 @@ def QueryPathOfRegTypeLib(
return pathname.value.split("\0")[0]


_GetModuleFileNameW = ctypes.windll.kernel32.GetModuleFileNameW
_GetModuleFileNameW.argtypes = HMODULE, LPWSTR, DWORD
_GetModuleFileNameW.restype = DWORD


def GetModuleFileName(handle: Optional[int], maxsize: int) -> str:
"""Returns the fullpath of the loaded module specified by the handle.
If the handle is NULL, returns the executable file path of the current process.
https://learn.microsoft.com/ja-jp/windows/win32/api/libloaderapi/nf-libloaderapi-loadlibraryw
"""
buf = ctypes.create_unicode_buffer(maxsize)
length = _GetModuleFileNameW(handle, buf, maxsize)
return buf.value[:length]


################################################################
# Structures

Expand Down

0 comments on commit 781e8e2

Please sign in to comment.