diff --git a/pydbus/proxy_method.py b/pydbus/proxy_method.py index 3e6e6ee..e61b148 100644 --- a/pydbus/proxy_method.py +++ b/pydbus/proxy_method.py @@ -2,6 +2,7 @@ from .generic import bound_method from .identifier import filter_identifier from .timeout import timeout_to_glib +from . import unixfd try: from inspect import Signature, Parameter @@ -69,10 +70,23 @@ def __call__(self, instance, *args, **kwargs): raise TypeError(self.__qualname__ + " got an unexpected keyword argument '{}'".format(kwarg)) timeout = kwargs.get("timeout", None) - ret = instance._bus.con.call_sync( - instance._bus_name, instance._path, - self._iface_name, self.__name__, GLib.Variant(self._sinargs, args), GLib.VariantType.new(self._soutargs), - 0, timeout_to_glib(timeout), None).unpack() + if unixfd.is_supported(instance._bus.con): + fd_list = unixfd.make_fd_list( + args, + [arg[1] for arg in self._inargs]) + ret, fd_list = instance._bus.con.call_with_unix_fd_list_sync( + instance._bus_name, instance._path, + self._iface_name, self.__name__, GLib.Variant(self._sinargs, args), GLib.VariantType.new(self._soutargs), + 0, timeout_to_glib(timeout), fd_list, None) + ret = unixfd.extract( + ret.unpack(), + self._outargs, + fd_list) + else: + ret = instance._bus.con.call_sync( + instance._bus_name, instance._path, + self._iface_name, self.__name__, GLib.Variant(self._sinargs, args), GLib.VariantType.new(self._soutargs), + 0, timeout_to_glib(timeout), None) if len(self._outargs) == 0: return None diff --git a/pydbus/registration.py b/pydbus/registration.py index f531539..ee894ae 100644 --- a/pydbus/registration.py +++ b/pydbus/registration.py @@ -5,6 +5,7 @@ from .exitable import ExitableWithAliases from functools import partial from .method_call_context import MethodCallContext +from . import unixfd import logging try: @@ -18,10 +19,12 @@ class ObjectWrapper(ExitableWithAliases("unwrap")): def __init__(self, object, interfaces): self.object = object + self.inargs = {} self.outargs = {} for iface in interfaces: for method in iface.methods: self.outargs[iface.name + "." + method.name] = [arg.signature for arg in method.out_args] + self.inargs[iface.name + "." + method.name] = [arg.signature for arg in method.in_args] self.readable_properties = {} self.writable_properties = {} @@ -54,6 +57,7 @@ def onPropertiesChanged(iface, changed, invalidated): def call_method(self, connection, sender, object_path, interface_name, method_name, parameters, invocation): try: try: + inargs = self.inargs[interface_name + "." + method_name] outargs = self.outargs[interface_name + "." + method_name] method = getattr(self.object, method_name) except KeyError: @@ -61,12 +65,15 @@ def call_method(self, connection, sender, object_path, interface_name, method_na if method_name == "Get": method = self.Get outargs = ["v"] + inargs = ["ss"] elif method_name == "GetAll": method = self.GetAll outargs = ["a{sv}"] + inargs = ["s"] elif method_name == "Set": method = self.Set outargs = [] + inargs = ["ssv"] else: raise else: @@ -78,14 +85,23 @@ def call_method(self, connection, sender, object_path, interface_name, method_na if "dbus_context" in sig.parameters and sig.parameters["dbus_context"].kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY): kwargs["dbus_context"] = MethodCallContext(invocation) + if unixfd.is_supported(connection): + parameters = unixfd.extract( + parameters, + inargs, + invocation.get_message().get_unix_fd_list()) + result = method(*parameters, **kwargs) if len(outargs) == 0: invocation.return_value(None) - elif len(outargs) == 1: - invocation.return_value(GLib.Variant("(" + "".join(outargs) + ")", (result,))) else: - invocation.return_value(GLib.Variant("(" + "".join(outargs) + ")", result)) + if len(outargs) == 1: + result = (result, ) + if unixfd.is_supported(connection): + invocation.return_value_with_unix_fd_list(GLib.Variant("(" + "".join(outargs) + ")", result), unixfd.make_fd_list(result, outargs, steal=True)) + else: + invocation.return_value(GLib.Variant("(" + "".join(outargs) + ")", result)) except Exception as e: logger = logging.getLogger(__name__) @@ -151,6 +167,5 @@ def register_object(self, path, object, node_info): node_info = [Gio.DBusNodeInfo.new_for_xml(ni) for ni in node_info] interfaces = sum((ni.interfaces for ni in node_info), []) - wrapper = ObjectWrapper(object, interfaces) return ObjectRegistration(self, path, interfaces, wrapper, own_wrapper=True) diff --git a/pydbus/unixfd.py b/pydbus/unixfd.py new file mode 100644 index 0000000..61e26a6 --- /dev/null +++ b/pydbus/unixfd.py @@ -0,0 +1,53 @@ +from gi.repository import Gio + +# signature type code +TYPE_FD = "h" + +def is_supported(conn): + """ + Check if the message bus supports passing of Unix file descriptors. + """ + return conn.get_capabilities() & Gio.DBusCapabilityFlags.UNIX_FD_PASSING + + +def extract(params, signature, fd_list): + """ + Extract any file descriptors from a UnixFDList (e.g. after + receiving from D-Bus) to a parameter list. + Receiver must call os.dup on any fd it decides to keep/use. + """ + if not fd_list: + return params + return [fd_list.get(0) + if arg == TYPE_FD + else val + for val, arg + in zip(params, signature)] + + +def make_fd_list(params, signature, steal=False): + """ + Embed any unix file descriptors in a parameter list into a + UnixFDList (for D-Bus-dispatch). + If steal is true, the responsibility for closing the file + descriptors are transferred to the UnixFDList object. + If steal is false, the file descriptors will be duplicated + and the caller must close the original file descriptors. + """ + if not any(arg + for arg in signature + if arg == TYPE_FD): + return None + + fds = [param + for param, arg + in zip(params, signature) + if arg == TYPE_FD] + + if steal: + return Gio.UnixFDList.new_from_array(fds) + + fd_list = Gio.UnixFDList() + for fd in fds: + fd_list.append(fd) + return fd_list diff --git a/tests/run.sh b/tests/run.sh index 8d93644..8e0b44b 100755 --- a/tests/run.sh +++ b/tests/run.sh @@ -15,4 +15,5 @@ then "$PYTHON" $TESTS_DIR/publish.py "$PYTHON" $TESTS_DIR/publish_properties.py "$PYTHON" $TESTS_DIR/publish_multiface.py + "$PYTHON" $TESTS_DIR/unixfd.py fi diff --git a/tests/unixfd.py b/tests/unixfd.py new file mode 100644 index 0000000..72bfdf3 --- /dev/null +++ b/tests/unixfd.py @@ -0,0 +1,60 @@ +from pydbus import SessionBus +from gi.repository import GLib +from threading import Thread +import sys +import os + +loop = GLib.MainLoop() + + +with open(__file__) as f: + contents = f.read() + + +class TestObject(object): + """ + + + + + + + + + """ + def Hello(self, in_fd): + with os.fdopen(in_fd) as in_file: + in_file.seek(0) + assert(contents == in_file.read()) + print("Received fd as in parameter ok") + with open(__file__) as out_file: + assert(contents == out_file.read()) + return os.dup(out_file.fileno()) + +bus = SessionBus() + + +with bus.publish("baz.bar.Foo", TestObject()): + remote = bus.get("baz.bar.Foo") + + def thread_func(): + with open(__file__) as in_file: + assert(contents == in_file.read()) + out_fd = remote.Hello(in_file.fileno()) + with os.fdopen(out_fd) as out_file: + out_file.seek(0) + assert(contents == out_file.read()) + print("Received fd as out argument ok") + loop.quit() + + thread = Thread(target=thread_func) + thread.daemon = True + + def handle_timeout(): + exit("ERROR: Timeout.") + + GLib.timeout_add_seconds(2, handle_timeout) + + thread.start() + loop.run() + thread.join()