From 996659e1b5fefeda7eb01259714a4a17fc224b9f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 5 Mar 2025 13:28:00 -1000 Subject: [PATCH] feat: refactor service bus handler lookup to avoid linear searches (#400) --- src/dbus_fast/constants.py | 2 +- src/dbus_fast/introspection.py | 2 +- src/dbus_fast/message_bus.pxd | 11 +- src/dbus_fast/message_bus.py | 343 +++++++++++----------- src/dbus_fast/send_reply.py | 19 +- src/dbus_fast/service.pxd | 7 +- src/dbus_fast/service.py | 155 ++++++---- tests/service/test_export.py | 19 +- tests/service/test_standard_interfaces.py | 19 +- 9 files changed, 315 insertions(+), 262 deletions(-) diff --git a/src/dbus_fast/constants.py b/src/dbus_fast/constants.py index 4078403..ff8691b 100644 --- a/src/dbus_fast/constants.py +++ b/src/dbus_fast/constants.py @@ -37,7 +37,7 @@ class MessageFlag(IntFlag): ALLOW_INTERACTIVE_AUTHORIZATION = 4 @cached_property - def value(self) -> str: + def value(self) -> int: """Return the value.""" return self._value_ diff --git a/src/dbus_fast/introspection.py b/src/dbus_fast/introspection.py index 36dbb50..42dc329 100644 --- a/src/dbus_fast/introspection.py +++ b/src/dbus_fast/introspection.py @@ -51,7 +51,7 @@ class Arg: def __init__( self, signature: Union[SignatureType, str], - direction: Optional[list[ArgDirection]] = None, + direction: Optional[ArgDirection] = None, name: Optional[str] = None, annotations: Optional[dict[str, str]] = None, ): diff --git a/src/dbus_fast/message_bus.pxd b/src/dbus_fast/message_bus.pxd index 9c88836..f7a0a4a 100644 --- a/src/dbus_fast/message_bus.pxd +++ b/src/dbus_fast/message_bus.pxd @@ -4,6 +4,7 @@ from ._private.address cimport get_bus_address, parse_address from .message cimport Message from .service cimport ServiceInterface, _Method +cdef bint TYPE_CHECKING cdef object MessageType cdef object DBusError @@ -39,24 +40,26 @@ cdef class BaseMessageBus: cdef public object _high_level_client_initialized cdef public object _ProxyObject cdef public object _machine_id - cdef public object _negotiate_unix_fd + cdef public bint _negotiate_unix_fd cdef public object _sock cdef public object _stream cdef public object _fd - cpdef _process_message(self, Message msg) + cpdef void _process_message(self, Message msg) + + @cython.locals(exported_service_interface=ServiceInterface) + cpdef export(self, str path, ServiceInterface interface) @cython.locals( methods=cython.list, method=_Method, interface=ServiceInterface, - interfaces=cython.list, + interfaces=dict, ) cdef _find_message_handler(self, Message msg) cdef _setup_socket(self) - @cython.locals(no_reply_expected=bint) cpdef _call(self, Message msg, object callback) cpdef next_serial(self) diff --git a/src/dbus_fast/message_bus.py b/src/dbus_fast/message_bus.py index d0aa795..6626047 100644 --- a/src/dbus_fast/message_bus.py +++ b/src/dbus_fast/message_bus.py @@ -1,10 +1,11 @@ +from __future__ import annotations import inspect import logging import socket import traceback import xml.etree.ElementTree as ET from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, TYPE_CHECKING from . import introspection as intr from ._private.address import get_bus_address, parse_address @@ -22,7 +23,7 @@ from .errors import DBusError, InvalidAddressError from .message import Message from .proxy_object import BaseProxyObject from .send_reply import SendReply -from .service import ServiceInterface, _Method +from .service import ServiceInterface, _Method, _Property, HandlerType from .signature import Variant from .validators import assert_bus_name_valid, assert_object_path_valid @@ -119,12 +120,12 @@ class BaseMessageBus: def __init__( self, - bus_address: Optional[str] = None, + bus_address: str | None = None, bus_type: BusType = BusType.SESSION, - ProxyObject: Optional[type[BaseProxyObject]] = None, + ProxyObject: type[BaseProxyObject] | None = None, negotiate_unix_fd: bool = False, ) -> None: - self.unique_name: Optional[str] = None + self.unique_name: str | None = None self._disconnected = False self._negotiate_unix_fd = negotiate_unix_fd @@ -133,11 +134,11 @@ class BaseMessageBus: self._user_disconnect = False self._method_return_handlers: dict[ - int, Callable[[Optional[Message], Optional[Exception]], None] + int, Callable[[Message | None, Exception | None], None] ] = {} self._serial = 0 self._user_message_handlers: list[ - Callable[[Message], Union[Message, bool, None]] + Callable[[Message], Message | bool | None] ] = [] # the key is the name and the value is the unique name of the owner. # This cache is kept up to date by the NameOwnerChanged signal and is @@ -145,7 +146,7 @@ class BaseMessageBus: # high level client only) self._name_owners: dict[str, str] = {} # used for the high level service - self._path_exports: dict[str, list[ServiceInterface]] = {} + self._path_exports: dict[str, dict[str, ServiceInterface]] = {} self._bus_address = ( parse_address(bus_address) if bus_address @@ -161,10 +162,10 @@ class BaseMessageBus: self._ProxyObject = ProxyObject # machine id is lazy loaded - self._machine_id: Optional[int] = None - self._sock: Optional[socket.socket] = None - self._fd: Optional[int] = None - self._stream: Optional[Any] = None + self._machine_id: int | None = None + self._sock: socket.socket | None = None + self._fd: int | None = None + self._stream: Any | None = None self._setup_socket() @@ -193,20 +194,18 @@ class BaseMessageBus: raise TypeError("interface must be a ServiceInterface") if path not in self._path_exports: - self._path_exports[path] = [] + self._path_exports[path] = {} + elif interface.name in self._path_exports[path]: + raise ValueError( + f'An interface with this name is already exported on this bus at path "{path}": "{interface.name}"' + ) - for f in self._path_exports[path]: - if f.name == interface.name: - raise ValueError( - f'An interface with this name is already exported on this bus at path "{path}": "{interface.name}"' - ) - - self._path_exports[path].append(interface) + self._path_exports[path][interface.name] = interface ServiceInterface._add_bus(interface, self, self._make_method_handler) self._emit_interface_added(path, interface) def unexport( - self, path: str, interface: Optional[Union[ServiceInterface, str]] = None + self, path: str, interface: ServiceInterface | str | None = None ) -> None: """Unexport the path or service interface to make it no longer available to clients. @@ -222,45 +221,42 @@ class BaseMessageBus: - :class:`InvalidObjectPathError ` - If the given object path is not valid. """ assert_object_path_valid(path) - if type(interface) not in [str, type(None)] and not isinstance( - interface, ServiceInterface - ): - raise TypeError("interface must be a ServiceInterface or interface name") - - if path not in self._path_exports: - return - - exports = self._path_exports[path] - - if type(interface) is str: - try: - interface = next(iface for iface in exports if iface.name == interface) - except StopIteration: - return - - removed_interfaces = [] + interface_name: str | None if interface is None: - del self._path_exports[path] - for iface in filter(lambda e: not self._has_interface(e), exports): - removed_interfaces.append(iface.name) - ServiceInterface._remove_bus(iface, self) + interface_name = None + elif type(interface) is str: + interface_name = interface + elif isinstance(interface, ServiceInterface): + interface_name = interface.name else: - for i, iface in enumerate(exports): - if iface is interface: - removed_interfaces.append(iface.name) - del self._path_exports[path][i] - if not self._path_exports[path]: - del self._path_exports[path] - if not self._has_interface(iface): - ServiceInterface._remove_bus(iface, self) - break - self._emit_interface_removed(path, removed_interfaces) + raise TypeError( + f"interface must be a ServiceInterface or interface name not {type(interface)}" + ) + + if (interfaces := self._path_exports.get(path)) is None: + return + removed_interface_names: list[str] = [] + + if interface_name is not None: + if (removed_interface := interfaces.pop(interface_name, None)) is None: + return + removed_interface_names.append(interface_name) + if not interfaces: + del self._path_exports[path] + ServiceInterface._remove_bus(removed_interface, self) + else: + del self._path_exports[path] + for removed_interface in interfaces.values(): + removed_interface_names.append(removed_interface.name) + ServiceInterface._remove_bus(removed_interface, self) + + self._emit_interface_removed(path, removed_interface_names) def introspect( self, bus_name: str, path: str, - callback: Callable[[Optional[intr.Node], Optional[Exception]], None], + callback: Callable[[intr.Node | None, Exception | None], None], check_callback_type: bool = True, validate_property_names: bool = True, ) -> None: @@ -289,12 +285,13 @@ class BaseMessageBus: if check_callback_type: BaseMessageBus._check_callback_type(callback) - def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None: + def reply_notify(reply: Message | None, err: Exception | None) -> None: try: BaseMessageBus._check_method_return(reply, err, "s") result = intr.Node.parse( - reply.body[0], validate_property_names=validate_property_names - ) # type: ignore[union-attr] + reply.body[0], # type: ignore[union-attr] + validate_property_names=validate_property_names, + ) except Exception as e: callback(None, e) return @@ -330,7 +327,7 @@ class BaseMessageBus: interface: ServiceInterface, result: Any, user_data: Any, - e: Optional[Exception], + e: Exception | None, ) -> None: if e is not None: try: @@ -385,9 +382,8 @@ class BaseMessageBus: self, name: str, flags: NameFlag = NameFlag.NONE, - callback: Optional[ - Callable[[Optional[RequestNameReply], Optional[Exception]], None] - ] = None, + callback: None + | (Callable[[RequestNameReply | None, Exception | None], None]) = None, check_callback_type: bool = True, ) -> None: """Request that this message bus owns the given name. @@ -424,24 +420,23 @@ class BaseMessageBus: self._call(message, None) return - def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None: + def reply_notify(reply: Message | None, err: Exception | None) -> None: try: BaseMessageBus._check_method_return(reply, err, "u") result = RequestNameReply(reply.body[0]) # type: ignore[union-attr] except Exception as e: - callback(None, e) # type: ignore[misc] + callback(None, e) return - callback(result, None) # type: ignore[misc] + callback(result, None) self._call(message, reply_notify) def release_name( self, name: str, - callback: Optional[ - Callable[[Optional[ReleaseNameReply], Optional[Exception]], None] - ] = None, + callback: None + | (Callable[[ReleaseNameReply | None, Exception | None], None]) = None, check_callback_type: bool = True, ) -> None: """Request that this message bus release the given name. @@ -474,20 +469,20 @@ class BaseMessageBus: self._call(message, None) return - def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None: + def reply_notify(reply: Message | None, err: Exception | None) -> None: try: BaseMessageBus._check_method_return(reply, err, "u") result = ReleaseNameReply(reply.body[0]) # type: ignore[union-attr] except Exception as e: - callback(None, e) # type: ignore[misc] + callback(None, e) return - callback(result, None) # type: ignore[misc] + callback(result, None) self._call(message, reply_notify) def get_proxy_object( - self, bus_name: str, path: str, introspection: Union[intr.Node, str, ET.Element] + self, bus_name: str, path: str, introspection: intr.Node | str | ET.Element ) -> BaseProxyObject: """Get a proxy object for the path exported on the bus that owns the name. The object is expected to export the interfaces and nodes @@ -526,10 +521,11 @@ class BaseMessageBus: All pending and future calls will error with a connection error. """ self._user_disconnect = True - try: - self._sock.shutdown(socket.SHUT_RDWR) - except Exception: - logging.warning("could not shut down socket", exc_info=True) + if self._sock: + try: + self._sock.shutdown(socket.SHUT_RDWR) + except Exception: + logging.warning("could not shut down socket", exc_info=True) def next_serial(self) -> int: """Get the next serial for this bus. This can be used as the ``serial`` @@ -543,7 +539,7 @@ class BaseMessageBus: return self._serial def add_message_handler( - self, handler: Callable[[Message], Optional[Union[Message, bool]]] + self, handler: Callable[[Message], Message | bool | None] ) -> None: """Add a custom message handler for incoming messages. @@ -568,7 +564,7 @@ class BaseMessageBus: self._user_message_handlers.append(handler) def remove_message_handler( - self, handler: Callable[[Message], Optional[Union[Message, bool]]] + self, handler: Callable[[Message], Message | bool | None] ) -> None: """Remove a message handler that was previously added by :func:`add_message_handler() @@ -592,7 +588,7 @@ class BaseMessageBus: 'the "send" method must be implemented in the inheriting class' ) - def _finalize(self, err: Optional[Exception]) -> None: + def _finalize(self, err: Exception | None) -> None: """should be called after the socket disconnects with the disconnection error to clean up resources and put the bus in a disconnected state""" if self._disconnected: @@ -615,14 +611,6 @@ class BaseMessageBus: self._user_message_handlers.clear() - def _has_interface(self, interface: ServiceInterface) -> bool: - for _, exports in self._path_exports.items(): - for iface in exports: - if iface is interface: - return True - - return False - def _interface_signal_notify( self, interface: ServiceInterface, @@ -632,9 +620,9 @@ class BaseMessageBus: body: list[Any], unix_fds: list[int] = [], ) -> None: - path = None + path: str | None = None for p, ifaces in self._path_exports.items(): - for i in ifaces: + for i in ifaces.values(): if i is interface: path = p @@ -657,9 +645,9 @@ class BaseMessageBus: def _introspect_export_path(self, path: str) -> intr.Node: assert_object_path_valid(path) - if path in self._path_exports: + if (interfaces := self._path_exports.get(path)) is not None: node = intr.Node.default(path) - for interface in self._path_exports[path]: + for interface in interfaces.values(): node.interfaces.append(interface.introspect()) else: node = intr.Node(path) @@ -687,7 +675,7 @@ class BaseMessageBus: err = None for transport, options in self._bus_address: - filename = None + filename: bytes | str | None = None ip_addr = "" ip_port = 0 @@ -738,9 +726,9 @@ class BaseMessageBus: def _reply_notify( self, msg: Message, - callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]], - reply: Optional[Message], - err: Optional[Exception], + callback: Callable[[Message | None, Exception | None], None], + reply: Message | None, + err: Exception | None, ) -> None: """Callback on reply.""" if reply and msg.destination and reply.sender: @@ -750,24 +738,23 @@ class BaseMessageBus: def _call( self, msg: Message, - callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]], + callback: Callable[[Message | None, Exception | None], None] | None, ) -> None: if not msg.serial: msg.serial = self.next_serial() - reply_expected = _expects_reply(msg) # Make sure the return reply handler is installed # before sending the message to avoid a race condition # where the reply is lost in case the backend can # send it right away. - if reply_expected: + if (reply_expected := _expects_reply(msg)) and callback is not None: self._method_return_handlers[msg.serial] = partial( self._reply_notify, msg, callback ) self.send(msg) - if not reply_expected: + if not reply_expected and callback is not None: callback(None, None) @staticmethod @@ -785,7 +772,7 @@ class BaseMessageBus: @staticmethod def _check_method_return( - msg: Optional[Message], err: Optional[Exception], signature: str + msg: Message | None, err: Exception | None, signature: str ) -> None: if err: raise err @@ -809,8 +796,7 @@ class BaseMessageBus: handled = False for user_handler in self._user_message_handlers: try: - result = user_handler(msg) - if result: + if result := user_handler(msg): if type(result) is Message: self.send(result) handled = True @@ -842,8 +828,8 @@ class BaseMessageBus: and msg.path == "/org/freedesktop/DBus" and msg.interface == "org.freedesktop.DBus" ): - [name, old_owner, new_owner] = msg.body - if new_owner: + name = msg.body[0] + if new_owner := msg.body[2]: self._name_owners[name] = new_owner elif name in self._name_owners: del self._name_owners[name] @@ -852,9 +838,9 @@ class BaseMessageBus: if msg.message_type is MESSAGE_TYPE_CALL: if not handled: handler = self._find_message_handler(msg) - if _expects_reply(msg) is False: + if not _expects_reply(msg): if handler: - handler(msg, BLOCK_UNEXPECTED_REPLY) + handler(msg, BLOCK_UNEXPECTED_REPLY) # type: ignore[arg-type] else: _LOGGER.error( '"%s.%s" with signature "%s" could not be found', @@ -891,17 +877,17 @@ class BaseMessageBus: interface: ServiceInterface, method: _Method, msg: Message, - send_reply: Callable[[Message], None], + send_reply: SendReply, ) -> None: """This is the callback that will be called when a method call is.""" args = ServiceInterface._c_msg_body_to_args(msg) if msg.unix_fds else msg.body result = method.fn(interface, *args) if send_reply is BLOCK_UNEXPECTED_REPLY or _expects_reply(msg) is False: return - body, fds = ServiceInterface._c_fn_result_to_body( + body_fds = ServiceInterface._c_fn_result_to_body( result, - signature_tree=method.out_signature_tree, - replace_fds=self._negotiate_unix_fd, + method.out_signature_tree, + self._negotiate_unix_fd, ) send_reply( Message( @@ -909,19 +895,20 @@ class BaseMessageBus: reply_serial=msg.serial, destination=msg.sender, signature=method.out_signature, - body=body, - unix_fds=fds, + body=body_fds[0], + unix_fds=body_fds[1], ) ) def _make_method_handler( self, interface: ServiceInterface, method: _Method - ) -> Callable[[Message, Callable[[Message], None]], None]: + ) -> HandlerType: return partial(self._callback_method_handler, interface, method) - def _find_message_handler( - self, msg: _Message - ) -> Optional[Callable[[Message, Callable[[Message], None]], None]]: + def _find_message_handler(self, msg: _Message) -> HandlerType | None: + if TYPE_CHECKING: + assert msg.interface is not None + if "org.freedesktop.DBus." in msg.interface: if ( msg.interface == "org.freedesktop.DBus.Introspectable" @@ -945,54 +932,52 @@ class BaseMessageBus: ): return self._default_get_managed_objects_handler - msg_path = msg.path - if msg_path: - interfaces = self._path_exports.get(msg_path) - if not interfaces: - return None - for interface in interfaces: - methods = ServiceInterface._c_get_methods(interface) - for method in methods: - if method.disabled: - continue - - if ( - msg.interface == interface.name - and msg.member == method.name - and msg.signature == method.in_signature - ): - return ServiceInterface._c_get_handler(interface, method, self) - + if ( + msg.path is not None + and msg.member is not None + and (interfaces := self._path_exports.get(msg.path)) is not None + and (interface := interfaces.get(msg.interface)) is not None + and ( + handler := ServiceInterface._get_enabled_handler_by_name_signature( + interface, self, msg.member, msg.signature + ) + ) + is not None + ): + return handler return None - def _default_introspect_handler( - self, msg: Message, send_reply: Callable[[Message], None] - ) -> None: + def _default_introspect_handler(self, msg: Message, send_reply: SendReply) -> None: + if TYPE_CHECKING: + assert msg.path is not None introspection = self._introspect_export_path(msg.path).tostring() send_reply(Message.new_method_return(msg, "s", [introspection])) - def _default_ping_handler( - self, msg: Message, send_reply: Callable[[Message], None] - ) -> None: + def _default_ping_handler(self, msg: Message, send_reply: SendReply) -> None: send_reply(Message.new_method_return(msg)) + def _send_machine_id_reply(self, msg: Message, send_reply: SendReply) -> None: + send_reply(Message.new_method_return(msg, "s", [self._machine_id])) + def _default_get_machine_id_handler( - self, msg: Message, send_reply: Callable[[Message], None] + self, msg: Message, send_reply: SendReply ) -> None: if self._machine_id: - send_reply(Message.new_method_return(msg, "s", self._machine_id)) + self._send_machine_id_reply(msg, send_reply) return - def reply_handler(reply, err): - if err: + def reply_handler(reply: Message | None, err: Exception | None) -> None: + if err or reply is None: # the bus has been disconnected, cannot send a reply return if reply.message_type == MessageType.METHOD_RETURN: self._machine_id = reply.body[0] - send_reply(Message.new_method_return(msg, "s", [self._machine_id])) - elif reply.message_type == MessageType.ERROR: - send_reply(Message.new_error(msg, reply.error_name, reply.body)) + self._send_machine_id_reply(msg, send_reply) + elif ( + reply.message_type == MessageType.ERROR and reply.error_name is not None + ): + send_reply(Message.new_error(msg, reply.error_name, str(reply.body))) else: send_reply( Message.new_error(msg, ErrorType.FAILED, "could not get machine_id") @@ -1009,13 +994,12 @@ class BaseMessageBus: ) def _default_get_managed_objects_handler( - self, msg: Message, send_reply: Callable[[Message], None] + self, msg: Message, send_reply: SendReply ) -> None: - result = {} result_signature = "a{oa{sa{sv}}}" error_handled = False - def is_result_complete(): + def is_result_complete() -> bool: if not result: return True for n, interfaces in result.items(): @@ -1025,6 +1009,9 @@ class BaseMessageBus: return True + if TYPE_CHECKING: + assert msg.path is not None + nodes = [ node for node in self._path_exports @@ -1032,16 +1019,18 @@ class BaseMessageBus: ] # first build up the result object to know when it's complete - for node in nodes: - result[node] = {} - for interface in self._path_exports[node]: - result[node][interface.name] = None + result: dict[str, dict[str, Any]] = { + node: {interface: None for interface in self._path_exports[node]} + for node in nodes + } if is_result_complete(): send_reply(Message.new_method_return(msg, result_signature, [result])) return - def get_all_properties_callback(interface, values, node, err): + def get_all_properties_callback( + interface: ServiceInterface, values: Any, node: str, err: Exception | None + ) -> None: nonlocal error_handled if err is not None: if not error_handled: @@ -1055,14 +1044,12 @@ class BaseMessageBus: send_reply(Message.new_method_return(msg, result_signature, [result])) for node in nodes: - for interface in self._path_exports[node]: + for interface in self._path_exports[node].values(): ServiceInterface._get_all_property_values( interface, get_all_properties_callback, node ) - def _default_properties_handler( - self, msg: Message, send_reply: Callable[[Message], None] - ) -> None: + def _default_properties_handler(self, msg: Message, send_reply: SendReply) -> None: methods = {"Get": "ss", "Set": "ssv", "GetAll": "s"} if msg.member not in methods or methods[msg.member] != msg.signature: raise DBusError( @@ -1082,12 +1069,7 @@ class BaseMessageBus: ErrorType.UNKNOWN_OBJECT, f'no interfaces at path: "{msg.path}"' ) - match = [ - iface - for iface in self._path_exports[msg.path] - if iface.name == interface_name - ] - if not match: + if (interface := self._path_exports[msg.path].get(interface_name)) is None: if interface_name in [ "org.freedesktop.DBus.Properties", "org.freedesktop.DBus.Introspectable", @@ -1111,7 +1093,6 @@ class BaseMessageBus: f'could not find an interface "{interface_name}" at path: "{msg.path}"', ) - interface = match[0] properties = ServiceInterface._get_properties(interface) if msg.member == "Get" or msg.member == "Set": @@ -1135,7 +1116,12 @@ class BaseMessageBus: "the property does not have read access", ) - def get_property_callback(interface, prop, prop_value, err): + def get_property_callback( + interface: ServiceInterface, + prop: _Property, + prop_value: Any, + err: Exception | None, + ) -> None: try: if err is not None: send_reply.send_error(err) @@ -1173,7 +1159,9 @@ class BaseMessageBus: ) assert prop.prop_setter - def set_property_callback(interface, prop, err): + def set_property_callback( + interface: ServiceInterface, prop: _Property, err: Exception | None + ) -> None: if err is not None: send_reply.send_error(err) return @@ -1188,7 +1176,12 @@ class BaseMessageBus: elif msg.member == "GetAll": - def get_all_properties_callback(interface, values, user_data, err): + def get_all_properties_callback( + interface: ServiceInterface, + values: Any, + user_data: Any, + err: Exception | None, + ) -> None: if err is not None: send_reply.send_error(err) return @@ -1212,12 +1205,12 @@ class BaseMessageBus: return self._high_level_client_initialized = True - def add_match_notify(msg, err): + def add_match_notify(msg: Message | None, err: Exception | None) -> None: if err: logging.error( f'add match request failed. match="{self._name_owner_match_rule}", {err}' ) - elif msg.message_type == MessageType.ERROR: + elif msg is not None and msg.message_type == MessageType.ERROR: logging.error( f'add match request failed. match="{self._name_owner_match_rule}", {msg.body[0]}' ) @@ -1234,7 +1227,7 @@ class BaseMessageBus: add_match_notify, ) - def _add_match_rule(self, match_rule): + def _add_match_rule(self, match_rule: str) -> None: """Add a match rule. Match rules added by this function are refcounted and must be removed by _remove_match_rule(). This is for use in the high level client only.""" @@ -1247,10 +1240,10 @@ class BaseMessageBus: self._match_rules[match_rule] = 1 - def add_match_notify(msg: Message, err: Optional[Exception]) -> None: + def add_match_notify(msg: Message | None, err: Exception | None) -> None: if err: logging.error(f'add match request failed. match="{match_rule}", {err}') - elif msg.message_type == MessageType.ERROR: + elif msg is not None and msg.message_type == MessageType.ERROR: logging.error( f'add match request failed. match="{match_rule}", {msg.body[0]}' ) @@ -1267,7 +1260,7 @@ class BaseMessageBus: add_match_notify, ) - def _remove_match_rule(self, match_rule): + def _remove_match_rule(self, match_rule: str) -> None: """Remove a match rule added with _add_match_rule(). This is for use in the high level client only.""" if match_rule == self._name_owner_match_rule: @@ -1280,7 +1273,7 @@ class BaseMessageBus: del self._match_rules[match_rule] - def remove_match_notify(msg, err): + def remove_match_notify(msg: Message | None, err: Exception | None) -> None: if self._disconnected: return @@ -1288,7 +1281,7 @@ class BaseMessageBus: logging.error( f'remove match request failed. match="{match_rule}", {err}' ) - elif msg.message_type == MessageType.ERROR: + elif msg is not None and msg.message_type == MessageType.ERROR: logging.error( f'remove match request failed. match="{match_rule}", {msg.body[0]}' ) diff --git a/src/dbus_fast/send_reply.py b/src/dbus_fast/send_reply.py index 00d4440..f7189b5 100644 --- a/src/dbus_fast/send_reply.py +++ b/src/dbus_fast/send_reply.py @@ -1,6 +1,7 @@ +from __future__ import annotations import traceback from types import TracebackType -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from .constants import ErrorType from .errors import DBusError @@ -15,12 +16,12 @@ class SendReply: __slots__ = ("_bus", "_msg") - def __init__(self, bus: "BaseMessageBus", msg: Message) -> None: + def __init__(self, bus: BaseMessageBus, msg: Message) -> None: """Create a new reply context manager.""" self._bus = bus self._msg = msg - def __enter__(self): + def __enter__(self) -> SendReply: return self def __call__(self, reply: Message) -> None: @@ -28,9 +29,9 @@ class SendReply: def _exit( self, - exc_type: Optional[type[Exception]], - exc_value: Optional[Exception], - tb: Optional[TracebackType], + exc_type: type[Exception] | None, + exc_value: Exception | None, + tb: TracebackType | None, ) -> bool: if exc_value: if isinstance(exc_value, DBusError): @@ -49,9 +50,9 @@ class SendReply: def __exit__( self, - exc_type: Optional[type[Exception]], - exc_value: Optional[Exception], - tb: Optional[TracebackType], + exc_type: type[Exception] | None, + exc_value: Exception | None, + tb: TracebackType | None, ) -> bool: return self._exit(exc_type, exc_value, tb) diff --git a/src/dbus_fast/service.pxd b/src/dbus_fast/service.pxd index 6a2d637..f1903d6 100644 --- a/src/dbus_fast/service.pxd +++ b/src/dbus_fast/service.pxd @@ -33,12 +33,11 @@ cdef class ServiceInterface: cdef list __signals cdef set __buses cdef dict __handlers + cdef dict __enabled_handlers_by_name_signature + @cython.locals(handlers=dict,in_signature=str,method=_Method) @staticmethod - cdef list _c_get_methods(ServiceInterface interface) - - @staticmethod - cdef object _c_get_handler(ServiceInterface interface, _Method method, object bus) + cdef object _get_enabled_handler_by_name_signature(ServiceInterface interface, object bus, object name, object signature) @staticmethod cdef list _c_msg_body_to_args(Message msg) diff --git a/src/dbus_fast/service.py b/src/dbus_fast/service.py index b31120a..5f733cf 100644 --- a/src/dbus_fast/service.py +++ b/src/dbus_fast/service.py @@ -1,8 +1,9 @@ +from __future__ import annotations import asyncio import copy import inspect from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Protocol from . import introspection as intr from ._private.util import ( @@ -20,19 +21,30 @@ from .signature import ( Variant, get_signature_tree, ) +from .send_reply import SendReply if TYPE_CHECKING: from .message_bus import BaseMessageBus +str_ = str + +HandlerType = Callable[[Message, SendReply], None] + + +class _MethodCallbackProtocol(Protocol): + def __call__(self, interface: ServiceInterface, *args: Any) -> Any: ... + class _Method: - def __init__(self, fn, name: str, disabled=False): + def __init__( + self, fn: _MethodCallbackProtocol, name: str, disabled: bool = False + ) -> None: in_signature = "" out_signature = "" inspection = inspect.signature(fn) - in_args = [] + in_args: list[intr.Arg] = [] for i, param in enumerate(inspection.parameters.values()): if i == 0: # first is self @@ -45,7 +57,7 @@ class _Method: in_args.append(intr.Arg(annotation, intr.ArgDirection.IN, param.name)) in_signature += annotation - out_args = [] + out_args: list[intr.Arg] = [] out_signature = parse_annotation(inspection.return_annotation) if out_signature: for type_ in get_signature_tree(out_signature).types: @@ -61,7 +73,7 @@ class _Method: self.out_signature_tree = get_signature_tree(out_signature) -def method(name: Optional[str] = None, disabled: bool = False): +def method(name: str | None = None, disabled: bool = False) -> Callable: """A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus service method. The parameters and return value must each be annotated with a signature @@ -99,9 +111,9 @@ def method(name: Optional[str] = None, disabled: bool = False): if type(disabled) is not bool: raise TypeError("disabled must be a bool") - def decorator(fn): + def decorator(fn: Callable) -> Callable: @wraps(fn) - def wrapped(*args, **kwargs): + def wrapped(*args: Any, **kwargs: Any) -> None: fn(*args, **kwargs) fn_name = name if name else fn.__name__ @@ -113,7 +125,7 @@ def method(name: Optional[str] = None, disabled: bool = False): class _Signal: - def __init__(self, fn, name, disabled=False): + def __init__(self, fn: Callable, name: str, disabled: bool = False) -> None: inspection = inspect.signature(fn) args = [] @@ -138,7 +150,7 @@ class _Signal: self.introspection = intr.Signal(self.name, args) -def signal(name: Optional[str] = None, disabled: bool = False): +def signal(name: str | None = None, disabled: bool = False) -> Callable: """A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus signal. The signal is broadcast on the bus when the decorated class method is @@ -173,12 +185,12 @@ def signal(name: Optional[str] = None, disabled: bool = False): if type(disabled) is not bool: raise TypeError("disabled must be a bool") - def decorator(fn): + def decorator(fn: Callable) -> Callable: fn_name = name if name else fn.__name__ signal = _Signal(fn, fn_name, disabled) @wraps(fn) - def wrapped(self, *args, **kwargs): + def wrapped(self, *args: Any, **kwargs: Any) -> Any: if signal.disabled: raise SignalDisabledError("Tried to call a disabled signal") result = fn(self, *args, **kwargs) @@ -259,9 +271,9 @@ class _Property(property): def dbus_property( access: PropertyAccess = PropertyAccess.READWRITE, - name: Optional[str] = None, + name: str | None = None, disabled: bool = False, -): +) -> Callable: """A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus property. The class method must be a Python getter method with a return annotation @@ -306,7 +318,7 @@ def dbus_property( if type(disabled) is not bool: raise TypeError("disabled must be a bool") - def decorator(fn): + def decorator(fn: Callable) -> _Property: options = {"name": name, "access": access, "disabled": disabled} return _Property(fn, options=options) @@ -314,7 +326,7 @@ def dbus_property( def _real_fn_result_to_body( - result: Optional[Any], + result: Any | None, signature_tree: SignatureTree, replace_fds: bool, ) -> tuple[list[Any], list[int]]: @@ -334,7 +346,8 @@ def _real_fn_result_to_body( if out_len != len(final_result): raise SignatureBodyMismatchError( - f"Signature and function return mismatch, expected {len(signature_tree.types)} arguments but got {len(result)}" + f"Signature and function return mismatch, expected " + f"{len(signature_tree.types)} arguments but got {len(result)}" # type: ignore[arg-type] ) if not replace_fds: @@ -365,12 +378,12 @@ class ServiceInterface: self.__methods: list[_Method] = [] self.__properties: list[_Property] = [] self.__signals: list[_Signal] = [] - self.__buses = set() - self.__handlers: dict[ - BaseMessageBus, - dict[_Method, Callable[[Message, Callable[[Message], None]], None]], + self.__buses: set[BaseMessageBus] = set() + self.__handlers: dict[BaseMessageBus, dict[_Method, HandlerType]] = {} + # Map of methods by bus of name -> method, handler + self.__handlers_by_name_signature: dict[ + BaseMessageBus, dict[str, tuple[_Method, HandlerType]] ] = {} - for name, member in inspect.getmembers(type(self)): member_dict = getattr(member, "__dict__", {}) if type(member) is _Property: @@ -405,7 +418,7 @@ class ServiceInterface: def emit_properties_changed( self, changed_properties: dict[str, Any], invalidated_properties: list[str] = [] - ): + ) -> None: """Emit the ``org.freedesktop.DBus.Properties.PropertiesChanged`` signal. This signal is intended to be used to alert clients when a property of @@ -464,58 +477,59 @@ class ServiceInterface: ) @staticmethod - def _get_properties(interface: "ServiceInterface") -> list[_Property]: + def _get_properties(interface: ServiceInterface) -> list[_Property]: return interface.__properties @staticmethod - def _get_methods(interface: "ServiceInterface") -> list[_Method]: + def _get_methods(interface: ServiceInterface) -> list[_Method]: return interface.__methods @staticmethod - def _c_get_methods(interface: "ServiceInterface") -> list[_Method]: - # _c_get_methods is used by the C code to get the methods for an - # interface - # https://github.com/cython/cython/issues/3327 - return interface.__methods - - @staticmethod - def _get_signals(interface: "ServiceInterface") -> list[_Signal]: + def _get_signals(interface: ServiceInterface) -> list[_Signal]: return interface.__signals @staticmethod - def _get_buses(interface: "ServiceInterface") -> set["BaseMessageBus"]: + def _get_buses(interface: ServiceInterface) -> set[BaseMessageBus]: return interface.__buses @staticmethod def _get_handler( - interface: "ServiceInterface", method: _Method, bus: "BaseMessageBus" - ) -> Callable[[Message, Callable[[Message], None]], None]: + interface: ServiceInterface, method: _Method, bus: BaseMessageBus + ) -> HandlerType: return interface.__handlers[bus][method] @staticmethod - def _c_get_handler( - interface: "ServiceInterface", method: _Method, bus: "BaseMessageBus" - ) -> Callable[[Message, Callable[[Message], None]], None]: - # _c_get_handler is used by the C code to get the handler for a method - # https://github.com/cython/cython/issues/3327 - return interface.__handlers[bus][method] + def _get_enabled_handler_by_name_signature( + interface: ServiceInterface, + bus: BaseMessageBus, + name: str_, + signature: str_, + ) -> HandlerType | None: + handlers = interface.__handlers_by_name_signature[bus] + if (method_handler := handlers.get(name)) is None: + return None + method = method_handler[0] + if method.disabled: + return None + return method_handler[1] if method.in_signature == signature else None @staticmethod def _add_bus( - interface: "ServiceInterface", - bus: "BaseMessageBus", - maker: Callable[ - ["ServiceInterface", _Method], - Callable[[Message, Callable[[Message], None]], None], - ], + interface: ServiceInterface, + bus: BaseMessageBus, + maker: Callable[[ServiceInterface, _Method], HandlerType], ) -> None: interface.__buses.add(bus) interface.__handlers[bus] = { method: maker(interface, method) for method in interface.__methods } + interface.__handlers_by_name_signature[bus] = { + method.name: (method, handler) + for method, handler in interface.__handlers[bus].items() + } @staticmethod - def _remove_bus(interface: "ServiceInterface", bus: "BaseMessageBus") -> None: + def _remove_bus(interface: ServiceInterface, bus: BaseMessageBus) -> None: interface.__buses.remove(bus) del interface.__handlers[bus] @@ -538,7 +552,7 @@ class ServiceInterface: @staticmethod def _fn_result_to_body( - result: Optional[Any], + result: Any | None, signature_tree: SignatureTree, replace_fds: bool = True, ) -> tuple[list[Any], list[int]]: @@ -546,7 +560,7 @@ class ServiceInterface: @staticmethod def _c_fn_result_to_body( - result: Optional[Any], + result: Any | None, signature_tree: SignatureTree, replace_fds: bool, ) -> tuple[list[Any], list[int]]: @@ -558,7 +572,7 @@ class ServiceInterface: @staticmethod def _handle_signal( - interface: "ServiceInterface", signal: _Signal, result: Optional[Any] + interface: ServiceInterface, signal: _Signal, result: Any | None ) -> None: body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree) for bus in ServiceInterface._get_buses(interface): @@ -567,15 +581,19 @@ class ServiceInterface: ) @staticmethod - def _get_property_value(interface: "ServiceInterface", prop: _Property, callback): + def _get_property_value( + interface: ServiceInterface, + prop: _Property, + callback: Callable[[ServiceInterface, _Property, Any, Exception | None], None], + ) -> None: # XXX MUST CHECK TYPE RETURNED BY GETTER try: if asyncio.iscoroutinefunction(prop.prop_getter): - task = asyncio.ensure_future(prop.prop_getter(interface)) + task: asyncio.Task = asyncio.ensure_future(prop.prop_getter(interface)) - def get_property_callback(task): + def get_property_callback(task_: asyncio.Task) -> None: try: - result = task.result() + result = task_.result() except Exception as e: callback(interface, prop, None, e) return @@ -592,15 +610,22 @@ class ServiceInterface: callback(interface, prop, None, e) @staticmethod - def _set_property_value(interface: "ServiceInterface", prop, value, callback): + def _set_property_value( + interface: ServiceInterface, + prop: _Property, + value: Any, + callback: Callable[[ServiceInterface, _Property, Exception | None], None], + ) -> None: # XXX MUST CHECK TYPE TO SET try: if asyncio.iscoroutinefunction(prop.prop_setter): - task = asyncio.ensure_future(prop.prop_setter(interface, value)) + task: asyncio.Task = asyncio.ensure_future( + prop.prop_setter(interface, value) + ) - def set_property_callback(task): + def set_property_callback(task_: asyncio.Task) -> None: try: - task.result() + task_.result() except Exception as e: callback(interface, prop, e) return @@ -617,9 +642,11 @@ class ServiceInterface: @staticmethod def _get_all_property_values( - interface: "ServiceInterface", callback, user_data=None - ): - result = {} + interface: ServiceInterface, + callback: Callable[[ServiceInterface, Any, Any, Exception | None], None], + user_data: Any | None = None, + ) -> None: + result: dict[str, Variant | None] = {} result_error = None for prop in ServiceInterface._get_properties(interface): @@ -632,10 +659,10 @@ class ServiceInterface: return def get_property_callback( - interface: "ServiceInterface", + interface: ServiceInterface, prop: _Property, value: Any, - e: Optional[Exception], + e: Exception | None, ) -> None: nonlocal result_error if e is not None: diff --git a/tests/service/test_export.py b/tests/service/test_export.py index 31fbe3b..74bcb4d 100644 --- a/tests/service/test_export.py +++ b/tests/service/test_export.py @@ -28,9 +28,14 @@ async def test_export_unexport(): bus = await MessageBus().connect() bus.export(export_path, interface) + + with pytest.raises(ValueError): + # Already exported + bus.export(export_path, interface) + assert export_path in bus._path_exports assert len(bus._path_exports[export_path]) == 1 - assert bus._path_exports[export_path][0] is interface + assert bus._path_exports[export_path][interface.name] is interface assert len(ServiceInterface._get_buses(interface)) == 1 bus.export(export_path2, interface2) @@ -60,11 +65,23 @@ async def test_export_unexport(): assert not bus._path_exports assert not ServiceInterface._get_buses(interface) + # test unexporting by ServiceInterface + bus.export(export_path, interface) + bus.unexport(export_path, interface) + assert not bus._path_exports + assert not ServiceInterface._get_buses(interface) + + with pytest.raises(TypeError): + bus.unexport(export_path, object()) + node = bus._introspect_export_path("/path/doesnt/exist") assert type(node) is intr.Node assert not node.interfaces assert not node.nodes + # Should to nothing + bus.unexport("/path/doesnt/exist", interface) + bus.disconnect() diff --git a/tests/service/test_standard_interfaces.py b/tests/service/test_standard_interfaces.py index 731d673..37e338e 100644 --- a/tests/service/test_standard_interfaces.py +++ b/tests/service/test_standard_interfaces.py @@ -149,6 +149,19 @@ async def test_peer_interface(): assert reply.message_type == MessageType.METHOD_RETURN, reply.body[0] assert reply.signature == "s" + reply2 = await bus2.call( + Message( + destination=bus1.unique_name, + path="/path/doesnt/exist", + interface="org.freedesktop.DBus.Peer", + member="GetMachineId", + signature="", + ) + ) + + assert reply2.message_type == MessageType.METHOD_RETURN, reply.body[0] + assert reply2.signature == "s" + bus1.disconnect() bus2.disconnect() @@ -213,9 +226,9 @@ async def test_object_manager(): ) ) - assert reply_root.signature == "a{oa{sa{sv}}}" - assert reply_level1.signature == "a{oa{sa{sv}}}" - assert reply_level2.signature == "a{oa{sa{sv}}}" + assert reply_root.signature == "a{oa{sa{sv}}}", reply_root + assert reply_level1.signature == "a{oa{sa{sv}}}", reply_level1 + assert reply_level2.signature == "a{oa{sa{sv}}}", reply_level2 assert reply_level2.body == [{}] assert reply_level1.body == [expected_reply]