Skip to content

Commit

Permalink
allow passing **kw to dispmethods
Browse files Browse the repository at this point in the history
  • Loading branch information
junkmd committed Feb 5, 2024
1 parent 9c464a7 commit 2a47b9e
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 45 deletions.
22 changes: 12 additions & 10 deletions comtypes/_memberspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,46 +454,48 @@ def fset(obj, value):

# Should the funcs/mths we create have restype and/or argtypes attributes?
def _make_disp_method(self, m: _DispMemberSpec) -> Callable[..., Any]:
memid = m.memid
if "propget" in m.idlflags:

def getfunc(obj, *args, **kw):
return obj.Invoke(
memid, _invkind=2, *args, **kw
m.memid, _invkind=2, _argspec=m.argspec, *args, **kw
) # DISPATCH_PROPERTYGET

return getfunc
elif "propput" in m.idlflags:

def putfunc(obj, *args, **kw):
return obj.Invoke(
memid, _invkind=4, *args, **kw
m.memid, _invkind=4, _argspec=m.argspec, *args, **kw
) # DISPATCH_PROPERTYPUT

return putfunc
elif "propputref" in m.idlflags:

def putreffunc(obj, *args, **kw):
return obj.Invoke(
memid, _invkind=8, *args, **kw
m.memid, _invkind=8, _argspec=m.argspec, *args, **kw
) # DISPATCH_PROPERTYPUTREF

return putreffunc
# a first attempt to make use of the restype. Still, support for
# named arguments and default argument values should be added.
# a first attempt to make use of the restype.
if hasattr(m.restype, "__com_interface__"):
interface = m.restype.__com_interface__ # type: ignore

def comitffunc(obj, *args, **kw):
result = obj.Invoke(memid, _invkind=1, *args, **kw)
result = obj.Invoke(
m.memid, _invkind=1, _argspec=m.argspec, *args, **kw
)
if result is None:
return
return result.QueryInterface(interface)

return comitffunc

def func(obj, *args, **kw):
return obj.Invoke(memid, _invkind=1, *args, **kw) # DISPATCH_METHOD
return obj.Invoke(
m.memid, _invkind=1, _argspec=m.argspec, *args, **kw
) # DISPATCH_METHOD

return func

Expand Down Expand Up @@ -526,10 +528,10 @@ def __getitem__(self, index):
else:
return self.fget(self.instance, index)

def __call__(self, *args):
def __call__(self, *args, **kw):
if self.fget is None:
raise TypeError("object is not callable")
return self.fget(self.instance, *args)
return self.fget(self.instance, *args, **kw)

def __setitem__(self, index, value):
if self.fset is None:
Expand Down
155 changes: 121 additions & 34 deletions comtypes/automation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import array
import datetime
import decimal
import itertools
import sys
from ctypes import *
from ctypes import _Pointer
Expand All @@ -11,7 +12,9 @@
Any,
Callable,
ClassVar,
Container,
List,
Mapping,
Optional,
overload,
Sequence,
Expand All @@ -21,6 +24,7 @@
)

from comtypes import BSTR, COMError, COMMETHOD, GUID, IID, IUnknown, STDMETHOD
from comtypes._memberspec import _resolve_argspec
from comtypes.hresult import *
import comtypes.patcher
import comtypes
Expand Down Expand Up @@ -875,9 +879,8 @@ def Invoke(self, dispid: int, *args: Any, **kw: Any) -> Any:
# For comtypes this is handled in DISPPARAMS.__del__ and VARIANT.__del__.
_invkind = kw.pop("_invkind", 1) # DISPATCH_METHOD
_lcid = kw.pop("_lcid", 0)
if kw:
raise ValueError("named parameters not yet implemented")
dp = DispParamsGenerator(_invkind).generate(*args)
_argspec = kw.pop("_argspec", ())
dp = DispParamsGenerator(_invkind, _argspec).generate(*args, **kw)
result = VARIANT()
excepinfo = EXCEPINFO()
argerr = c_uint()
Expand Down Expand Up @@ -928,45 +931,38 @@ def Invoke(self, dispid: int, *args: Any, **kw: Any) -> Any:


class DispParamsGenerator(object):
__slots__ = ("invkind",)
__slots__ = ("invkind", "argspec")

def __init__(self, invkind: int) -> None:
def __init__(
self, invkind: int, argspec: Sequence["hints._ArgSpecElmType"]
) -> None:
self.invkind = invkind
self.argspec = argspec

def generate(self, *args: Any) -> DISPPARAMS:
def generate(self, *args: Any, **kw: Any) -> DISPPARAMS:
"""Generate `DISPPARAMS` for passing to `IDispatch::Invoke`.
Examples:
>>> _get_rgvarg = lambda dp: [dp.rgvarg[i] for i in range(dp.cArgs)]
>>> dp = DispParamsGenerator(DISPATCH_METHOD).generate(9)
>>> _get_rgvarg(dp), bool(dp.rgdispidNamedArgs), dp.cArgs, dp.cNamedArgs
([VARIANT(vt=0x3, 9)], False, 1, 0)
>>> dp = DispParamsGenerator(DISPATCH_PROPERTYGET).generate('foo', 3.14)
>>> _get_rgvarg(dp), bool(dp.rgdispidNamedArgs), dp.cArgs, dp.cNamedArgs
([VARIANT(vt=0x5, 3.14), VARIANT(vt=0x8, 'foo')], False, 2, 0)
>>> dp = DispParamsGenerator(DISPATCH_PROPERTYPUT).generate(8)
>>> _get_rgvarg(dp), dp.rgdispidNamedArgs.contents, dp.cArgs, dp.cNamedArgs
([VARIANT(vt=0x3, 8)], c_long(-3), 1, 1)
>>> dp = DispParamsGenerator(DISPATCH_PROPERTYPUTREF).generate(7, 'bar')
>>> _get_rgvarg(dp), dp.rgdispidNamedArgs.contents, dp.cArgs, dp.cNamedArgs
([VARIANT(vt=0x8, 'bar'), VARIANT(vt=0x3, 7)], c_long(-3), 2, 1)
>>> gen = DispParamsGenerator(DISPATCH_METHOD)
>>> _get_rgvarg(gen.generate())
[]
>>> _get_rgvarg(gen.generate(4))
[VARIANT(vt=0x3, 4)]
>>> _get_rgvarg(gen.generate(4, 3.14))
[VARIANT(vt=0x5, 3.14), VARIANT(vt=0x3, 4)]
>>> _get_rgvarg(gen.generate(4, 3.14, 'foo'))
[VARIANT(vt=0x8, 'foo'), VARIANT(vt=0x5, 3.14), VARIANT(vt=0x3, 4)]
Notes:
The following would be occured only when `**kw` is passed.
- Check the required arguments specified by the `argspec` are satisfied.
- Complement non-passed optional arguments with their default values
from the `argspec`.
"""
array = (VARIANT * len(args))()
for i, a in enumerate(args[::-1]):
if kw:
new_args = self._resolve_kwargs(*args, **kw)
else:
# Argument validation based on `argspec` is not triggered unless `**kw`
# is passed, because...
# - for backward compatibility with `1.2.0` and earlier.
# - there might be unexpected `argspec` in the real world.
# - `IDispatch.Invoke` might be called as a public method and `_argspec`
# is not passed.
new_args = args
array = (VARIANT * len(new_args))()
for i, a in enumerate(new_args[::-1]):
array[i].value = a
dp = DISPPARAMS()
dp.cArgs = len(args)
dp.cArgs = len(new_args)
if self.invkind in (DISPATCH_PROPERTYPUT, DISPATCH_PROPERTYPUTREF): # propput
dp.cNamedArgs = 1
dp.rgvarg = array
Expand All @@ -976,6 +972,97 @@ def generate(self, *args: Any) -> DISPPARAMS:
dp.rgvarg = array
return dp

def _resolve_kwargs(self, *args: Any, **kw: Any) -> Sequence[Any]:
pfs, _ = _resolve_argspec(self.argspec)
arg_names, arg_defaults = self._resolve_paramflags(pfs)
self._validate_unexpected(kw, arg_names, arg_defaults)
new_args, used_names = [], set()
for name in itertools.chain(arg_names, arg_defaults):
if not args and not kw:
break
if name in kw:
if args or name in used_names:
raise TypeError(f"got multiple values for argument {name!r}")
new_args.append(kw.pop(name))
used_names.add(name)
elif args:
new_args.append(args[0])
used_names.add(name)
args = args[1:]
elif name in arg_defaults:
new_args.append(arg_defaults[name])
used_names.add(name)
else:
continue
self._validate_missings(arg_names, used_names)
if args or kw:
# messages should be...
# - takes 0 positional arguments but 1 was given
# - takes 1 positional argument but N were given
# - takes L to M positional arguments but N were given
#
# `kw` resolution is only called when `**kw` is passed to `generate`.
# And `TypeError: got multiple values` is raised when there are
# multiple arguments.
# This conditional branch is for edge cases that may arise in the future.
raise TypeError # too many arguments
return new_args

def _resolve_paramflags(
self, pfs: Sequence["hints._ParamFlagType"]
) -> Tuple[Sequence[str], Mapping[str, Any]]:
arg_names, arg_defaults = [], {}
for p in pfs:
if len(p) == 2:
if arg_defaults:
raise ValueError("unexpected ordered params")
_, name = p
if name is None:
raise ValueError("unnamed argument")
arg_names.append(name)
else:
_, name, defval = p
if name is None:
raise ValueError("unnamed argument")
arg_defaults[name] = defval
return arg_names, arg_defaults

def _validate_unexpected(
self,
kw: Mapping[str, Any],
arg_names: Sequence[str],
arg_defaults: Mapping[str, Any],
) -> None:
for name in kw:
if name not in arg_names and name not in arg_defaults:
raise TypeError(f"got an unexpected keyword argument {name!r}")

def _validate_excessive(
self,
args: Sequence[Any],
kw: Mapping[str, Any],
arg_names: Sequence[str],
) -> None:
len_required_positionals = len(set(arg_names) - set(kw.keys()))
print(arg_names, kw.keys(), set(arg_names) - set(kw.keys()))
len_args = len(args)
if len_args > len_required_positionals:
raise TypeError

def _validate_missings(
self, arg_names: Sequence[str], used_names: Container[str]
) -> None:
mis = [n for n in arg_names if n not in used_names]
if not mis:
return
if len(mis) == 1:
head = "missing 1 required positional argument"
tail = repr(mis[0])
else:
head = f"missing {len(mis)} required positional arguments"
tail = ", ".join(map(repr, mis[:-1])) + f" and {mis[-1]!r}"
raise TypeError(f"{head}: {tail}")


################################################################
# safearrays
Expand Down
5 changes: 4 additions & 1 deletion comtypes/hints.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@
from comtypes.automation import IDispatch as IDispatch, VARIANT as VARIANT
from comtypes.server import IClassFactory as IClassFactory
from comtypes.typeinfo import ITypeInfo as ITypeInfo
from comtypes._memberspec import _ArgSpecElmType as _ArgSpecElmType
from comtypes._memberspec import (
_ArgSpecElmType as _ArgSpecElmType,
_ParamFlagType as _ParamFlagType,
)
Loading

0 comments on commit 2a47b9e

Please sign in to comment.