Skip to content

Commit

Permalink
Add type annotations to comtypes.automation (#374)
Browse files Browse the repository at this point in the history
* add type annotations

* improve `VARIANT.__repr__`
  • Loading branch information
junkmd authored Nov 13, 2022
1 parent 5f5429f commit db0931d
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 4 deletions.
71 changes: 67 additions & 4 deletions comtypes/automation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from ctypes import *
from ctypes import _Pointer
from _ctypes import CopyComPointer
from comtypes import IUnknown, GUID, IID, STDMETHOD, BSTR, COMMETHOD, COMError
from comtypes import (
BSTR, COMError, COMMETHOD, GUID, IID, IUnknown, STDMETHOD, TYPE_CHECKING,
)
from comtypes.hresult import *
import comtypes.patcher
import comtypes
Expand All @@ -19,6 +21,12 @@ class _safearray(object):

from ctypes.wintypes import DWORD, LONG, UINT, VARIANT_BOOL, WCHAR, WORD

if TYPE_CHECKING:
from typing import (
Any, Callable, ClassVar, List, Optional, Tuple, Union as _UnionT,
)
from comtypes import hints


if sys.version_info >= (3, 0):
int_types = (int, )
Expand Down Expand Up @@ -149,6 +157,13 @@ def as_decimal(self):
# The VARIANT structure is a good candidate for implementation in a C
# helper extension. At least the get/set methods.
class tagVARIANT(Structure):
if TYPE_CHECKING:
vt = hints.AnnoField() # type: int
_ = hints.AnnoField() # type: U_VARIANT1.__tagVARIANT.U_VARIANT2
null = hints.AnnoField() # type: ClassVar[VARIANT]
empty = hints.AnnoField() # type: ClassVar[VARIANT]
missing = hints.AnnoField() # type: ClassVar[VARIANT]

class U_VARIANT1(Union):
class __tagVARIANT(Structure):
# The C Header file defn of VARIANT is much more complicated, but
Expand Down Expand Up @@ -205,6 +220,12 @@ def __del__(self):
def __repr__(self):
if self.vt & VT_BYREF:
return "VARIANT(vt=0x%x, byref(%r))" % (self.vt, self[0])
elif self is type(self).null:
return "VARIANT.null"
elif self is type(self).empty:
return "VARIANT.empty"
elif self is type(self).missing:
return "VARIANT.missing"
return "VARIANT(vt=0x%x, %r)" % (self.vt, self.value)

@classmethod
Expand Down Expand Up @@ -650,6 +671,17 @@ def Next(self, celt):


class tagEXCEPINFO(Structure):
if TYPE_CHECKING:
wCode = hints.AnnoField() # type: int
wReserved = hints.AnnoField() # type: int
bstrSource = hints.AnnoField() # type: str
bstrDescription = hints.AnnoField() # type: str
bstrHelpFile = hints.AnnoField() # type: str
dwHelpContext = hints.AnnoField() # type: int
pvReserved = hints.AnnoField() # type: Optional[int]
pfnDeferredFillIn = hints.AnnoField() # type: Optional[int]
scode = hints.AnnoField() # type: int

def __repr__(self):
return "<EXCEPINFO %s>" % \
((self.wCode, self.bstrSource, self.bstrDescription, self.bstrHelpFile, self.dwHelpContext,
Expand All @@ -669,6 +701,11 @@ def __repr__(self):
EXCEPINFO = tagEXCEPINFO

class tagDISPPARAMS(Structure):
if TYPE_CHECKING:
rgvarg = hints.AnnoField() # type: Array[VARIANT]
rgdispidNamedArgs = hints.AnnoField() # type: _Pointer[DISPID]
cArgs = hints.AnnoField() # type: int
cNamedArgs = hints.AnnoField() # type: int
_fields_ = [
# C:/Programme/gccxml/bin/Vc71/PlatformSDK/oaidl.h 696
('rgvarg', POINTER(VARIANTARG)),
Expand All @@ -691,17 +728,39 @@ def __del__(self):
DISPID_DESTRUCTOR = -7
DISPID_COLLECT = -8


if TYPE_CHECKING:
RawGetIDsOfNamesFunc = Callable[
[_byref_type, Array[c_wchar_p], int, int, Array[DISPID]], int,
]
RawInvokeFunc = Callable[
[
int, _byref_type, int, int, # dispIdMember, riid, lcid, wFlags
_UnionT[_byref_type, DISPPARAMS], # *pDispParams
_UnionT[_byref_type, VARIANT], # pVarResult
_UnionT[_byref_type, EXCEPINFO, None], # pExcepInfo
_UnionT[_byref_type, c_uint], # puArgErr
],
int
]

class IDispatch(IUnknown):
if TYPE_CHECKING:
_disp_methods_ = hints.AnnoField() # type: ClassVar[List[comtypes._DispMemberSpec]]
_GetTypeInfo = hints.AnnoField() # type: Callable[[int, int], IUnknown]
__com_GetIDsOfNames = hints.AnnoField() # type: RawGetIDsOfNamesFunc
__com_Invoke = hints.AnnoField() # type: RawInvokeFunc

_iid_ = GUID("{00020400-0000-0000-C000-000000000046}")
_methods_ = [
COMMETHOD([], HRESULT, 'GetTypeInfoCount',
(['out'], POINTER(UINT) ) ),
COMMETHOD([], HRESULT, 'GetTypeInfo',
(['in'], UINT, 'index'),
(['in'], LCID, 'lcid', 0),
## Normally, we would declare this parameter in this way:
## (['out'], POINTER(POINTER(ITypeInfo)) ) ),
## but we cannot import comtypes.typeinfo at the top level (recursive imports!).
# Normally, we would declare this parameter in this way:
# (['out'], POINTER(POINTER(ITypeInfo)) ) ),
# but we cannot import comtypes.typeinfo at the top level (recursive imports!).
(['out'], POINTER(POINTER(IUnknown)) ) ),
STDMETHOD(HRESULT, 'GetIDsOfNames', [POINTER(IID), POINTER(c_wchar_p),
UINT, LCID, POINTER(DISPID)]),
Expand All @@ -711,12 +770,14 @@ class IDispatch(IUnknown):
]

def GetTypeInfo(self, index, lcid=0):
# type: (int, int) -> hints.ITypeInfo
"""Return type information. Index 0 specifies typeinfo for IDispatch"""
import comtypes.typeinfo
result = self._GetTypeInfo(index, lcid)
return result.QueryInterface(comtypes.typeinfo.ITypeInfo)

def GetIDsOfNames(self, *names, **kw):
# type: (str, Any) -> List[int]
"""Map string names to integer ids."""
lcid = kw.pop("lcid", 0)
assert not kw
Expand All @@ -726,6 +787,7 @@ def GetIDsOfNames(self, *names, **kw):
return ids[:]

def _invoke(self, memid, invkind, lcid, *args):
# type: (int, int, int, Any) -> Any
var = VARIANT()
argerr = c_uint()
dp = DISPPARAMS()
Expand All @@ -747,6 +809,7 @@ def _invoke(self, memid, invkind, lcid, *args):
return var._get_value(dynamic=True)

def Invoke(self, dispid, *args, **kw):
# type: (int, Any, Any) -> Any
"""Invoke a method or property."""

# Memory management in Dispatch::Invoke calls:
Expand Down
2 changes: 2 additions & 0 deletions comtypes/hints.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ from typing import (
)

# symbols those what might occur recursive imports in runtime.
from comtypes.automation import IDispatch as IDispatch, VARIANT as VARIANT
from comtypes.server import IClassFactory as IClassFactory
from comtypes.typeinfo import ITypeInfo as ITypeInfo


def AnnoField() -> Any:
Expand Down
7 changes: 7 additions & 0 deletions comtypes/test/test_variant.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ def test_byref(self):
variable.value = 96
self.assertEqual(v[0], 96)

def test_repr(self):
self.assertEqual(repr(VARIANT(c_int(42))), "VARIANT(vt=0x3, 42)")
self.assertEqual(repr(VARIANT(byref(c_int(42)))), "VARIANT(vt=0x4003, byref(42))")
self.assertEqual(repr(VARIANT.empty), "VARIANT.empty")
self.assertEqual(repr(VARIANT.null), "VARIANT.null")
self.assertEqual(repr(VARIANT.missing), "VARIANT.missing")


class ArrayTest(unittest.TestCase):
def test_double(self):
Expand Down

0 comments on commit db0931d

Please sign in to comment.