feat: complete some more missing typing (#103)

This commit is contained in:
J. Nick Koston
2022-10-11 16:54:23 -10:00
committed by GitHub
parent 95a98f4c44
commit 5787032af7
2 changed files with 131 additions and 88 deletions

View File

@@ -3,7 +3,8 @@ import logging
import socket import socket
import traceback import traceback
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Callable, Optional, Type, Union from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Type, Union
from . import introspection as intr from . import introspection as intr
from ._private.address import get_bus_address, parse_address from ._private.address import get_bus_address, parse_address
@@ -20,7 +21,7 @@ from .constants import (
from .errors import DBusError, InvalidAddressError from .errors import DBusError, InvalidAddressError
from .message import Message from .message import Message
from .proxy_object import BaseProxyObject from .proxy_object import BaseProxyObject
from .service import ServiceInterface from .service import ServiceInterface, _Method
from .signature import Variant from .signature import Variant
from .validators import assert_bus_name_valid, assert_object_path_valid from .validators import assert_bus_name_valid, assert_object_path_valid
@@ -61,23 +62,27 @@ class BaseMessageBus:
bus_type: BusType = BusType.SESSION, bus_type: BusType = BusType.SESSION,
ProxyObject: Optional[Type[BaseProxyObject]] = None, ProxyObject: Optional[Type[BaseProxyObject]] = None,
) -> None: ) -> None:
self.unique_name = None self.unique_name: Optional[str] = None
self._disconnected = False self._disconnected = False
# True if the user disconnected himself, so don't throw errors out of # True if the user disconnected himself, so don't throw errors out of
# the main loop. # the main loop.
self._user_disconnect = False self._user_disconnect = False
self._method_return_handlers = {} self._method_return_handlers: Dict[
int, Callable[[Optional[Message], Optional[Exception]], None]
] = {}
self._serial = 0 self._serial = 0
self._user_message_handlers = [] self._user_message_handlers: List[
Callable[[Message], Union[Message, bool, None]]
] = []
# the key is the name and the value is the unique name of the owner. # the key is the name and the value is the unique name of the owner.
# This cache is kept up to date by the NameOwnerChanged signal and is # This cache is kept up to date by the NameOwnerChanged signal and is
# used to route messages to the correct proxy object. (used for the # used to route messages to the correct proxy object. (used for the
# high level client only) # high level client only)
self._name_owners = {} self._name_owners: Dict[str, str] = {}
# used for the high level service # used for the high level service
self._path_exports = {} self._path_exports: Dict[str, list[ServiceInterface]] = {}
self._bus_address = ( self._bus_address = (
parse_address(bus_address) parse_address(bus_address)
if bus_address if bus_address
@@ -88,12 +93,12 @@ class BaseMessageBus:
self._name_owner_match_rule = "sender='org.freedesktop.DBus',interface='org.freedesktop.DBus',path='/org/freedesktop/DBus',member='NameOwnerChanged'" self._name_owner_match_rule = "sender='org.freedesktop.DBus',interface='org.freedesktop.DBus',path='/org/freedesktop/DBus',member='NameOwnerChanged'"
# _match_rules: the keys are match rules and the values are ref counts # _match_rules: the keys are match rules and the values are ref counts
# (used for the high level client only) # (used for the high level client only)
self._match_rules = {} self._match_rules: Dict[str, int] = {}
self._high_level_client_initialized = False self._high_level_client_initialized = False
self._ProxyObject = ProxyObject self._ProxyObject = ProxyObject
# machine id is lazy loaded # machine id is lazy loaded
self._machine_id = None self._machine_id: Optional[int] = None
self._setup_socket() self._setup_socket()
@@ -211,10 +216,10 @@ class BaseMessageBus:
""" """
BaseMessageBus._check_callback_type(callback) BaseMessageBus._check_callback_type(callback)
def reply_notify(reply, err): def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
try: try:
BaseMessageBus._check_method_return(reply, err, "s") BaseMessageBus._check_method_return(reply, err, "s")
result = intr.Node.parse(reply.body[0]) result = intr.Node.parse(reply.body[0]) # type: ignore[union-attr]
except Exception as e: except Exception as e:
callback(None, e) callback(None, e)
return return
@@ -246,7 +251,12 @@ class BaseMessageBus:
if self._disconnected: if self._disconnected:
return return
def get_properties_callback(interface, result, user_data, e): def get_properties_callback(
interface: ServiceInterface,
result: Any,
user_data: Any,
e: Optional[Exception],
) -> None:
if e is not None: if e is not None:
try: try:
raise e raise e
@@ -272,7 +282,7 @@ class BaseMessageBus:
ServiceInterface._get_all_property_values(interface, get_properties_callback) ServiceInterface._get_all_property_values(interface, get_properties_callback)
def _emit_interface_removed(self, path, removed_interfaces): def _emit_interface_removed(self, path: str, removed_interfaces: List[str]) -> None:
"""Emit the ``org.freedesktop.DBus.ObjectManager.InterfacesRemoved` signal. """Emit the ``org.freedesktop.DBus.ObjectManager.InterfacesRemoved` signal.
This signal is intended to be used to alert clients when This signal is intended to be used to alert clients when
@@ -303,7 +313,7 @@ class BaseMessageBus:
callback: Optional[ callback: Optional[
Callable[[Optional[RequestNameReply], Optional[Exception]], None] Callable[[Optional[RequestNameReply], Optional[Exception]], None]
] = None, ] = None,
): ) -> None:
"""Request that this message bus owns the given name. """Request that this message bus owns the given name.
:param name: The name to request. :param name: The name to request.
@@ -322,38 +332,41 @@ class BaseMessageBus:
if callback is not None: if callback is not None:
BaseMessageBus._check_callback_type(callback) BaseMessageBus._check_callback_type(callback)
def reply_notify(reply, err):
try:
BaseMessageBus._check_method_return(reply, err, "u")
result = RequestNameReply(reply.body[0])
except Exception as e:
callback(None, e)
return
callback(result, None)
if type(flags) is not NameFlag: if type(flags) is not NameFlag:
flags = NameFlag(flags) flags = NameFlag(flags)
self._call( message = Message(
Message( destination="org.freedesktop.DBus",
destination="org.freedesktop.DBus", path="/org/freedesktop/DBus",
path="/org/freedesktop/DBus", interface="org.freedesktop.DBus",
interface="org.freedesktop.DBus", member="RequestName",
member="RequestName", signature="su",
signature="su", body=[name, flags],
body=[name, flags],
),
reply_notify if callback else None,
) )
if callback is None:
self._call(message, None)
return
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
try:
BaseMessageBus._check_method_return(reply, err, "u")
result = RequestNameReply(reply.body[0]) # type: ignore[union-attr]
except Exception as e:
callback(None, e) # type: ignore[misc]
return
callback(result, None) # type: ignore[misc]
self._call(message, reply_notify)
def release_name( def release_name(
self, self,
name: str, name: str,
callback: Optional[ callback: Optional[
Callable[[Optional[ReleaseNameReply], Optional[Exception]], None] Callable[[Optional[ReleaseNameReply], Optional[Exception]], None]
] = None, ] = None,
): ) -> None:
"""Request that this message bus release the given name. """Request that this message bus release the given name.
:param name: The name to release. :param name: The name to release.
@@ -371,27 +384,30 @@ class BaseMessageBus:
if callback is not None: if callback is not None:
BaseMessageBus._check_callback_type(callback) BaseMessageBus._check_callback_type(callback)
def reply_notify(reply, err): message = Message(
destination="org.freedesktop.DBus",
path="/org/freedesktop/DBus",
interface="org.freedesktop.DBus",
member="ReleaseName",
signature="s",
body=[name],
)
if callback is None:
self._call(message, None)
return
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
try: try:
BaseMessageBus._check_method_return(reply, err, "u") BaseMessageBus._check_method_return(reply, err, "u")
result = ReleaseNameReply(reply.body[0]) result = ReleaseNameReply(reply.body[0]) # type: ignore[union-attr]
except Exception as e: except Exception as e:
callback(None, e) callback(None, e) # type: ignore[misc]
return return
callback(result, None) callback(result, None) # type: ignore[misc]
self._call( self._call(message, reply_notify)
Message(
destination="org.freedesktop.DBus",
path="/org/freedesktop/DBus",
interface="org.freedesktop.DBus",
member="ReleaseName",
signature="s",
body=[name],
),
reply_notify if callback else None,
)
def get_proxy_object( def get_proxy_object(
self, bus_name: str, path: str, introspection: Union[intr.Node, str, ET.Element] self, bus_name: str, path: str, introspection: Union[intr.Node, str, ET.Element]
@@ -451,7 +467,7 @@ class BaseMessageBus:
def add_message_handler( def add_message_handler(
self, handler: Callable[[Message], Optional[Union[Message, bool]]] self, handler: Callable[[Message], Optional[Union[Message, bool]]]
): ) -> None:
"""Add a custom message handler for incoming messages. """Add a custom message handler for incoming messages.
The handler should be a callable that takes a :class:`Message The handler should be a callable that takes a :class:`Message
@@ -476,7 +492,7 @@ class BaseMessageBus:
def remove_message_handler( def remove_message_handler(
self, handler: Callable[[Message], Optional[Union[Message, bool]]] self, handler: Callable[[Message], Optional[Union[Message, bool]]]
): ) -> None:
"""Remove a message handler that was previously added by """Remove a message handler that was previously added by
:func:`add_message_handler() :func:`add_message_handler()
<dbus_fast.message_bus.BaseMessageBus.add_message_handler>`. <dbus_fast.message_bus.BaseMessageBus.add_message_handler>`.
@@ -487,7 +503,7 @@ class BaseMessageBus:
for i, h in enumerate(self._user_message_handlers): for i, h in enumerate(self._user_message_handlers):
if h == handler: if h == handler:
del self._user_message_handlers[i] del self._user_message_handlers[i]
break return
def send(self, msg: Message) -> None: def send(self, msg: Message) -> None:
"""Asynchronously send a message on the message bus. """Asynchronously send a message on the message bus.
@@ -499,7 +515,7 @@ class BaseMessageBus:
'the "send" method must be implemented in the inheriting class' 'the "send" method must be implemented in the inheriting class'
) )
def _finalize(self, err): def _finalize(self, err: Optional[Exception]) -> None:
"""should be called after the socket disconnects with the disconnection """should be called after the socket disconnects with the disconnection
error to clean up resources and put the bus in a disconnected state""" error to clean up resources and put the bus in a disconnected state"""
if self._disconnected: if self._disconnected:
@@ -531,8 +547,14 @@ class BaseMessageBus:
return False return False
def _interface_signal_notify( def _interface_signal_notify(
self, interface, interface_name, member, signature, body, unix_fds=[] self,
): interface: ServiceInterface,
interface_name: str,
member: str,
signature: str,
body: List[Any],
unix_fds: List[int] = [],
) -> None:
path = None path = None
for p, ifaces in self._path_exports.items(): for p, ifaces in self._path_exports.items():
for i in ifaces: for i in ifaces:
@@ -555,7 +577,7 @@ class BaseMessageBus:
) )
) )
def _introspect_export_path(self, path): def _introspect_export_path(self, path: str) -> intr.Node:
assert_object_path_valid(path) assert_object_path_valid(path)
if path in self._path_exports: if path in self._path_exports:
@@ -582,7 +604,7 @@ class BaseMessageBus:
return node return node
def _setup_socket(self): def _setup_socket(self) -> None:
err = None err = None
for transport, options in self._bus_address: for transport, options in self._bus_address:
@@ -635,7 +657,10 @@ class BaseMessageBus:
raise err raise err
def _call( def _call(
self, msg: Message, callback: Callable, check_callback: bool = True self,
msg: Message,
callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]],
check_callback: bool = True,
) -> None: ) -> None:
if check_callback: if check_callback:
BaseMessageBus._check_callback_type(callback) BaseMessageBus._check_callback_type(callback)
@@ -643,10 +668,10 @@ class BaseMessageBus:
if not msg.serial: if not msg.serial:
msg.serial = self.next_serial() msg.serial = self.next_serial()
def reply_notify(reply, err): def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
if reply: if reply and msg.destination and reply.sender:
self._name_owners[msg.destination] = reply.sender self._name_owners[msg.destination] = reply.sender
callback(reply, err) callback(reply, err) # type: ignore[misc]
no_reply_expected = msg.flags & MessageFlag.NO_REPLY_EXPECTED no_reply_expected = msg.flags & MessageFlag.NO_REPLY_EXPECTED
@@ -660,7 +685,7 @@ class BaseMessageBus:
self.send(msg) self.send(msg)
if no_reply_expected: if no_reply_expected:
callback(None, None) callback(None, None) # type: ignore[misc]
@staticmethod @staticmethod
def _check_callback_type(callback: Callable) -> None: def _check_callback_type(callback: Callable) -> None:
@@ -676,9 +701,15 @@ class BaseMessageBus:
raise TypeError(text) raise TypeError(text)
@staticmethod @staticmethod
def _check_method_return(msg: Message, err: Exception, signature: str) -> None: def _check_method_return(
msg: Optional[Message], err: Optional[Exception], signature: str
) -> None:
if err: if err:
raise err raise err
elif msg is None:
raise DBusError(
ErrorType.INTERNAL_ERROR, "invalid message type for method call", msg
)
elif ( elif (
msg.message_type == MessageType.METHOD_RETURN and msg.signature == signature msg.message_type == MessageType.METHOD_RETURN and msg.signature == signature
): ):
@@ -694,21 +725,26 @@ class BaseMessageBus:
bus = self bus = self
class SendReply: class SendReply:
def __enter__(self): def __enter__(self) -> "SendReply":
return self return self
def __call__(self, reply): def __call__(self, reply: Message) -> None:
if msg.flags & MessageFlag.NO_REPLY_EXPECTED: if msg.flags & MessageFlag.NO_REPLY_EXPECTED:
return return
bus.send(reply) bus.send(reply)
def _exit(self, exc_type, exc_value, tb): def _exit(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
) -> bool:
if exc_type is None: if exc_type is None:
return return False
if issubclass(exc_type, DBusError): if issubclass(exc_type, DBusError):
self(exc_value._as_message(msg)) self(exc_value._as_message(msg)) # type: ignore[union-attr]
return True return True
if issubclass(exc_type, Exception): if issubclass(exc_type, Exception):
@@ -721,10 +757,15 @@ class BaseMessageBus:
) )
return True return True
def __exit__(self, exc_type, exc_value, tb): def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
) -> None:
self._exit(exc_type, exc_value, tb) self._exit(exc_type, exc_value, tb)
def send_error(self, exc): def send_error(self, exc: Exception) -> None:
self._exit(exc.__class__, exc, exc.__traceback__) self._exit(exc.__class__, exc, exc.__traceback__)
return SendReply() return SendReply()
@@ -732,9 +773,9 @@ class BaseMessageBus:
def _process_message(self, msg: Message) -> None: def _process_message(self, msg: Message) -> None:
handled = False handled = False
for handler in self._user_message_handlers: for user_handler in self._user_message_handlers:
try: try:
result = handler(msg) result = user_handler(msg)
if result: if result:
if type(result) is Message: if type(result) is Message:
self.send(result) self.send(result)
@@ -799,12 +840,14 @@ class BaseMessageBus:
# An ERROR or a METHOD_RETURN # An ERROR or a METHOD_RETURN
if msg.reply_serial in self._method_return_handlers: if msg.reply_serial in self._method_return_handlers:
if not handled: if not handled:
handler = self._method_return_handlers[msg.reply_serial] return_handler = self._method_return_handlers[msg.reply_serial]
handler(msg, None) return_handler(msg, None)
del self._method_return_handlers[msg.reply_serial] del self._method_return_handlers[msg.reply_serial]
def _make_method_handler(self, interface, method): def _make_method_handler(
def handler(msg, send_reply): self, interface: ServiceInterface, method: _Method
) -> Callable[[Message, Callable[[Message], None]], None]:
def handler(msg: Message, send_reply: Callable[[Message], None]) -> None:
args = ServiceInterface._msg_body_to_args(msg) args = ServiceInterface._msg_body_to_args(msg)
result = method.fn(interface, *args) result = method.fn(interface, *args)
body, fds = ServiceInterface._fn_result_to_body( body, fds = ServiceInterface._fn_result_to_body(
@@ -817,7 +860,7 @@ class BaseMessageBus:
def _find_message_handler( def _find_message_handler(
self, msg: Message self, msg: Message
) -> Optional[Callable[[Message, Callable], None]]: ) -> Optional[Callable[[Message, Callable], None]]:
handler = None handler: Optional[Callable[[Message, Callable], None]] = None
if ( if (
msg.interface == "org.freedesktop.DBus.Introspectable" msg.interface == "org.freedesktop.DBus.Introspectable"
@@ -840,7 +883,7 @@ class BaseMessageBus:
): ):
handler = self._default_get_managed_objects_handler handler = self._default_get_managed_objects_handler
else: elif msg.path:
for interface in self._path_exports.get(msg.path, []): for interface in self._path_exports.get(msg.path, []):
for method in ServiceInterface._get_methods(interface): for method in ServiceInterface._get_methods(interface):
if method.disabled: if method.disabled:

View File

@@ -327,9 +327,9 @@ class ServiceInterface:
def __init__(self, name: str): def __init__(self, name: str):
# TODO cannot be overridden by a dbus member # TODO cannot be overridden by a dbus member
self.name = name self.name = name
self.__methods = [] self.__methods: List[_Method] = []
self.__properties = [] self.__properties: List[_Property] = []
self.__signals = [] self.__signals: List[_Signal] = []
self.__buses = set() self.__buses = set()
for name, member in inspect.getmembers(type(self)): for name, member in inspect.getmembers(type(self)):
@@ -425,27 +425,27 @@ class ServiceInterface:
) )
@staticmethod @staticmethod
def _get_properties(interface): def _get_properties(interface: "ServiceInterface") -> List[_Property]:
return interface.__properties return interface.__properties
@staticmethod @staticmethod
def _get_methods(interface): def _get_methods(interface: "ServiceInterface") -> List[_Method]:
return interface.__methods return interface.__methods
@staticmethod @staticmethod
def _get_signals(interface): def _get_signals(interface: "ServiceInterface") -> List[_Signal]:
return interface.__signals return interface.__signals
@staticmethod @staticmethod
def _get_buses(interface): def _get_buses(interface: "ServiceInterface"):
return interface.__buses return interface.__buses
@staticmethod @staticmethod
def _add_bus(interface, bus): def _add_bus(interface: "ServiceInterface", bus) -> None:
interface.__buses.add(bus) interface.__buses.add(bus)
@staticmethod @staticmethod
def _remove_bus(interface, bus): def _remove_bus(interface: "ServiceInterface", bus) -> None:
interface.__buses.remove(bus) interface.__buses.remove(bus)
@staticmethod @staticmethod