From 5787032af7cae1ffffd1561390cdb02053776345 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 11 Oct 2022 16:54:23 -1000 Subject: [PATCH] feat: complete some more missing typing (#103) --- src/dbus_fast/message_bus.py | 201 +++++++++++++++++++++-------------- src/dbus_fast/service.py | 18 ++-- 2 files changed, 131 insertions(+), 88 deletions(-) diff --git a/src/dbus_fast/message_bus.py b/src/dbus_fast/message_bus.py index 1619814..a48695f 100644 --- a/src/dbus_fast/message_bus.py +++ b/src/dbus_fast/message_bus.py @@ -3,7 +3,8 @@ import logging import socket import traceback import xml.etree.ElementTree as ET -from typing import Callable, Optional, Type, Union +from types import TracebackType +from typing import Any, Callable, Dict, List, Optional, Type, Union from . import introspection as intr from ._private.address import get_bus_address, parse_address @@ -20,7 +21,7 @@ from .constants import ( from .errors import DBusError, InvalidAddressError from .message import Message from .proxy_object import BaseProxyObject -from .service import ServiceInterface +from .service import ServiceInterface, _Method from .signature import Variant from .validators import assert_bus_name_valid, assert_object_path_valid @@ -61,23 +62,27 @@ class BaseMessageBus: bus_type: BusType = BusType.SESSION, ProxyObject: Optional[Type[BaseProxyObject]] = None, ) -> None: - self.unique_name = None + self.unique_name: Optional[str] = None self._disconnected = False # True if the user disconnected himself, so don't throw errors out of # the main loop. self._user_disconnect = False - self._method_return_handlers = {} + self._method_return_handlers: Dict[ + int, Callable[[Optional[Message], Optional[Exception]], None] + ] = {} self._serial = 0 - self._user_message_handlers = [] + self._user_message_handlers: List[ + Callable[[Message], Union[Message, bool, None]] + ] = [] # the key is the name and the value is the unique name of the owner. # This cache is kept up to date by the NameOwnerChanged signal and is # used to route messages to the correct proxy object. (used for the # high level client only) - self._name_owners = {} + self._name_owners: Dict[str, str] = {} # used for the high level service - self._path_exports = {} + self._path_exports: Dict[str, list[ServiceInterface]] = {} self._bus_address = ( parse_address(bus_address) if bus_address @@ -88,12 +93,12 @@ class BaseMessageBus: self._name_owner_match_rule = "sender='org.freedesktop.DBus',interface='org.freedesktop.DBus',path='/org/freedesktop/DBus',member='NameOwnerChanged'" # _match_rules: the keys are match rules and the values are ref counts # (used for the high level client only) - self._match_rules = {} + self._match_rules: Dict[str, int] = {} self._high_level_client_initialized = False self._ProxyObject = ProxyObject # machine id is lazy loaded - self._machine_id = None + self._machine_id: Optional[int] = None self._setup_socket() @@ -211,10 +216,10 @@ class BaseMessageBus: """ BaseMessageBus._check_callback_type(callback) - def reply_notify(reply, err): + def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None: try: BaseMessageBus._check_method_return(reply, err, "s") - result = intr.Node.parse(reply.body[0]) + result = intr.Node.parse(reply.body[0]) # type: ignore[union-attr] except Exception as e: callback(None, e) return @@ -246,7 +251,12 @@ class BaseMessageBus: if self._disconnected: return - def get_properties_callback(interface, result, user_data, e): + def get_properties_callback( + interface: ServiceInterface, + result: Any, + user_data: Any, + e: Optional[Exception], + ) -> None: if e is not None: try: raise e @@ -272,7 +282,7 @@ class BaseMessageBus: ServiceInterface._get_all_property_values(interface, get_properties_callback) - def _emit_interface_removed(self, path, removed_interfaces): + def _emit_interface_removed(self, path: str, removed_interfaces: List[str]) -> None: """Emit the ``org.freedesktop.DBus.ObjectManager.InterfacesRemoved` signal. This signal is intended to be used to alert clients when @@ -303,7 +313,7 @@ class BaseMessageBus: callback: Optional[ Callable[[Optional[RequestNameReply], Optional[Exception]], None] ] = None, - ): + ) -> None: """Request that this message bus owns the given name. :param name: The name to request. @@ -322,38 +332,41 @@ class BaseMessageBus: if callback is not None: BaseMessageBus._check_callback_type(callback) - def reply_notify(reply, err): - try: - BaseMessageBus._check_method_return(reply, err, "u") - result = RequestNameReply(reply.body[0]) - except Exception as e: - callback(None, e) - return - - callback(result, None) - if type(flags) is not NameFlag: flags = NameFlag(flags) - self._call( - Message( - destination="org.freedesktop.DBus", - path="/org/freedesktop/DBus", - interface="org.freedesktop.DBus", - member="RequestName", - signature="su", - body=[name, flags], - ), - reply_notify if callback else None, + message = Message( + destination="org.freedesktop.DBus", + path="/org/freedesktop/DBus", + interface="org.freedesktop.DBus", + member="RequestName", + signature="su", + body=[name, flags], ) + if callback is None: + self._call(message, None) + return + + def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None: + try: + BaseMessageBus._check_method_return(reply, err, "u") + result = RequestNameReply(reply.body[0]) # type: ignore[union-attr] + except Exception as e: + callback(None, e) # type: ignore[misc] + return + + callback(result, None) # type: ignore[misc] + + self._call(message, reply_notify) + def release_name( self, name: str, callback: Optional[ Callable[[Optional[ReleaseNameReply], Optional[Exception]], None] ] = None, - ): + ) -> None: """Request that this message bus release the given name. :param name: The name to release. @@ -371,27 +384,30 @@ class BaseMessageBus: if callback is not None: BaseMessageBus._check_callback_type(callback) - def reply_notify(reply, err): + message = Message( + destination="org.freedesktop.DBus", + path="/org/freedesktop/DBus", + interface="org.freedesktop.DBus", + member="ReleaseName", + signature="s", + body=[name], + ) + + if callback is None: + self._call(message, None) + return + + def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None: try: BaseMessageBus._check_method_return(reply, err, "u") - result = ReleaseNameReply(reply.body[0]) + result = ReleaseNameReply(reply.body[0]) # type: ignore[union-attr] except Exception as e: - callback(None, e) + callback(None, e) # type: ignore[misc] return - callback(result, None) + callback(result, None) # type: ignore[misc] - self._call( - Message( - destination="org.freedesktop.DBus", - path="/org/freedesktop/DBus", - interface="org.freedesktop.DBus", - member="ReleaseName", - signature="s", - body=[name], - ), - reply_notify if callback else None, - ) + self._call(message, reply_notify) def get_proxy_object( self, bus_name: str, path: str, introspection: Union[intr.Node, str, ET.Element] @@ -451,7 +467,7 @@ class BaseMessageBus: def add_message_handler( self, handler: Callable[[Message], Optional[Union[Message, bool]]] - ): + ) -> None: """Add a custom message handler for incoming messages. The handler should be a callable that takes a :class:`Message @@ -476,7 +492,7 @@ class BaseMessageBus: def remove_message_handler( self, handler: Callable[[Message], Optional[Union[Message, bool]]] - ): + ) -> None: """Remove a message handler that was previously added by :func:`add_message_handler() `. @@ -487,7 +503,7 @@ class BaseMessageBus: for i, h in enumerate(self._user_message_handlers): if h == handler: del self._user_message_handlers[i] - break + return def send(self, msg: Message) -> None: """Asynchronously send a message on the message bus. @@ -499,7 +515,7 @@ class BaseMessageBus: 'the "send" method must be implemented in the inheriting class' ) - def _finalize(self, err): + def _finalize(self, err: Optional[Exception]) -> None: """should be called after the socket disconnects with the disconnection error to clean up resources and put the bus in a disconnected state""" if self._disconnected: @@ -531,8 +547,14 @@ class BaseMessageBus: return False def _interface_signal_notify( - self, interface, interface_name, member, signature, body, unix_fds=[] - ): + self, + interface: ServiceInterface, + interface_name: str, + member: str, + signature: str, + body: List[Any], + unix_fds: List[int] = [], + ) -> None: path = None for p, ifaces in self._path_exports.items(): for i in ifaces: @@ -555,7 +577,7 @@ class BaseMessageBus: ) ) - def _introspect_export_path(self, path): + def _introspect_export_path(self, path: str) -> intr.Node: assert_object_path_valid(path) if path in self._path_exports: @@ -582,7 +604,7 @@ class BaseMessageBus: return node - def _setup_socket(self): + def _setup_socket(self) -> None: err = None for transport, options in self._bus_address: @@ -635,7 +657,10 @@ class BaseMessageBus: raise err def _call( - self, msg: Message, callback: Callable, check_callback: bool = True + self, + msg: Message, + callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]], + check_callback: bool = True, ) -> None: if check_callback: BaseMessageBus._check_callback_type(callback) @@ -643,10 +668,10 @@ class BaseMessageBus: if not msg.serial: msg.serial = self.next_serial() - def reply_notify(reply, err): - if reply: + def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None: + if reply and msg.destination and reply.sender: self._name_owners[msg.destination] = reply.sender - callback(reply, err) + callback(reply, err) # type: ignore[misc] no_reply_expected = msg.flags & MessageFlag.NO_REPLY_EXPECTED @@ -660,7 +685,7 @@ class BaseMessageBus: self.send(msg) if no_reply_expected: - callback(None, None) + callback(None, None) # type: ignore[misc] @staticmethod def _check_callback_type(callback: Callable) -> None: @@ -676,9 +701,15 @@ class BaseMessageBus: raise TypeError(text) @staticmethod - def _check_method_return(msg: Message, err: Exception, signature: str) -> None: + def _check_method_return( + msg: Optional[Message], err: Optional[Exception], signature: str + ) -> None: if err: raise err + elif msg is None: + raise DBusError( + ErrorType.INTERNAL_ERROR, "invalid message type for method call", msg + ) elif ( msg.message_type == MessageType.METHOD_RETURN and msg.signature == signature ): @@ -694,21 +725,26 @@ class BaseMessageBus: bus = self class SendReply: - def __enter__(self): + def __enter__(self) -> "SendReply": return self - def __call__(self, reply): + def __call__(self, reply: Message) -> None: if msg.flags & MessageFlag.NO_REPLY_EXPECTED: return bus.send(reply) - def _exit(self, exc_type, exc_value, tb): + def _exit( + self, + exc_type: Optional[Type[Exception]], + exc_value: Optional[Exception], + tb: Optional[TracebackType], + ) -> bool: if exc_type is None: - return + return False if issubclass(exc_type, DBusError): - self(exc_value._as_message(msg)) + self(exc_value._as_message(msg)) # type: ignore[union-attr] return True if issubclass(exc_type, Exception): @@ -721,10 +757,15 @@ class BaseMessageBus: ) return True - def __exit__(self, exc_type, exc_value, tb): + def __exit__( + self, + exc_type: Optional[Type[Exception]], + exc_value: Optional[Exception], + tb: Optional[TracebackType], + ) -> None: self._exit(exc_type, exc_value, tb) - def send_error(self, exc): + def send_error(self, exc: Exception) -> None: self._exit(exc.__class__, exc, exc.__traceback__) return SendReply() @@ -732,9 +773,9 @@ class BaseMessageBus: def _process_message(self, msg: Message) -> None: handled = False - for handler in self._user_message_handlers: + for user_handler in self._user_message_handlers: try: - result = handler(msg) + result = user_handler(msg) if result: if type(result) is Message: self.send(result) @@ -799,12 +840,14 @@ class BaseMessageBus: # An ERROR or a METHOD_RETURN if msg.reply_serial in self._method_return_handlers: if not handled: - handler = self._method_return_handlers[msg.reply_serial] - handler(msg, None) + return_handler = self._method_return_handlers[msg.reply_serial] + return_handler(msg, None) del self._method_return_handlers[msg.reply_serial] - def _make_method_handler(self, interface, method): - def handler(msg, send_reply): + def _make_method_handler( + self, interface: ServiceInterface, method: _Method + ) -> Callable[[Message, Callable[[Message], None]], None]: + def handler(msg: Message, send_reply: Callable[[Message], None]) -> None: args = ServiceInterface._msg_body_to_args(msg) result = method.fn(interface, *args) body, fds = ServiceInterface._fn_result_to_body( @@ -817,7 +860,7 @@ class BaseMessageBus: def _find_message_handler( self, msg: Message ) -> Optional[Callable[[Message, Callable], None]]: - handler = None + handler: Optional[Callable[[Message, Callable], None]] = None if ( msg.interface == "org.freedesktop.DBus.Introspectable" @@ -840,7 +883,7 @@ class BaseMessageBus: ): handler = self._default_get_managed_objects_handler - else: + elif msg.path: for interface in self._path_exports.get(msg.path, []): for method in ServiceInterface._get_methods(interface): if method.disabled: diff --git a/src/dbus_fast/service.py b/src/dbus_fast/service.py index 20ca015..7c29e14 100644 --- a/src/dbus_fast/service.py +++ b/src/dbus_fast/service.py @@ -327,9 +327,9 @@ class ServiceInterface: def __init__(self, name: str): # TODO cannot be overridden by a dbus member self.name = name - self.__methods = [] - self.__properties = [] - self.__signals = [] + self.__methods: List[_Method] = [] + self.__properties: List[_Property] = [] + self.__signals: List[_Signal] = [] self.__buses = set() for name, member in inspect.getmembers(type(self)): @@ -425,27 +425,27 @@ class ServiceInterface: ) @staticmethod - def _get_properties(interface): + def _get_properties(interface: "ServiceInterface") -> List[_Property]: return interface.__properties @staticmethod - def _get_methods(interface): + def _get_methods(interface: "ServiceInterface") -> List[_Method]: return interface.__methods @staticmethod - def _get_signals(interface): + def _get_signals(interface: "ServiceInterface") -> List[_Signal]: return interface.__signals @staticmethod - def _get_buses(interface): + def _get_buses(interface: "ServiceInterface"): return interface.__buses @staticmethod - def _add_bus(interface, bus): + def _add_bus(interface: "ServiceInterface", bus) -> None: interface.__buses.add(bus) @staticmethod - def _remove_bus(interface, bus): + def _remove_bus(interface: "ServiceInterface", bus) -> None: interface.__buses.remove(bus) @staticmethod