Skip to content

Commit

Permalink
Add more type annotations to errorinfo. (#820)
Browse files Browse the repository at this point in the history
* Add type hints to `errorinfo.py` for supressing type errors in `test_errorinfo`.

* Add type hints to `errorinfo.ReportError` and `errorinfo.ReportException`.

* Add type hints to `errorinfo.ICreateErrorInfo`.

* Add type hints to `errorinfo.CreateErrorInfo`.

* Add type hints to `errorinfo.SetErrorInfo`.
  • Loading branch information
junkmd authored Feb 24, 2025
1 parent 8c49d8f commit e803526
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 16 deletions.
52 changes: 39 additions & 13 deletions comtypes/errorinfo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import sys
from ctypes import POINTER, OleDLL, byref, c_wchar_p
from ctypes.wintypes import DWORD, ULONG
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from typing import Union as _UnionT

from comtypes import BSTR, COMMETHOD, GUID, HRESULT, IUnknown
from comtypes.hresult import DISP_E_EXCEPTION, S_OK
Expand All @@ -21,6 +22,13 @@ class ICreateErrorInfo(IUnknown):
COMMETHOD([], HRESULT, "SetHelpFile", (["in"], LPCOLESTR, "szHelpFile")),
COMMETHOD([], HRESULT, "SetHelpContext", (["in"], DWORD, "dwHelpContext")),
]
if TYPE_CHECKING:

def SetGUID(self, rguid: GUID) -> hints.Hresult: ...
def SetSource(self, szSource: str) -> hints.Hresult: ...
def SetDescription(self, szDescription: str) -> hints.Hresult: ...
def SetHelpFile(self, szHelpFile: str) -> hints.Hresult: ...
def SetHelpContext(self, dwHelpContext: int) -> hints.Hresult: ...


class IErrorInfo(IUnknown):
Expand All @@ -38,6 +46,13 @@ class IErrorInfo(IUnknown):
[], HRESULT, "GetHelpContext", (["out"], POINTER(DWORD), "pdwHelpContext")
),
]
if TYPE_CHECKING:

def GetGUID(self) -> GUID: ...
def GetSource(self) -> str: ...
def GetDescription(self) -> str: ...
def GetHelpFile(self) -> str: ...
def GetHelpContext(self) -> int: ...


class ISupportErrorInfo(IUnknown):
Expand Down Expand Up @@ -68,28 +83,35 @@ def InterfaceSupportsErrorInfo(self, riid: GUID) -> hints.Hresult: ...
_SetErrorInfo.restype = HRESULT


def CreateErrorInfo():
def CreateErrorInfo() -> ICreateErrorInfo:
cei = POINTER(ICreateErrorInfo)()
_CreateErrorInfo(byref(cei))
return cei
return cei # type: ignore


def GetErrorInfo():
def GetErrorInfo() -> Optional[IErrorInfo]:
"""Get the error information for the current thread."""
errinfo = POINTER(IErrorInfo)()
if S_OK == _GetErrorInfo(0, byref(errinfo)):
return errinfo
return errinfo # type: ignore
return None


def SetErrorInfo(errinfo):
def SetErrorInfo(errinfo: _UnionT[IErrorInfo, ICreateErrorInfo]) -> "hints.Hresult":
"""Set error information for the current thread."""
# ICreateErrorInfo can QueryInterface with IErrorInfo, so both types are
# accepted, thanks to the magic of from_param.
return _SetErrorInfo(0, errinfo)


def ReportError(
text, iid, clsid=None, helpfile=None, helpcontext=0, hresult=DISP_E_EXCEPTION
):
text: str,
iid: GUID,
clsid: _UnionT[None, str, GUID] = None,
helpfile: Optional[str] = None,
helpcontext: Optional[int] = 0,
hresult: int = DISP_E_EXCEPTION,
) -> int:
"""Report a COM error. Returns the passed in hresult value."""
ei = CreateErrorInfo()
ei.SetDescription(text)
Expand All @@ -106,16 +128,20 @@ def ReportError(
except WindowsError:
pass
else:
ei.SetSource(
progid
) # progid for the class or application that created the error
# progid for the class or application that created the error
ei.SetSource(progid)
SetErrorInfo(ei)
return hresult


def ReportException(
hresult, iid, clsid=None, helpfile=None, helpcontext=None, stacklevel=None
):
hresult: int,
iid: GUID,
clsid: _UnionT[None, str, GUID] = None,
helpfile: Optional[str] = None,
helpcontext: Optional[int] = None,
stacklevel: Optional[int] = None,
) -> int:
"""Report a COM exception. Returns the passed in hresult value."""
typ, value, tb = sys.exc_info()
if stacklevel is not None:
Expand Down
10 changes: 7 additions & 3 deletions comtypes/test/test_errorinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_error_has_been_set(self):
self.assertEqual(errcode, hr)
pei = errorinfo.GetErrorInfo()
self.assertIsNotNone(pei)
assert pei is not None # for static type guard
self.assertEqual(shelllink.IShellLinkW._iid_, pei.GetGUID())
self.assertEqual("lnkfile", pei.GetSource())
self.assertEqual(errmsg, pei.GetDescription())
Expand All @@ -37,6 +38,7 @@ def test_without_optional_args(self):
self.assertEqual(hres.DISP_E_EXCEPTION, hr)
pei = errorinfo.GetErrorInfo()
self.assertIsNotNone(pei)
assert pei is not None # for static type guard
self.assertEqual(shelllink.IShellLinkW._iid_, pei.GetGUID())
self.assertIsNone(pei.GetSource())
self.assertEqual(errmsg, pei.GetDescription())
Expand Down Expand Up @@ -68,6 +70,7 @@ def test_without_stacklevel(self):
self.assertEqual(hres.E_UNEXPECTED, hr)
pei = errorinfo.GetErrorInfo()
self.assertIsNotNone(pei)
assert pei is not None # for static type guard
self.assertEqual(shelllink.IShellLinkW._iid_, pei.GetGUID())
self.assertIsNone(pei.GetSource())
self.assertEqual("<class 'RuntimeError'>: for testing", pei.GetDescription())
Expand All @@ -82,15 +85,16 @@ def test_with_stacklevel(self):
for slv, text in [
# XXX: If the codebase changes, the line where functions or
# methods are defined will change, meaning this test is brittle.
(0, f"{stem} ({__name__}, line 93)"),
(1, f"{stem} ({__name__}, line 53)"),
(2, f"{stem} ({__name__}, line 57)"),
(0, f"{stem} ({__name__}, line 96)"),
(1, f"{stem} ({__name__}, line 55)"),
(2, f"{stem} ({__name__}, line 59)"),
]:
with self.subTest(text=text):
try:
raise_runtime_error()
except RuntimeError:
errorinfo.ReportException(hres.E_UNEXPECTED, iid, stacklevel=slv)
pei = errorinfo.GetErrorInfo()
assert pei is not None # for static type guard
self.assertEqual(text, pei.GetDescription())
self.assertIsNone(errorinfo.GetErrorInfo())

0 comments on commit e803526

Please sign in to comment.