feat: complete some more missing typing (#103)

This commit is contained in:
J. Nick Koston 2022-10-11 16:54:23 -10:00 committed by GitHub
parent 95a98f4c44
commit 5787032af7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 131 additions and 88 deletions

View File

@ -3,7 +3,8 @@ import logging
import socket
import traceback
import xml.etree.ElementTree as ET
from typing import Callable, Optional, Type, Union
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Type, Union
from . import introspection as intr
from ._private.address import get_bus_address, parse_address
@ -20,7 +21,7 @@ from .constants import (
from .errors import DBusError, InvalidAddressError
from .message import Message
from .proxy_object import BaseProxyObject
from .service import ServiceInterface
from .service import ServiceInterface, _Method
from .signature import Variant
from .validators import assert_bus_name_valid, assert_object_path_valid
@ -61,23 +62,27 @@ class BaseMessageBus:
bus_type: BusType = BusType.SESSION,
ProxyObject: Optional[Type[BaseProxyObject]] = None,
) -> None:
self.unique_name = None
self.unique_name: Optional[str] = None
self._disconnected = False
# True if the user disconnected himself, so don't throw errors out of
# the main loop.
self._user_disconnect = False
self._method_return_handlers = {}
self._method_return_handlers: Dict[
int, Callable[[Optional[Message], Optional[Exception]], None]
] = {}
self._serial = 0
self._user_message_handlers = []
self._user_message_handlers: List[
Callable[[Message], Union[Message, bool, None]]
] = []
# the key is the name and the value is the unique name of the owner.
# This cache is kept up to date by the NameOwnerChanged signal and is
# used to route messages to the correct proxy object. (used for the
# high level client only)
self._name_owners = {}
self._name_owners: Dict[str, str] = {}
# used for the high level service
self._path_exports = {}
self._path_exports: Dict[str, list[ServiceInterface]] = {}
self._bus_address = (
parse_address(bus_address)
if bus_address
@ -88,12 +93,12 @@ class BaseMessageBus:
self._name_owner_match_rule = "sender='org.freedesktop.DBus',interface='org.freedesktop.DBus',path='/org/freedesktop/DBus',member='NameOwnerChanged'"
# _match_rules: the keys are match rules and the values are ref counts
# (used for the high level client only)
self._match_rules = {}
self._match_rules: Dict[str, int] = {}
self._high_level_client_initialized = False
self._ProxyObject = ProxyObject
# machine id is lazy loaded
self._machine_id = None
self._machine_id: Optional[int] = None
self._setup_socket()
@ -211,10 +216,10 @@ class BaseMessageBus:
"""
BaseMessageBus._check_callback_type(callback)
def reply_notify(reply, err):
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
try:
BaseMessageBus._check_method_return(reply, err, "s")
result = intr.Node.parse(reply.body[0])
result = intr.Node.parse(reply.body[0]) # type: ignore[union-attr]
except Exception as e:
callback(None, e)
return
@ -246,7 +251,12 @@ class BaseMessageBus:
if self._disconnected:
return
def get_properties_callback(interface, result, user_data, e):
def get_properties_callback(
interface: ServiceInterface,
result: Any,
user_data: Any,
e: Optional[Exception],
) -> None:
if e is not None:
try:
raise e
@ -272,7 +282,7 @@ class BaseMessageBus:
ServiceInterface._get_all_property_values(interface, get_properties_callback)
def _emit_interface_removed(self, path, removed_interfaces):
def _emit_interface_removed(self, path: str, removed_interfaces: List[str]) -> None:
"""Emit the ``org.freedesktop.DBus.ObjectManager.InterfacesRemoved` signal.
This signal is intended to be used to alert clients when
@ -303,7 +313,7 @@ class BaseMessageBus:
callback: Optional[
Callable[[Optional[RequestNameReply], Optional[Exception]], None]
] = None,
):
) -> None:
"""Request that this message bus owns the given name.
:param name: The name to request.
@ -322,38 +332,41 @@ class BaseMessageBus:
if callback is not None:
BaseMessageBus._check_callback_type(callback)
def reply_notify(reply, err):
try:
BaseMessageBus._check_method_return(reply, err, "u")
result = RequestNameReply(reply.body[0])
except Exception as e:
callback(None, e)
return
callback(result, None)
if type(flags) is not NameFlag:
flags = NameFlag(flags)
self._call(
Message(
destination="org.freedesktop.DBus",
path="/org/freedesktop/DBus",
interface="org.freedesktop.DBus",
member="RequestName",
signature="su",
body=[name, flags],
),
reply_notify if callback else None,
message = Message(
destination="org.freedesktop.DBus",
path="/org/freedesktop/DBus",
interface="org.freedesktop.DBus",
member="RequestName",
signature="su",
body=[name, flags],
)
if callback is None:
self._call(message, None)
return
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
try:
BaseMessageBus._check_method_return(reply, err, "u")
result = RequestNameReply(reply.body[0]) # type: ignore[union-attr]
except Exception as e:
callback(None, e) # type: ignore[misc]
return
callback(result, None) # type: ignore[misc]
self._call(message, reply_notify)
def release_name(
self,
name: str,
callback: Optional[
Callable[[Optional[ReleaseNameReply], Optional[Exception]], None]
] = None,
):
) -> None:
"""Request that this message bus release the given name.
:param name: The name to release.
@ -371,27 +384,30 @@ class BaseMessageBus:
if callback is not None:
BaseMessageBus._check_callback_type(callback)
def reply_notify(reply, err):
message = Message(
destination="org.freedesktop.DBus",
path="/org/freedesktop/DBus",
interface="org.freedesktop.DBus",
member="ReleaseName",
signature="s",
body=[name],
)
if callback is None:
self._call(message, None)
return
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
try:
BaseMessageBus._check_method_return(reply, err, "u")
result = ReleaseNameReply(reply.body[0])
result = ReleaseNameReply(reply.body[0]) # type: ignore[union-attr]
except Exception as e:
callback(None, e)
callback(None, e) # type: ignore[misc]
return
callback(result, None)
callback(result, None) # type: ignore[misc]
self._call(
Message(
destination="org.freedesktop.DBus",
path="/org/freedesktop/DBus",
interface="org.freedesktop.DBus",
member="ReleaseName",
signature="s",
body=[name],
),
reply_notify if callback else None,
)
self._call(message, reply_notify)
def get_proxy_object(
self, bus_name: str, path: str, introspection: Union[intr.Node, str, ET.Element]
@ -451,7 +467,7 @@ class BaseMessageBus:
def add_message_handler(
self, handler: Callable[[Message], Optional[Union[Message, bool]]]
):
) -> None:
"""Add a custom message handler for incoming messages.
The handler should be a callable that takes a :class:`Message
@ -476,7 +492,7 @@ class BaseMessageBus:
def remove_message_handler(
self, handler: Callable[[Message], Optional[Union[Message, bool]]]
):
) -> None:
"""Remove a message handler that was previously added by
:func:`add_message_handler()
<dbus_fast.message_bus.BaseMessageBus.add_message_handler>`.
@ -487,7 +503,7 @@ class BaseMessageBus:
for i, h in enumerate(self._user_message_handlers):
if h == handler:
del self._user_message_handlers[i]
break
return
def send(self, msg: Message) -> None:
"""Asynchronously send a message on the message bus.
@ -499,7 +515,7 @@ class BaseMessageBus:
'the "send" method must be implemented in the inheriting class'
)
def _finalize(self, err):
def _finalize(self, err: Optional[Exception]) -> None:
"""should be called after the socket disconnects with the disconnection
error to clean up resources and put the bus in a disconnected state"""
if self._disconnected:
@ -531,8 +547,14 @@ class BaseMessageBus:
return False
def _interface_signal_notify(
self, interface, interface_name, member, signature, body, unix_fds=[]
):
self,
interface: ServiceInterface,
interface_name: str,
member: str,
signature: str,
body: List[Any],
unix_fds: List[int] = [],
) -> None:
path = None
for p, ifaces in self._path_exports.items():
for i in ifaces:
@ -555,7 +577,7 @@ class BaseMessageBus:
)
)
def _introspect_export_path(self, path):
def _introspect_export_path(self, path: str) -> intr.Node:
assert_object_path_valid(path)
if path in self._path_exports:
@ -582,7 +604,7 @@ class BaseMessageBus:
return node
def _setup_socket(self):
def _setup_socket(self) -> None:
err = None
for transport, options in self._bus_address:
@ -635,7 +657,10 @@ class BaseMessageBus:
raise err
def _call(
self, msg: Message, callback: Callable, check_callback: bool = True
self,
msg: Message,
callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]],
check_callback: bool = True,
) -> None:
if check_callback:
BaseMessageBus._check_callback_type(callback)
@ -643,10 +668,10 @@ class BaseMessageBus:
if not msg.serial:
msg.serial = self.next_serial()
def reply_notify(reply, err):
if reply:
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
if reply and msg.destination and reply.sender:
self._name_owners[msg.destination] = reply.sender
callback(reply, err)
callback(reply, err) # type: ignore[misc]
no_reply_expected = msg.flags & MessageFlag.NO_REPLY_EXPECTED
@ -660,7 +685,7 @@ class BaseMessageBus:
self.send(msg)
if no_reply_expected:
callback(None, None)
callback(None, None) # type: ignore[misc]
@staticmethod
def _check_callback_type(callback: Callable) -> None:
@ -676,9 +701,15 @@ class BaseMessageBus:
raise TypeError(text)
@staticmethod
def _check_method_return(msg: Message, err: Exception, signature: str) -> None:
def _check_method_return(
msg: Optional[Message], err: Optional[Exception], signature: str
) -> None:
if err:
raise err
elif msg is None:
raise DBusError(
ErrorType.INTERNAL_ERROR, "invalid message type for method call", msg
)
elif (
msg.message_type == MessageType.METHOD_RETURN and msg.signature == signature
):
@ -694,21 +725,26 @@ class BaseMessageBus:
bus = self
class SendReply:
def __enter__(self):
def __enter__(self) -> "SendReply":
return self
def __call__(self, reply):
def __call__(self, reply: Message) -> None:
if msg.flags & MessageFlag.NO_REPLY_EXPECTED:
return
bus.send(reply)
def _exit(self, exc_type, exc_value, tb):
def _exit(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
) -> bool:
if exc_type is None:
return
return False
if issubclass(exc_type, DBusError):
self(exc_value._as_message(msg))
self(exc_value._as_message(msg)) # type: ignore[union-attr]
return True
if issubclass(exc_type, Exception):
@ -721,10 +757,15 @@ class BaseMessageBus:
)
return True
def __exit__(self, exc_type, exc_value, tb):
def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
) -> None:
self._exit(exc_type, exc_value, tb)
def send_error(self, exc):
def send_error(self, exc: Exception) -> None:
self._exit(exc.__class__, exc, exc.__traceback__)
return SendReply()
@ -732,9 +773,9 @@ class BaseMessageBus:
def _process_message(self, msg: Message) -> None:
handled = False
for handler in self._user_message_handlers:
for user_handler in self._user_message_handlers:
try:
result = handler(msg)
result = user_handler(msg)
if result:
if type(result) is Message:
self.send(result)
@ -799,12 +840,14 @@ class BaseMessageBus:
# An ERROR or a METHOD_RETURN
if msg.reply_serial in self._method_return_handlers:
if not handled:
handler = self._method_return_handlers[msg.reply_serial]
handler(msg, None)
return_handler = self._method_return_handlers[msg.reply_serial]
return_handler(msg, None)
del self._method_return_handlers[msg.reply_serial]
def _make_method_handler(self, interface, method):
def handler(msg, send_reply):
def _make_method_handler(
self, interface: ServiceInterface, method: _Method
) -> Callable[[Message, Callable[[Message], None]], None]:
def handler(msg: Message, send_reply: Callable[[Message], None]) -> None:
args = ServiceInterface._msg_body_to_args(msg)
result = method.fn(interface, *args)
body, fds = ServiceInterface._fn_result_to_body(
@ -817,7 +860,7 @@ class BaseMessageBus:
def _find_message_handler(
self, msg: Message
) -> Optional[Callable[[Message, Callable], None]]:
handler = None
handler: Optional[Callable[[Message, Callable], None]] = None
if (
msg.interface == "org.freedesktop.DBus.Introspectable"
@ -840,7 +883,7 @@ class BaseMessageBus:
):
handler = self._default_get_managed_objects_handler
else:
elif msg.path:
for interface in self._path_exports.get(msg.path, []):
for method in ServiceInterface._get_methods(interface):
if method.disabled:

View File

@ -327,9 +327,9 @@ class ServiceInterface:
def __init__(self, name: str):
# TODO cannot be overridden by a dbus member
self.name = name
self.__methods = []
self.__properties = []
self.__signals = []
self.__methods: List[_Method] = []
self.__properties: List[_Property] = []
self.__signals: List[_Signal] = []
self.__buses = set()
for name, member in inspect.getmembers(type(self)):
@ -425,27 +425,27 @@ class ServiceInterface:
)
@staticmethod
def _get_properties(interface):
def _get_properties(interface: "ServiceInterface") -> List[_Property]:
return interface.__properties
@staticmethod
def _get_methods(interface):
def _get_methods(interface: "ServiceInterface") -> List[_Method]:
return interface.__methods
@staticmethod
def _get_signals(interface):
def _get_signals(interface: "ServiceInterface") -> List[_Signal]:
return interface.__signals
@staticmethod
def _get_buses(interface):
def _get_buses(interface: "ServiceInterface"):
return interface.__buses
@staticmethod
def _add_bus(interface, bus):
def _add_bus(interface: "ServiceInterface", bus) -> None:
interface.__buses.add(bus)
@staticmethod
def _remove_bus(interface, bus):
def _remove_bus(interface: "ServiceInterface", bus) -> None:
interface.__buses.remove(bus)
@staticmethod