feat: refactor service bus handler lookup to avoid linear searches (#400)

This commit is contained in:
J. Nick Koston 2025-03-05 13:28:00 -10:00 committed by GitHub
parent 640e1f8d87
commit 996659e1b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 315 additions and 262 deletions

View File

@ -37,7 +37,7 @@ class MessageFlag(IntFlag):
ALLOW_INTERACTIVE_AUTHORIZATION = 4
@cached_property
def value(self) -> str:
def value(self) -> int:
"""Return the value."""
return self._value_

View File

@ -51,7 +51,7 @@ class Arg:
def __init__(
self,
signature: Union[SignatureType, str],
direction: Optional[list[ArgDirection]] = None,
direction: Optional[ArgDirection] = None,
name: Optional[str] = None,
annotations: Optional[dict[str, str]] = None,
):

View File

@ -4,6 +4,7 @@ from ._private.address cimport get_bus_address, parse_address
from .message cimport Message
from .service cimport ServiceInterface, _Method
cdef bint TYPE_CHECKING
cdef object MessageType
cdef object DBusError
@ -39,24 +40,26 @@ cdef class BaseMessageBus:
cdef public object _high_level_client_initialized
cdef public object _ProxyObject
cdef public object _machine_id
cdef public object _negotiate_unix_fd
cdef public bint _negotiate_unix_fd
cdef public object _sock
cdef public object _stream
cdef public object _fd
cpdef _process_message(self, Message msg)
cpdef void _process_message(self, Message msg)
@cython.locals(exported_service_interface=ServiceInterface)
cpdef export(self, str path, ServiceInterface interface)
@cython.locals(
methods=cython.list,
method=_Method,
interface=ServiceInterface,
interfaces=cython.list,
interfaces=dict,
)
cdef _find_message_handler(self, Message msg)
cdef _setup_socket(self)
@cython.locals(no_reply_expected=bint)
cpdef _call(self, Message msg, object callback)
cpdef next_serial(self)

View File

@ -1,10 +1,11 @@
from __future__ import annotations
import inspect
import logging
import socket
import traceback
import xml.etree.ElementTree as ET
from functools import partial
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, TYPE_CHECKING
from . import introspection as intr
from ._private.address import get_bus_address, parse_address
@ -22,7 +23,7 @@ from .errors import DBusError, InvalidAddressError
from .message import Message
from .proxy_object import BaseProxyObject
from .send_reply import SendReply
from .service import ServiceInterface, _Method
from .service import ServiceInterface, _Method, _Property, HandlerType
from .signature import Variant
from .validators import assert_bus_name_valid, assert_object_path_valid
@ -119,12 +120,12 @@ class BaseMessageBus:
def __init__(
self,
bus_address: Optional[str] = None,
bus_address: str | None = None,
bus_type: BusType = BusType.SESSION,
ProxyObject: Optional[type[BaseProxyObject]] = None,
ProxyObject: type[BaseProxyObject] | None = None,
negotiate_unix_fd: bool = False,
) -> None:
self.unique_name: Optional[str] = None
self.unique_name: str | None = None
self._disconnected = False
self._negotiate_unix_fd = negotiate_unix_fd
@ -133,11 +134,11 @@ class BaseMessageBus:
self._user_disconnect = False
self._method_return_handlers: dict[
int, Callable[[Optional[Message], Optional[Exception]], None]
int, Callable[[Message | None, Exception | None], None]
] = {}
self._serial = 0
self._user_message_handlers: list[
Callable[[Message], Union[Message, bool, None]]
Callable[[Message], Message | bool | None]
] = []
# the key is the name and the value is the unique name of the owner.
# This cache is kept up to date by the NameOwnerChanged signal and is
@ -145,7 +146,7 @@ class BaseMessageBus:
# high level client only)
self._name_owners: dict[str, str] = {}
# used for the high level service
self._path_exports: dict[str, list[ServiceInterface]] = {}
self._path_exports: dict[str, dict[str, ServiceInterface]] = {}
self._bus_address = (
parse_address(bus_address)
if bus_address
@ -161,10 +162,10 @@ class BaseMessageBus:
self._ProxyObject = ProxyObject
# machine id is lazy loaded
self._machine_id: Optional[int] = None
self._sock: Optional[socket.socket] = None
self._fd: Optional[int] = None
self._stream: Optional[Any] = None
self._machine_id: int | None = None
self._sock: socket.socket | None = None
self._fd: int | None = None
self._stream: Any | None = None
self._setup_socket()
@ -193,20 +194,18 @@ class BaseMessageBus:
raise TypeError("interface must be a ServiceInterface")
if path not in self._path_exports:
self._path_exports[path] = []
self._path_exports[path] = {}
elif interface.name in self._path_exports[path]:
raise ValueError(
f'An interface with this name is already exported on this bus at path "{path}": "{interface.name}"'
)
for f in self._path_exports[path]:
if f.name == interface.name:
raise ValueError(
f'An interface with this name is already exported on this bus at path "{path}": "{interface.name}"'
)
self._path_exports[path].append(interface)
self._path_exports[path][interface.name] = interface
ServiceInterface._add_bus(interface, self, self._make_method_handler)
self._emit_interface_added(path, interface)
def unexport(
self, path: str, interface: Optional[Union[ServiceInterface, str]] = None
self, path: str, interface: ServiceInterface | str | None = None
) -> None:
"""Unexport the path or service interface to make it no longer
available to clients.
@ -222,45 +221,42 @@ class BaseMessageBus:
- :class:`InvalidObjectPathError <dbus_fast.InvalidObjectPathError>` - If the given object path is not valid.
"""
assert_object_path_valid(path)
if type(interface) not in [str, type(None)] and not isinstance(
interface, ServiceInterface
):
raise TypeError("interface must be a ServiceInterface or interface name")
if path not in self._path_exports:
return
exports = self._path_exports[path]
if type(interface) is str:
try:
interface = next(iface for iface in exports if iface.name == interface)
except StopIteration:
return
removed_interfaces = []
interface_name: str | None
if interface is None:
del self._path_exports[path]
for iface in filter(lambda e: not self._has_interface(e), exports):
removed_interfaces.append(iface.name)
ServiceInterface._remove_bus(iface, self)
interface_name = None
elif type(interface) is str:
interface_name = interface
elif isinstance(interface, ServiceInterface):
interface_name = interface.name
else:
for i, iface in enumerate(exports):
if iface is interface:
removed_interfaces.append(iface.name)
del self._path_exports[path][i]
if not self._path_exports[path]:
del self._path_exports[path]
if not self._has_interface(iface):
ServiceInterface._remove_bus(iface, self)
break
self._emit_interface_removed(path, removed_interfaces)
raise TypeError(
f"interface must be a ServiceInterface or interface name not {type(interface)}"
)
if (interfaces := self._path_exports.get(path)) is None:
return
removed_interface_names: list[str] = []
if interface_name is not None:
if (removed_interface := interfaces.pop(interface_name, None)) is None:
return
removed_interface_names.append(interface_name)
if not interfaces:
del self._path_exports[path]
ServiceInterface._remove_bus(removed_interface, self)
else:
del self._path_exports[path]
for removed_interface in interfaces.values():
removed_interface_names.append(removed_interface.name)
ServiceInterface._remove_bus(removed_interface, self)
self._emit_interface_removed(path, removed_interface_names)
def introspect(
self,
bus_name: str,
path: str,
callback: Callable[[Optional[intr.Node], Optional[Exception]], None],
callback: Callable[[intr.Node | None, Exception | None], None],
check_callback_type: bool = True,
validate_property_names: bool = True,
) -> None:
@ -289,12 +285,13 @@ class BaseMessageBus:
if check_callback_type:
BaseMessageBus._check_callback_type(callback)
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
def reply_notify(reply: Message | None, err: Exception | None) -> None:
try:
BaseMessageBus._check_method_return(reply, err, "s")
result = intr.Node.parse(
reply.body[0], validate_property_names=validate_property_names
) # type: ignore[union-attr]
reply.body[0], # type: ignore[union-attr]
validate_property_names=validate_property_names,
)
except Exception as e:
callback(None, e)
return
@ -330,7 +327,7 @@ class BaseMessageBus:
interface: ServiceInterface,
result: Any,
user_data: Any,
e: Optional[Exception],
e: Exception | None,
) -> None:
if e is not None:
try:
@ -385,9 +382,8 @@ class BaseMessageBus:
self,
name: str,
flags: NameFlag = NameFlag.NONE,
callback: Optional[
Callable[[Optional[RequestNameReply], Optional[Exception]], None]
] = None,
callback: None
| (Callable[[RequestNameReply | None, Exception | None], None]) = None,
check_callback_type: bool = True,
) -> None:
"""Request that this message bus owns the given name.
@ -424,24 +420,23 @@ class BaseMessageBus:
self._call(message, None)
return
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
def reply_notify(reply: Message | None, err: Exception | None) -> None:
try:
BaseMessageBus._check_method_return(reply, err, "u")
result = RequestNameReply(reply.body[0]) # type: ignore[union-attr]
except Exception as e:
callback(None, e) # type: ignore[misc]
callback(None, e)
return
callback(result, None) # type: ignore[misc]
callback(result, None)
self._call(message, reply_notify)
def release_name(
self,
name: str,
callback: Optional[
Callable[[Optional[ReleaseNameReply], Optional[Exception]], None]
] = None,
callback: None
| (Callable[[ReleaseNameReply | None, Exception | None], None]) = None,
check_callback_type: bool = True,
) -> None:
"""Request that this message bus release the given name.
@ -474,20 +469,20 @@ class BaseMessageBus:
self._call(message, None)
return
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
def reply_notify(reply: Message | None, err: Exception | None) -> None:
try:
BaseMessageBus._check_method_return(reply, err, "u")
result = ReleaseNameReply(reply.body[0]) # type: ignore[union-attr]
except Exception as e:
callback(None, e) # type: ignore[misc]
callback(None, e)
return
callback(result, None) # type: ignore[misc]
callback(result, None)
self._call(message, reply_notify)
def get_proxy_object(
self, bus_name: str, path: str, introspection: Union[intr.Node, str, ET.Element]
self, bus_name: str, path: str, introspection: intr.Node | str | ET.Element
) -> BaseProxyObject:
"""Get a proxy object for the path exported on the bus that owns the
name. The object is expected to export the interfaces and nodes
@ -526,10 +521,11 @@ class BaseMessageBus:
All pending and future calls will error with a connection error.
"""
self._user_disconnect = True
try:
self._sock.shutdown(socket.SHUT_RDWR)
except Exception:
logging.warning("could not shut down socket", exc_info=True)
if self._sock:
try:
self._sock.shutdown(socket.SHUT_RDWR)
except Exception:
logging.warning("could not shut down socket", exc_info=True)
def next_serial(self) -> int:
"""Get the next serial for this bus. This can be used as the ``serial``
@ -543,7 +539,7 @@ class BaseMessageBus:
return self._serial
def add_message_handler(
self, handler: Callable[[Message], Optional[Union[Message, bool]]]
self, handler: Callable[[Message], Message | bool | None]
) -> None:
"""Add a custom message handler for incoming messages.
@ -568,7 +564,7 @@ class BaseMessageBus:
self._user_message_handlers.append(handler)
def remove_message_handler(
self, handler: Callable[[Message], Optional[Union[Message, bool]]]
self, handler: Callable[[Message], Message | bool | None]
) -> None:
"""Remove a message handler that was previously added by
:func:`add_message_handler()
@ -592,7 +588,7 @@ class BaseMessageBus:
'the "send" method must be implemented in the inheriting class'
)
def _finalize(self, err: Optional[Exception]) -> None:
def _finalize(self, err: Exception | None) -> None:
"""should be called after the socket disconnects with the disconnection
error to clean up resources and put the bus in a disconnected state"""
if self._disconnected:
@ -615,14 +611,6 @@ class BaseMessageBus:
self._user_message_handlers.clear()
def _has_interface(self, interface: ServiceInterface) -> bool:
for _, exports in self._path_exports.items():
for iface in exports:
if iface is interface:
return True
return False
def _interface_signal_notify(
self,
interface: ServiceInterface,
@ -632,9 +620,9 @@ class BaseMessageBus:
body: list[Any],
unix_fds: list[int] = [],
) -> None:
path = None
path: str | None = None
for p, ifaces in self._path_exports.items():
for i in ifaces:
for i in ifaces.values():
if i is interface:
path = p
@ -657,9 +645,9 @@ class BaseMessageBus:
def _introspect_export_path(self, path: str) -> intr.Node:
assert_object_path_valid(path)
if path in self._path_exports:
if (interfaces := self._path_exports.get(path)) is not None:
node = intr.Node.default(path)
for interface in self._path_exports[path]:
for interface in interfaces.values():
node.interfaces.append(interface.introspect())
else:
node = intr.Node(path)
@ -687,7 +675,7 @@ class BaseMessageBus:
err = None
for transport, options in self._bus_address:
filename = None
filename: bytes | str | None = None
ip_addr = ""
ip_port = 0
@ -738,9 +726,9 @@ class BaseMessageBus:
def _reply_notify(
self,
msg: Message,
callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]],
reply: Optional[Message],
err: Optional[Exception],
callback: Callable[[Message | None, Exception | None], None],
reply: Message | None,
err: Exception | None,
) -> None:
"""Callback on reply."""
if reply and msg.destination and reply.sender:
@ -750,24 +738,23 @@ class BaseMessageBus:
def _call(
self,
msg: Message,
callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]],
callback: Callable[[Message | None, Exception | None], None] | None,
) -> None:
if not msg.serial:
msg.serial = self.next_serial()
reply_expected = _expects_reply(msg)
# Make sure the return reply handler is installed
# before sending the message to avoid a race condition
# where the reply is lost in case the backend can
# send it right away.
if reply_expected:
if (reply_expected := _expects_reply(msg)) and callback is not None:
self._method_return_handlers[msg.serial] = partial(
self._reply_notify, msg, callback
)
self.send(msg)
if not reply_expected:
if not reply_expected and callback is not None:
callback(None, None)
@staticmethod
@ -785,7 +772,7 @@ class BaseMessageBus:
@staticmethod
def _check_method_return(
msg: Optional[Message], err: Optional[Exception], signature: str
msg: Message | None, err: Exception | None, signature: str
) -> None:
if err:
raise err
@ -809,8 +796,7 @@ class BaseMessageBus:
handled = False
for user_handler in self._user_message_handlers:
try:
result = user_handler(msg)
if result:
if result := user_handler(msg):
if type(result) is Message:
self.send(result)
handled = True
@ -842,8 +828,8 @@ class BaseMessageBus:
and msg.path == "/org/freedesktop/DBus"
and msg.interface == "org.freedesktop.DBus"
):
[name, old_owner, new_owner] = msg.body
if new_owner:
name = msg.body[0]
if new_owner := msg.body[2]:
self._name_owners[name] = new_owner
elif name in self._name_owners:
del self._name_owners[name]
@ -852,9 +838,9 @@ class BaseMessageBus:
if msg.message_type is MESSAGE_TYPE_CALL:
if not handled:
handler = self._find_message_handler(msg)
if _expects_reply(msg) is False:
if not _expects_reply(msg):
if handler:
handler(msg, BLOCK_UNEXPECTED_REPLY)
handler(msg, BLOCK_UNEXPECTED_REPLY) # type: ignore[arg-type]
else:
_LOGGER.error(
'"%s.%s" with signature "%s" could not be found',
@ -891,17 +877,17 @@ class BaseMessageBus:
interface: ServiceInterface,
method: _Method,
msg: Message,
send_reply: Callable[[Message], None],
send_reply: SendReply,
) -> None:
"""This is the callback that will be called when a method call is."""
args = ServiceInterface._c_msg_body_to_args(msg) if msg.unix_fds else msg.body
result = method.fn(interface, *args)
if send_reply is BLOCK_UNEXPECTED_REPLY or _expects_reply(msg) is False:
return
body, fds = ServiceInterface._c_fn_result_to_body(
body_fds = ServiceInterface._c_fn_result_to_body(
result,
signature_tree=method.out_signature_tree,
replace_fds=self._negotiate_unix_fd,
method.out_signature_tree,
self._negotiate_unix_fd,
)
send_reply(
Message(
@ -909,19 +895,20 @@ class BaseMessageBus:
reply_serial=msg.serial,
destination=msg.sender,
signature=method.out_signature,
body=body,
unix_fds=fds,
body=body_fds[0],
unix_fds=body_fds[1],
)
)
def _make_method_handler(
self, interface: ServiceInterface, method: _Method
) -> Callable[[Message, Callable[[Message], None]], None]:
) -> HandlerType:
return partial(self._callback_method_handler, interface, method)
def _find_message_handler(
self, msg: _Message
) -> Optional[Callable[[Message, Callable[[Message], None]], None]]:
def _find_message_handler(self, msg: _Message) -> HandlerType | None:
if TYPE_CHECKING:
assert msg.interface is not None
if "org.freedesktop.DBus." in msg.interface:
if (
msg.interface == "org.freedesktop.DBus.Introspectable"
@ -945,54 +932,52 @@ class BaseMessageBus:
):
return self._default_get_managed_objects_handler
msg_path = msg.path
if msg_path:
interfaces = self._path_exports.get(msg_path)
if not interfaces:
return None
for interface in interfaces:
methods = ServiceInterface._c_get_methods(interface)
for method in methods:
if method.disabled:
continue
if (
msg.interface == interface.name
and msg.member == method.name
and msg.signature == method.in_signature
):
return ServiceInterface._c_get_handler(interface, method, self)
if (
msg.path is not None
and msg.member is not None
and (interfaces := self._path_exports.get(msg.path)) is not None
and (interface := interfaces.get(msg.interface)) is not None
and (
handler := ServiceInterface._get_enabled_handler_by_name_signature(
interface, self, msg.member, msg.signature
)
)
is not None
):
return handler
return None
def _default_introspect_handler(
self, msg: Message, send_reply: Callable[[Message], None]
) -> None:
def _default_introspect_handler(self, msg: Message, send_reply: SendReply) -> None:
if TYPE_CHECKING:
assert msg.path is not None
introspection = self._introspect_export_path(msg.path).tostring()
send_reply(Message.new_method_return(msg, "s", [introspection]))
def _default_ping_handler(
self, msg: Message, send_reply: Callable[[Message], None]
) -> None:
def _default_ping_handler(self, msg: Message, send_reply: SendReply) -> None:
send_reply(Message.new_method_return(msg))
def _send_machine_id_reply(self, msg: Message, send_reply: SendReply) -> None:
send_reply(Message.new_method_return(msg, "s", [self._machine_id]))
def _default_get_machine_id_handler(
self, msg: Message, send_reply: Callable[[Message], None]
self, msg: Message, send_reply: SendReply
) -> None:
if self._machine_id:
send_reply(Message.new_method_return(msg, "s", self._machine_id))
self._send_machine_id_reply(msg, send_reply)
return
def reply_handler(reply, err):
if err:
def reply_handler(reply: Message | None, err: Exception | None) -> None:
if err or reply is None:
# the bus has been disconnected, cannot send a reply
return
if reply.message_type == MessageType.METHOD_RETURN:
self._machine_id = reply.body[0]
send_reply(Message.new_method_return(msg, "s", [self._machine_id]))
elif reply.message_type == MessageType.ERROR:
send_reply(Message.new_error(msg, reply.error_name, reply.body))
self._send_machine_id_reply(msg, send_reply)
elif (
reply.message_type == MessageType.ERROR and reply.error_name is not None
):
send_reply(Message.new_error(msg, reply.error_name, str(reply.body)))
else:
send_reply(
Message.new_error(msg, ErrorType.FAILED, "could not get machine_id")
@ -1009,13 +994,12 @@ class BaseMessageBus:
)
def _default_get_managed_objects_handler(
self, msg: Message, send_reply: Callable[[Message], None]
self, msg: Message, send_reply: SendReply
) -> None:
result = {}
result_signature = "a{oa{sa{sv}}}"
error_handled = False
def is_result_complete():
def is_result_complete() -> bool:
if not result:
return True
for n, interfaces in result.items():
@ -1025,6 +1009,9 @@ class BaseMessageBus:
return True
if TYPE_CHECKING:
assert msg.path is not None
nodes = [
node
for node in self._path_exports
@ -1032,16 +1019,18 @@ class BaseMessageBus:
]
# first build up the result object to know when it's complete
for node in nodes:
result[node] = {}
for interface in self._path_exports[node]:
result[node][interface.name] = None
result: dict[str, dict[str, Any]] = {
node: {interface: None for interface in self._path_exports[node]}
for node in nodes
}
if is_result_complete():
send_reply(Message.new_method_return(msg, result_signature, [result]))
return
def get_all_properties_callback(interface, values, node, err):
def get_all_properties_callback(
interface: ServiceInterface, values: Any, node: str, err: Exception | None
) -> None:
nonlocal error_handled
if err is not None:
if not error_handled:
@ -1055,14 +1044,12 @@ class BaseMessageBus:
send_reply(Message.new_method_return(msg, result_signature, [result]))
for node in nodes:
for interface in self._path_exports[node]:
for interface in self._path_exports[node].values():
ServiceInterface._get_all_property_values(
interface, get_all_properties_callback, node
)
def _default_properties_handler(
self, msg: Message, send_reply: Callable[[Message], None]
) -> None:
def _default_properties_handler(self, msg: Message, send_reply: SendReply) -> None:
methods = {"Get": "ss", "Set": "ssv", "GetAll": "s"}
if msg.member not in methods or methods[msg.member] != msg.signature:
raise DBusError(
@ -1082,12 +1069,7 @@ class BaseMessageBus:
ErrorType.UNKNOWN_OBJECT, f'no interfaces at path: "{msg.path}"'
)
match = [
iface
for iface in self._path_exports[msg.path]
if iface.name == interface_name
]
if not match:
if (interface := self._path_exports[msg.path].get(interface_name)) is None:
if interface_name in [
"org.freedesktop.DBus.Properties",
"org.freedesktop.DBus.Introspectable",
@ -1111,7 +1093,6 @@ class BaseMessageBus:
f'could not find an interface "{interface_name}" at path: "{msg.path}"',
)
interface = match[0]
properties = ServiceInterface._get_properties(interface)
if msg.member == "Get" or msg.member == "Set":
@ -1135,7 +1116,12 @@ class BaseMessageBus:
"the property does not have read access",
)
def get_property_callback(interface, prop, prop_value, err):
def get_property_callback(
interface: ServiceInterface,
prop: _Property,
prop_value: Any,
err: Exception | None,
) -> None:
try:
if err is not None:
send_reply.send_error(err)
@ -1173,7 +1159,9 @@ class BaseMessageBus:
)
assert prop.prop_setter
def set_property_callback(interface, prop, err):
def set_property_callback(
interface: ServiceInterface, prop: _Property, err: Exception | None
) -> None:
if err is not None:
send_reply.send_error(err)
return
@ -1188,7 +1176,12 @@ class BaseMessageBus:
elif msg.member == "GetAll":
def get_all_properties_callback(interface, values, user_data, err):
def get_all_properties_callback(
interface: ServiceInterface,
values: Any,
user_data: Any,
err: Exception | None,
) -> None:
if err is not None:
send_reply.send_error(err)
return
@ -1212,12 +1205,12 @@ class BaseMessageBus:
return
self._high_level_client_initialized = True
def add_match_notify(msg, err):
def add_match_notify(msg: Message | None, err: Exception | None) -> None:
if err:
logging.error(
f'add match request failed. match="{self._name_owner_match_rule}", {err}'
)
elif msg.message_type == MessageType.ERROR:
elif msg is not None and msg.message_type == MessageType.ERROR:
logging.error(
f'add match request failed. match="{self._name_owner_match_rule}", {msg.body[0]}'
)
@ -1234,7 +1227,7 @@ class BaseMessageBus:
add_match_notify,
)
def _add_match_rule(self, match_rule):
def _add_match_rule(self, match_rule: str) -> None:
"""Add a match rule. Match rules added by this function are refcounted
and must be removed by _remove_match_rule(). This is for use in the
high level client only."""
@ -1247,10 +1240,10 @@ class BaseMessageBus:
self._match_rules[match_rule] = 1
def add_match_notify(msg: Message, err: Optional[Exception]) -> None:
def add_match_notify(msg: Message | None, err: Exception | None) -> None:
if err:
logging.error(f'add match request failed. match="{match_rule}", {err}')
elif msg.message_type == MessageType.ERROR:
elif msg is not None and msg.message_type == MessageType.ERROR:
logging.error(
f'add match request failed. match="{match_rule}", {msg.body[0]}'
)
@ -1267,7 +1260,7 @@ class BaseMessageBus:
add_match_notify,
)
def _remove_match_rule(self, match_rule):
def _remove_match_rule(self, match_rule: str) -> None:
"""Remove a match rule added with _add_match_rule(). This is for use in
the high level client only."""
if match_rule == self._name_owner_match_rule:
@ -1280,7 +1273,7 @@ class BaseMessageBus:
del self._match_rules[match_rule]
def remove_match_notify(msg, err):
def remove_match_notify(msg: Message | None, err: Exception | None) -> None:
if self._disconnected:
return
@ -1288,7 +1281,7 @@ class BaseMessageBus:
logging.error(
f'remove match request failed. match="{match_rule}", {err}'
)
elif msg.message_type == MessageType.ERROR:
elif msg is not None and msg.message_type == MessageType.ERROR:
logging.error(
f'remove match request failed. match="{match_rule}", {msg.body[0]}'
)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import traceback
from types import TracebackType
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from .constants import ErrorType
from .errors import DBusError
@ -15,12 +16,12 @@ class SendReply:
__slots__ = ("_bus", "_msg")
def __init__(self, bus: "BaseMessageBus", msg: Message) -> None:
def __init__(self, bus: BaseMessageBus, msg: Message) -> None:
"""Create a new reply context manager."""
self._bus = bus
self._msg = msg
def __enter__(self):
def __enter__(self) -> SendReply:
return self
def __call__(self, reply: Message) -> None:
@ -28,9 +29,9 @@ class SendReply:
def _exit(
self,
exc_type: Optional[type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
exc_type: type[Exception] | None,
exc_value: Exception | None,
tb: TracebackType | None,
) -> bool:
if exc_value:
if isinstance(exc_value, DBusError):
@ -49,9 +50,9 @@ class SendReply:
def __exit__(
self,
exc_type: Optional[type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
exc_type: type[Exception] | None,
exc_value: Exception | None,
tb: TracebackType | None,
) -> bool:
return self._exit(exc_type, exc_value, tb)

View File

@ -33,12 +33,11 @@ cdef class ServiceInterface:
cdef list __signals
cdef set __buses
cdef dict __handlers
cdef dict __enabled_handlers_by_name_signature
@cython.locals(handlers=dict,in_signature=str,method=_Method)
@staticmethod
cdef list _c_get_methods(ServiceInterface interface)
@staticmethod
cdef object _c_get_handler(ServiceInterface interface, _Method method, object bus)
cdef object _get_enabled_handler_by_name_signature(ServiceInterface interface, object bus, object name, object signature)
@staticmethod
cdef list _c_msg_body_to_args(Message msg)

View File

@ -1,8 +1,9 @@
from __future__ import annotations
import asyncio
import copy
import inspect
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Protocol
from . import introspection as intr
from ._private.util import (
@ -20,19 +21,30 @@ from .signature import (
Variant,
get_signature_tree,
)
from .send_reply import SendReply
if TYPE_CHECKING:
from .message_bus import BaseMessageBus
str_ = str
HandlerType = Callable[[Message, SendReply], None]
class _MethodCallbackProtocol(Protocol):
def __call__(self, interface: ServiceInterface, *args: Any) -> Any: ...
class _Method:
def __init__(self, fn, name: str, disabled=False):
def __init__(
self, fn: _MethodCallbackProtocol, name: str, disabled: bool = False
) -> None:
in_signature = ""
out_signature = ""
inspection = inspect.signature(fn)
in_args = []
in_args: list[intr.Arg] = []
for i, param in enumerate(inspection.parameters.values()):
if i == 0:
# first is self
@ -45,7 +57,7 @@ class _Method:
in_args.append(intr.Arg(annotation, intr.ArgDirection.IN, param.name))
in_signature += annotation
out_args = []
out_args: list[intr.Arg] = []
out_signature = parse_annotation(inspection.return_annotation)
if out_signature:
for type_ in get_signature_tree(out_signature).types:
@ -61,7 +73,7 @@ class _Method:
self.out_signature_tree = get_signature_tree(out_signature)
def method(name: Optional[str] = None, disabled: bool = False):
def method(name: str | None = None, disabled: bool = False) -> Callable:
"""A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus service method.
The parameters and return value must each be annotated with a signature
@ -99,9 +111,9 @@ def method(name: Optional[str] = None, disabled: bool = False):
if type(disabled) is not bool:
raise TypeError("disabled must be a bool")
def decorator(fn):
def decorator(fn: Callable) -> Callable:
@wraps(fn)
def wrapped(*args, **kwargs):
def wrapped(*args: Any, **kwargs: Any) -> None:
fn(*args, **kwargs)
fn_name = name if name else fn.__name__
@ -113,7 +125,7 @@ def method(name: Optional[str] = None, disabled: bool = False):
class _Signal:
def __init__(self, fn, name, disabled=False):
def __init__(self, fn: Callable, name: str, disabled: bool = False) -> None:
inspection = inspect.signature(fn)
args = []
@ -138,7 +150,7 @@ class _Signal:
self.introspection = intr.Signal(self.name, args)
def signal(name: Optional[str] = None, disabled: bool = False):
def signal(name: str | None = None, disabled: bool = False) -> Callable:
"""A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus signal.
The signal is broadcast on the bus when the decorated class method is
@ -173,12 +185,12 @@ def signal(name: Optional[str] = None, disabled: bool = False):
if type(disabled) is not bool:
raise TypeError("disabled must be a bool")
def decorator(fn):
def decorator(fn: Callable) -> Callable:
fn_name = name if name else fn.__name__
signal = _Signal(fn, fn_name, disabled)
@wraps(fn)
def wrapped(self, *args, **kwargs):
def wrapped(self, *args: Any, **kwargs: Any) -> Any:
if signal.disabled:
raise SignalDisabledError("Tried to call a disabled signal")
result = fn(self, *args, **kwargs)
@ -259,9 +271,9 @@ class _Property(property):
def dbus_property(
access: PropertyAccess = PropertyAccess.READWRITE,
name: Optional[str] = None,
name: str | None = None,
disabled: bool = False,
):
) -> Callable:
"""A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus property.
The class method must be a Python getter method with a return annotation
@ -306,7 +318,7 @@ def dbus_property(
if type(disabled) is not bool:
raise TypeError("disabled must be a bool")
def decorator(fn):
def decorator(fn: Callable) -> _Property:
options = {"name": name, "access": access, "disabled": disabled}
return _Property(fn, options=options)
@ -314,7 +326,7 @@ def dbus_property(
def _real_fn_result_to_body(
result: Optional[Any],
result: Any | None,
signature_tree: SignatureTree,
replace_fds: bool,
) -> tuple[list[Any], list[int]]:
@ -334,7 +346,8 @@ def _real_fn_result_to_body(
if out_len != len(final_result):
raise SignatureBodyMismatchError(
f"Signature and function return mismatch, expected {len(signature_tree.types)} arguments but got {len(result)}"
f"Signature and function return mismatch, expected "
f"{len(signature_tree.types)} arguments but got {len(result)}" # type: ignore[arg-type]
)
if not replace_fds:
@ -365,12 +378,12 @@ class ServiceInterface:
self.__methods: list[_Method] = []
self.__properties: list[_Property] = []
self.__signals: list[_Signal] = []
self.__buses = set()
self.__handlers: dict[
BaseMessageBus,
dict[_Method, Callable[[Message, Callable[[Message], None]], None]],
self.__buses: set[BaseMessageBus] = set()
self.__handlers: dict[BaseMessageBus, dict[_Method, HandlerType]] = {}
# Map of methods by bus of name -> method, handler
self.__handlers_by_name_signature: dict[
BaseMessageBus, dict[str, tuple[_Method, HandlerType]]
] = {}
for name, member in inspect.getmembers(type(self)):
member_dict = getattr(member, "__dict__", {})
if type(member) is _Property:
@ -405,7 +418,7 @@ class ServiceInterface:
def emit_properties_changed(
self, changed_properties: dict[str, Any], invalidated_properties: list[str] = []
):
) -> None:
"""Emit the ``org.freedesktop.DBus.Properties.PropertiesChanged`` signal.
This signal is intended to be used to alert clients when a property of
@ -464,58 +477,59 @@ class ServiceInterface:
)
@staticmethod
def _get_properties(interface: "ServiceInterface") -> list[_Property]:
def _get_properties(interface: ServiceInterface) -> list[_Property]:
return interface.__properties
@staticmethod
def _get_methods(interface: "ServiceInterface") -> list[_Method]:
def _get_methods(interface: ServiceInterface) -> list[_Method]:
return interface.__methods
@staticmethod
def _c_get_methods(interface: "ServiceInterface") -> list[_Method]:
# _c_get_methods is used by the C code to get the methods for an
# interface
# https://github.com/cython/cython/issues/3327
return interface.__methods
@staticmethod
def _get_signals(interface: "ServiceInterface") -> list[_Signal]:
def _get_signals(interface: ServiceInterface) -> list[_Signal]:
return interface.__signals
@staticmethod
def _get_buses(interface: "ServiceInterface") -> set["BaseMessageBus"]:
def _get_buses(interface: ServiceInterface) -> set[BaseMessageBus]:
return interface.__buses
@staticmethod
def _get_handler(
interface: "ServiceInterface", method: _Method, bus: "BaseMessageBus"
) -> Callable[[Message, Callable[[Message], None]], None]:
interface: ServiceInterface, method: _Method, bus: BaseMessageBus
) -> HandlerType:
return interface.__handlers[bus][method]
@staticmethod
def _c_get_handler(
interface: "ServiceInterface", method: _Method, bus: "BaseMessageBus"
) -> Callable[[Message, Callable[[Message], None]], None]:
# _c_get_handler is used by the C code to get the handler for a method
# https://github.com/cython/cython/issues/3327
return interface.__handlers[bus][method]
def _get_enabled_handler_by_name_signature(
interface: ServiceInterface,
bus: BaseMessageBus,
name: str_,
signature: str_,
) -> HandlerType | None:
handlers = interface.__handlers_by_name_signature[bus]
if (method_handler := handlers.get(name)) is None:
return None
method = method_handler[0]
if method.disabled:
return None
return method_handler[1] if method.in_signature == signature else None
@staticmethod
def _add_bus(
interface: "ServiceInterface",
bus: "BaseMessageBus",
maker: Callable[
["ServiceInterface", _Method],
Callable[[Message, Callable[[Message], None]], None],
],
interface: ServiceInterface,
bus: BaseMessageBus,
maker: Callable[[ServiceInterface, _Method], HandlerType],
) -> None:
interface.__buses.add(bus)
interface.__handlers[bus] = {
method: maker(interface, method) for method in interface.__methods
}
interface.__handlers_by_name_signature[bus] = {
method.name: (method, handler)
for method, handler in interface.__handlers[bus].items()
}
@staticmethod
def _remove_bus(interface: "ServiceInterface", bus: "BaseMessageBus") -> None:
def _remove_bus(interface: ServiceInterface, bus: BaseMessageBus) -> None:
interface.__buses.remove(bus)
del interface.__handlers[bus]
@ -538,7 +552,7 @@ class ServiceInterface:
@staticmethod
def _fn_result_to_body(
result: Optional[Any],
result: Any | None,
signature_tree: SignatureTree,
replace_fds: bool = True,
) -> tuple[list[Any], list[int]]:
@ -546,7 +560,7 @@ class ServiceInterface:
@staticmethod
def _c_fn_result_to_body(
result: Optional[Any],
result: Any | None,
signature_tree: SignatureTree,
replace_fds: bool,
) -> tuple[list[Any], list[int]]:
@ -558,7 +572,7 @@ class ServiceInterface:
@staticmethod
def _handle_signal(
interface: "ServiceInterface", signal: _Signal, result: Optional[Any]
interface: ServiceInterface, signal: _Signal, result: Any | None
) -> None:
body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree)
for bus in ServiceInterface._get_buses(interface):
@ -567,15 +581,19 @@ class ServiceInterface:
)
@staticmethod
def _get_property_value(interface: "ServiceInterface", prop: _Property, callback):
def _get_property_value(
interface: ServiceInterface,
prop: _Property,
callback: Callable[[ServiceInterface, _Property, Any, Exception | None], None],
) -> None:
# XXX MUST CHECK TYPE RETURNED BY GETTER
try:
if asyncio.iscoroutinefunction(prop.prop_getter):
task = asyncio.ensure_future(prop.prop_getter(interface))
task: asyncio.Task = asyncio.ensure_future(prop.prop_getter(interface))
def get_property_callback(task):
def get_property_callback(task_: asyncio.Task) -> None:
try:
result = task.result()
result = task_.result()
except Exception as e:
callback(interface, prop, None, e)
return
@ -592,15 +610,22 @@ class ServiceInterface:
callback(interface, prop, None, e)
@staticmethod
def _set_property_value(interface: "ServiceInterface", prop, value, callback):
def _set_property_value(
interface: ServiceInterface,
prop: _Property,
value: Any,
callback: Callable[[ServiceInterface, _Property, Exception | None], None],
) -> None:
# XXX MUST CHECK TYPE TO SET
try:
if asyncio.iscoroutinefunction(prop.prop_setter):
task = asyncio.ensure_future(prop.prop_setter(interface, value))
task: asyncio.Task = asyncio.ensure_future(
prop.prop_setter(interface, value)
)
def set_property_callback(task):
def set_property_callback(task_: asyncio.Task) -> None:
try:
task.result()
task_.result()
except Exception as e:
callback(interface, prop, e)
return
@ -617,9 +642,11 @@ class ServiceInterface:
@staticmethod
def _get_all_property_values(
interface: "ServiceInterface", callback, user_data=None
):
result = {}
interface: ServiceInterface,
callback: Callable[[ServiceInterface, Any, Any, Exception | None], None],
user_data: Any | None = None,
) -> None:
result: dict[str, Variant | None] = {}
result_error = None
for prop in ServiceInterface._get_properties(interface):
@ -632,10 +659,10 @@ class ServiceInterface:
return
def get_property_callback(
interface: "ServiceInterface",
interface: ServiceInterface,
prop: _Property,
value: Any,
e: Optional[Exception],
e: Exception | None,
) -> None:
nonlocal result_error
if e is not None:

View File

@ -28,9 +28,14 @@ async def test_export_unexport():
bus = await MessageBus().connect()
bus.export(export_path, interface)
with pytest.raises(ValueError):
# Already exported
bus.export(export_path, interface)
assert export_path in bus._path_exports
assert len(bus._path_exports[export_path]) == 1
assert bus._path_exports[export_path][0] is interface
assert bus._path_exports[export_path][interface.name] is interface
assert len(ServiceInterface._get_buses(interface)) == 1
bus.export(export_path2, interface2)
@ -60,11 +65,23 @@ async def test_export_unexport():
assert not bus._path_exports
assert not ServiceInterface._get_buses(interface)
# test unexporting by ServiceInterface
bus.export(export_path, interface)
bus.unexport(export_path, interface)
assert not bus._path_exports
assert not ServiceInterface._get_buses(interface)
with pytest.raises(TypeError):
bus.unexport(export_path, object())
node = bus._introspect_export_path("/path/doesnt/exist")
assert type(node) is intr.Node
assert not node.interfaces
assert not node.nodes
# Should to nothing
bus.unexport("/path/doesnt/exist", interface)
bus.disconnect()

View File

@ -149,6 +149,19 @@ async def test_peer_interface():
assert reply.message_type == MessageType.METHOD_RETURN, reply.body[0]
assert reply.signature == "s"
reply2 = await bus2.call(
Message(
destination=bus1.unique_name,
path="/path/doesnt/exist",
interface="org.freedesktop.DBus.Peer",
member="GetMachineId",
signature="",
)
)
assert reply2.message_type == MessageType.METHOD_RETURN, reply.body[0]
assert reply2.signature == "s"
bus1.disconnect()
bus2.disconnect()
@ -213,9 +226,9 @@ async def test_object_manager():
)
)
assert reply_root.signature == "a{oa{sa{sv}}}"
assert reply_level1.signature == "a{oa{sa{sv}}}"
assert reply_level2.signature == "a{oa{sa{sv}}}"
assert reply_root.signature == "a{oa{sa{sv}}}", reply_root
assert reply_level1.signature == "a{oa{sa{sv}}}", reply_level1
assert reply_level2.signature == "a{oa{sa{sv}}}", reply_level2
assert reply_level2.body == [{}]
assert reply_level1.body == [expected_reply]