diff --git a/src/dbus_fast/_private/util.py b/src/dbus_fast/_private/util.py index 38c6a53..72e3b89 100644 --- a/src/dbus_fast/_private/util.py +++ b/src/dbus_fast/_private/util.py @@ -1,6 +1,6 @@ import ast import inspect -from typing import Any, List, Union +from typing import Any, List, Tuple, Union from ..signature import SignatureTree, Variant, get_signature_tree @@ -50,7 +50,7 @@ def signature_contains_type( def replace_fds_with_idx( signature: Union[str, SignatureTree], body: List[Any] -) -> (List[Any], List[int]): +) -> Tuple[List[Any], List[int]]: """Take the high level body format and convert it into the low level body format. Type 'h' refers directly to the fd in the body. Replace that with an index and return the corresponding list of unix fds that can be set on diff --git a/src/dbus_fast/aio/message_bus.py b/src/dbus_fast/aio/message_bus.py index 090d0cb..c05dc0a 100644 --- a/src/dbus_fast/aio/message_bus.py +++ b/src/dbus_fast/aio/message_bus.py @@ -160,15 +160,16 @@ class MessageBus(BaseMessageBus): :vartype connected: bool """ + __slots__ = ("_loop", "_auth", "_writer", "_disconnect_future") + def __init__( self, bus_address: str = None, bus_type: BusType = BusType.SESSION, auth: Authenticator = None, - negotiate_unix_fd=False, - ): - super().__init__(bus_address, bus_type, ProxyObject) - self._negotiate_unix_fd = negotiate_unix_fd + negotiate_unix_fd: bool = False, + ) -> None: + super().__init__(bus_address, bus_type, ProxyObject, negotiate_unix_fd) self._loop = asyncio.get_running_loop() self._writer = _MessageWriter(self) diff --git a/src/dbus_fast/message_bus.pxd b/src/dbus_fast/message_bus.pxd index 47723f7..db7ef4d 100644 --- a/src/dbus_fast/message_bus.pxd +++ b/src/dbus_fast/message_bus.pxd @@ -6,6 +6,7 @@ from .message cimport Message cdef object MessageType cdef object DBusError cdef object MessageFlag +cdef object ServiceInterface cdef object MESSAGE_TYPE_CALL cdef object MESSAGE_TYPE_SIGNAL @@ -17,6 +18,7 @@ cdef class BaseMessageBus: cdef public object _user_disconnect cdef public object _method_return_handlers cdef public object _serial + cdef public object _path_exports cdef public cython.list _user_message_handlers cdef public object _name_owners cdef public object _bus_address @@ -25,5 +27,11 @@ cdef class BaseMessageBus: cdef public object _high_level_client_initialized cdef public object _ProxyObject cdef public object _machine_id + cdef public object _negotiate_unix_fd + cdef public object _sock + cdef public object _stream + cdef public object _fd cpdef _process_message(self, Message msg) + + cdef _find_message_handler(self, Message msg) diff --git a/src/dbus_fast/message_bus.py b/src/dbus_fast/message_bus.py index fdbc559..c6b2663 100644 --- a/src/dbus_fast/message_bus.py +++ b/src/dbus_fast/message_bus.py @@ -67,12 +67,17 @@ class BaseMessageBus: "_serial", "_user_message_handlers", "_name_owners", + "_path_exports", "_bus_address", "_name_owner_match_rule", "_match_rules", "_high_level_client_initialized", "_ProxyObject", "_machine_id", + "_negotiate_unix_fd", + "_sock", + "_stream", + "_fd", ) def __init__( @@ -80,9 +85,11 @@ class BaseMessageBus: bus_address: Optional[str] = None, bus_type: BusType = BusType.SESSION, ProxyObject: Optional[Type[BaseProxyObject]] = None, + negotiate_unix_fd: bool = False, ) -> None: self.unique_name: Optional[str] = None self._disconnected = False + self._negotiate_unix_fd = negotiate_unix_fd # True if the user disconnected himself, so don't throw errors out of # the main loop. @@ -870,14 +877,16 @@ class BaseMessageBus: args = ServiceInterface._msg_body_to_args(msg) result = method.fn(interface, *args) body, fds = ServiceInterface._fn_result_to_body( - result, signature_tree=method.out_signature_tree + result, + signature_tree=method.out_signature_tree, + replace_fds=self._negotiate_unix_fd, ) send_reply(Message.new_method_return(msg, method.out_signature, body, fds)) return handler def _find_message_handler( - self, msg: Message + self, msg ) -> Optional[Callable[[Message, Callable], None]]: handler: Optional[Callable[[Message, Callable], None]] = None diff --git a/src/dbus_fast/service.py b/src/dbus_fast/service.py index 9114699..7f5b474 100644 --- a/src/dbus_fast/service.py +++ b/src/dbus_fast/service.py @@ -8,7 +8,9 @@ from typing import ( Callable, Dict, List, + Optional, Set, + Tuple, no_type_check_decorator, ) @@ -22,7 +24,12 @@ from ._private.util import ( from .constants import PropertyAccess from .errors import SignalDisabledError from .message import Message -from .signature import SignatureBodyMismatchError, Variant, get_signature_tree +from .signature import ( + SignatureBodyMismatchError, + SignatureTree, + Variant, + get_signature_tree, +) if TYPE_CHECKING: from .message_bus import BaseMessageBus @@ -482,19 +489,23 @@ class ServiceInterface: del interface.__handlers[bus] @staticmethod - def _msg_body_to_args(msg): - if signature_contains_type(msg.signature_tree, msg.body, "h"): - # XXX: This deep copy could be expensive if messages are very - # large. We could optimize this by only copying what we change - # here. - return replace_idx_with_fds( - msg.signature_tree, copy.deepcopy(msg.body), msg.unix_fds - ) - else: + def _msg_body_to_args(msg: Message) -> List[Any]: + if not msg.unix_fds or not signature_contains_type( + msg.signature_tree, msg.body, "h" + ): return msg.body + # XXX: This deep copy could be expensive if messages are very + # large. We could optimize this by only copying what we change + # here. + return replace_idx_with_fds( + msg.signature_tree, copy.deepcopy(msg.body), msg.unix_fds + ) + @staticmethod - def _fn_result_to_body(result, signature_tree): + def _fn_result_to_body( + result: List[Any], signature_tree: SignatureTree, replace_fds: bool = True + ) -> Tuple[List[Any], List[int]]: """The high level interfaces may return single values which may be wrapped in a list to be a message body. Also they may return fds directly for type 'h' which need to be put into an external list.""" @@ -515,10 +526,14 @@ class ServiceInterface: f"Signature and function return mismatch, expected {len(signature_tree.types)} arguments but got {len(result)}" ) + if not replace_fds: + return result, [] return replace_fds_with_idx(signature_tree, result) @staticmethod - def _handle_signal(interface, signal, result): + def _handle_signal( + interface: "ServiceInterface", signal: _Signal, result: List[Any] + ) -> None: body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree) for bus in ServiceInterface._get_buses(interface): bus._interface_signal_notify( @@ -526,7 +541,7 @@ class ServiceInterface: ) @staticmethod - def _get_property_value(interface, prop, callback): + def _get_property_value(interface: "ServiceInterface", prop: _Property, callback): # XXX MUST CHECK TYPE RETURNED BY GETTER try: if asyncio.iscoroutinefunction(prop.prop_getter): @@ -551,7 +566,7 @@ class ServiceInterface: callback(interface, prop, None, e) @staticmethod - def _set_property_value(interface, prop, value, callback): + def _set_property_value(interface: "ServiceInterface", prop, value, callback): # XXX MUST CHECK TYPE TO SET try: if asyncio.iscoroutinefunction(prop.prop_setter): @@ -575,7 +590,9 @@ class ServiceInterface: callback(interface, prop, e) @staticmethod - def _get_all_property_values(interface, callback, user_data=None): + def _get_all_property_values( + interface: "ServiceInterface", callback, user_data=None + ): result = {} result_error = None @@ -588,7 +605,12 @@ class ServiceInterface: callback(interface, result, user_data, None) return - def get_property_callback(interface, prop, value, e): + def get_property_callback( + interface: "ServiceInterface", + prop: _Property, + value: Any, + e: Optional[Exception], + ) -> None: nonlocal result_error if e is not None: result_error = e