feat: refactor service bus handler lookup to avoid linear searches (#400)

This commit is contained in:
J. Nick Koston
2025-03-05 13:28:00 -10:00
committed by GitHub
parent 640e1f8d87
commit 996659e1b5
9 changed files with 315 additions and 262 deletions

View File

@@ -37,7 +37,7 @@ class MessageFlag(IntFlag):
ALLOW_INTERACTIVE_AUTHORIZATION = 4 ALLOW_INTERACTIVE_AUTHORIZATION = 4
@cached_property @cached_property
def value(self) -> str: def value(self) -> int:
"""Return the value.""" """Return the value."""
return self._value_ return self._value_

View File

@@ -51,7 +51,7 @@ class Arg:
def __init__( def __init__(
self, self,
signature: Union[SignatureType, str], signature: Union[SignatureType, str],
direction: Optional[list[ArgDirection]] = None, direction: Optional[ArgDirection] = None,
name: Optional[str] = None, name: Optional[str] = None,
annotations: Optional[dict[str, str]] = None, annotations: Optional[dict[str, str]] = None,
): ):

View File

@@ -4,6 +4,7 @@ from ._private.address cimport get_bus_address, parse_address
from .message cimport Message from .message cimport Message
from .service cimport ServiceInterface, _Method from .service cimport ServiceInterface, _Method
cdef bint TYPE_CHECKING
cdef object MessageType cdef object MessageType
cdef object DBusError cdef object DBusError
@@ -39,24 +40,26 @@ cdef class BaseMessageBus:
cdef public object _high_level_client_initialized cdef public object _high_level_client_initialized
cdef public object _ProxyObject cdef public object _ProxyObject
cdef public object _machine_id cdef public object _machine_id
cdef public object _negotiate_unix_fd cdef public bint _negotiate_unix_fd
cdef public object _sock cdef public object _sock
cdef public object _stream cdef public object _stream
cdef public object _fd cdef public object _fd
cpdef _process_message(self, Message msg) cpdef void _process_message(self, Message msg)
@cython.locals(exported_service_interface=ServiceInterface)
cpdef export(self, str path, ServiceInterface interface)
@cython.locals( @cython.locals(
methods=cython.list, methods=cython.list,
method=_Method, method=_Method,
interface=ServiceInterface, interface=ServiceInterface,
interfaces=cython.list, interfaces=dict,
) )
cdef _find_message_handler(self, Message msg) cdef _find_message_handler(self, Message msg)
cdef _setup_socket(self) cdef _setup_socket(self)
@cython.locals(no_reply_expected=bint)
cpdef _call(self, Message msg, object callback) cpdef _call(self, Message msg, object callback)
cpdef next_serial(self) cpdef next_serial(self)

View File

@@ -1,10 +1,11 @@
from __future__ import annotations
import inspect import inspect
import logging import logging
import socket import socket
import traceback import traceback
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from functools import partial from functools import partial
from typing import Any, Callable, Optional, Union from typing import Any, Callable, TYPE_CHECKING
from . import introspection as intr from . import introspection as intr
from ._private.address import get_bus_address, parse_address from ._private.address import get_bus_address, parse_address
@@ -22,7 +23,7 @@ from .errors import DBusError, InvalidAddressError
from .message import Message from .message import Message
from .proxy_object import BaseProxyObject from .proxy_object import BaseProxyObject
from .send_reply import SendReply from .send_reply import SendReply
from .service import ServiceInterface, _Method from .service import ServiceInterface, _Method, _Property, HandlerType
from .signature import Variant from .signature import Variant
from .validators import assert_bus_name_valid, assert_object_path_valid from .validators import assert_bus_name_valid, assert_object_path_valid
@@ -119,12 +120,12 @@ class BaseMessageBus:
def __init__( def __init__(
self, self,
bus_address: Optional[str] = None, bus_address: str | None = None,
bus_type: BusType = BusType.SESSION, bus_type: BusType = BusType.SESSION,
ProxyObject: Optional[type[BaseProxyObject]] = None, ProxyObject: type[BaseProxyObject] | None = None,
negotiate_unix_fd: bool = False, negotiate_unix_fd: bool = False,
) -> None: ) -> None:
self.unique_name: Optional[str] = None self.unique_name: str | None = None
self._disconnected = False self._disconnected = False
self._negotiate_unix_fd = negotiate_unix_fd self._negotiate_unix_fd = negotiate_unix_fd
@@ -133,11 +134,11 @@ class BaseMessageBus:
self._user_disconnect = False self._user_disconnect = False
self._method_return_handlers: dict[ self._method_return_handlers: dict[
int, Callable[[Optional[Message], Optional[Exception]], None] int, Callable[[Message | None, Exception | None], None]
] = {} ] = {}
self._serial = 0 self._serial = 0
self._user_message_handlers: list[ self._user_message_handlers: list[
Callable[[Message], Union[Message, bool, None]] Callable[[Message], Message | bool | None]
] = [] ] = []
# the key is the name and the value is the unique name of the owner. # the key is the name and the value is the unique name of the owner.
# This cache is kept up to date by the NameOwnerChanged signal and is # This cache is kept up to date by the NameOwnerChanged signal and is
@@ -145,7 +146,7 @@ class BaseMessageBus:
# high level client only) # high level client only)
self._name_owners: dict[str, str] = {} self._name_owners: dict[str, str] = {}
# used for the high level service # used for the high level service
self._path_exports: dict[str, list[ServiceInterface]] = {} self._path_exports: dict[str, dict[str, ServiceInterface]] = {}
self._bus_address = ( self._bus_address = (
parse_address(bus_address) parse_address(bus_address)
if bus_address if bus_address
@@ -161,10 +162,10 @@ class BaseMessageBus:
self._ProxyObject = ProxyObject self._ProxyObject = ProxyObject
# machine id is lazy loaded # machine id is lazy loaded
self._machine_id: Optional[int] = None self._machine_id: int | None = None
self._sock: Optional[socket.socket] = None self._sock: socket.socket | None = None
self._fd: Optional[int] = None self._fd: int | None = None
self._stream: Optional[Any] = None self._stream: Any | None = None
self._setup_socket() self._setup_socket()
@@ -193,20 +194,18 @@ class BaseMessageBus:
raise TypeError("interface must be a ServiceInterface") raise TypeError("interface must be a ServiceInterface")
if path not in self._path_exports: if path not in self._path_exports:
self._path_exports[path] = [] self._path_exports[path] = {}
elif interface.name in self._path_exports[path]:
raise ValueError(
f'An interface with this name is already exported on this bus at path "{path}": "{interface.name}"'
)
for f in self._path_exports[path]: self._path_exports[path][interface.name] = interface
if f.name == interface.name:
raise ValueError(
f'An interface with this name is already exported on this bus at path "{path}": "{interface.name}"'
)
self._path_exports[path].append(interface)
ServiceInterface._add_bus(interface, self, self._make_method_handler) ServiceInterface._add_bus(interface, self, self._make_method_handler)
self._emit_interface_added(path, interface) self._emit_interface_added(path, interface)
def unexport( def unexport(
self, path: str, interface: Optional[Union[ServiceInterface, str]] = None self, path: str, interface: ServiceInterface | str | None = None
) -> None: ) -> None:
"""Unexport the path or service interface to make it no longer """Unexport the path or service interface to make it no longer
available to clients. available to clients.
@@ -222,45 +221,42 @@ class BaseMessageBus:
- :class:`InvalidObjectPathError <dbus_fast.InvalidObjectPathError>` - If the given object path is not valid. - :class:`InvalidObjectPathError <dbus_fast.InvalidObjectPathError>` - If the given object path is not valid.
""" """
assert_object_path_valid(path) assert_object_path_valid(path)
if type(interface) not in [str, type(None)] and not isinstance( interface_name: str | None
interface, ServiceInterface
):
raise TypeError("interface must be a ServiceInterface or interface name")
if path not in self._path_exports:
return
exports = self._path_exports[path]
if type(interface) is str:
try:
interface = next(iface for iface in exports if iface.name == interface)
except StopIteration:
return
removed_interfaces = []
if interface is None: if interface is None:
del self._path_exports[path] interface_name = None
for iface in filter(lambda e: not self._has_interface(e), exports): elif type(interface) is str:
removed_interfaces.append(iface.name) interface_name = interface
ServiceInterface._remove_bus(iface, self) elif isinstance(interface, ServiceInterface):
interface_name = interface.name
else: else:
for i, iface in enumerate(exports): raise TypeError(
if iface is interface: f"interface must be a ServiceInterface or interface name not {type(interface)}"
removed_interfaces.append(iface.name) )
del self._path_exports[path][i]
if not self._path_exports[path]: if (interfaces := self._path_exports.get(path)) is None:
del self._path_exports[path] return
if not self._has_interface(iface): removed_interface_names: list[str] = []
ServiceInterface._remove_bus(iface, self)
break if interface_name is not None:
self._emit_interface_removed(path, removed_interfaces) if (removed_interface := interfaces.pop(interface_name, None)) is None:
return
removed_interface_names.append(interface_name)
if not interfaces:
del self._path_exports[path]
ServiceInterface._remove_bus(removed_interface, self)
else:
del self._path_exports[path]
for removed_interface in interfaces.values():
removed_interface_names.append(removed_interface.name)
ServiceInterface._remove_bus(removed_interface, self)
self._emit_interface_removed(path, removed_interface_names)
def introspect( def introspect(
self, self,
bus_name: str, bus_name: str,
path: str, path: str,
callback: Callable[[Optional[intr.Node], Optional[Exception]], None], callback: Callable[[intr.Node | None, Exception | None], None],
check_callback_type: bool = True, check_callback_type: bool = True,
validate_property_names: bool = True, validate_property_names: bool = True,
) -> None: ) -> None:
@@ -289,12 +285,13 @@ class BaseMessageBus:
if check_callback_type: if check_callback_type:
BaseMessageBus._check_callback_type(callback) BaseMessageBus._check_callback_type(callback)
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None: def reply_notify(reply: Message | None, err: Exception | None) -> None:
try: try:
BaseMessageBus._check_method_return(reply, err, "s") BaseMessageBus._check_method_return(reply, err, "s")
result = intr.Node.parse( result = intr.Node.parse(
reply.body[0], validate_property_names=validate_property_names reply.body[0], # type: ignore[union-attr]
) # type: ignore[union-attr] validate_property_names=validate_property_names,
)
except Exception as e: except Exception as e:
callback(None, e) callback(None, e)
return return
@@ -330,7 +327,7 @@ class BaseMessageBus:
interface: ServiceInterface, interface: ServiceInterface,
result: Any, result: Any,
user_data: Any, user_data: Any,
e: Optional[Exception], e: Exception | None,
) -> None: ) -> None:
if e is not None: if e is not None:
try: try:
@@ -385,9 +382,8 @@ class BaseMessageBus:
self, self,
name: str, name: str,
flags: NameFlag = NameFlag.NONE, flags: NameFlag = NameFlag.NONE,
callback: Optional[ callback: None
Callable[[Optional[RequestNameReply], Optional[Exception]], None] | (Callable[[RequestNameReply | None, Exception | None], None]) = None,
] = None,
check_callback_type: bool = True, check_callback_type: bool = True,
) -> None: ) -> None:
"""Request that this message bus owns the given name. """Request that this message bus owns the given name.
@@ -424,24 +420,23 @@ class BaseMessageBus:
self._call(message, None) self._call(message, None)
return return
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None: def reply_notify(reply: Message | None, err: Exception | None) -> None:
try: try:
BaseMessageBus._check_method_return(reply, err, "u") BaseMessageBus._check_method_return(reply, err, "u")
result = RequestNameReply(reply.body[0]) # type: ignore[union-attr] result = RequestNameReply(reply.body[0]) # type: ignore[union-attr]
except Exception as e: except Exception as e:
callback(None, e) # type: ignore[misc] callback(None, e)
return return
callback(result, None) # type: ignore[misc] callback(result, None)
self._call(message, reply_notify) self._call(message, reply_notify)
def release_name( def release_name(
self, self,
name: str, name: str,
callback: Optional[ callback: None
Callable[[Optional[ReleaseNameReply], Optional[Exception]], None] | (Callable[[ReleaseNameReply | None, Exception | None], None]) = None,
] = None,
check_callback_type: bool = True, check_callback_type: bool = True,
) -> None: ) -> None:
"""Request that this message bus release the given name. """Request that this message bus release the given name.
@@ -474,20 +469,20 @@ class BaseMessageBus:
self._call(message, None) self._call(message, None)
return return
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None: def reply_notify(reply: Message | None, err: Exception | None) -> None:
try: try:
BaseMessageBus._check_method_return(reply, err, "u") BaseMessageBus._check_method_return(reply, err, "u")
result = ReleaseNameReply(reply.body[0]) # type: ignore[union-attr] result = ReleaseNameReply(reply.body[0]) # type: ignore[union-attr]
except Exception as e: except Exception as e:
callback(None, e) # type: ignore[misc] callback(None, e)
return return
callback(result, None) # type: ignore[misc] callback(result, None)
self._call(message, reply_notify) self._call(message, reply_notify)
def get_proxy_object( def get_proxy_object(
self, bus_name: str, path: str, introspection: Union[intr.Node, str, ET.Element] self, bus_name: str, path: str, introspection: intr.Node | str | ET.Element
) -> BaseProxyObject: ) -> BaseProxyObject:
"""Get a proxy object for the path exported on the bus that owns the """Get a proxy object for the path exported on the bus that owns the
name. The object is expected to export the interfaces and nodes name. The object is expected to export the interfaces and nodes
@@ -526,10 +521,11 @@ class BaseMessageBus:
All pending and future calls will error with a connection error. All pending and future calls will error with a connection error.
""" """
self._user_disconnect = True self._user_disconnect = True
try: if self._sock:
self._sock.shutdown(socket.SHUT_RDWR) try:
except Exception: self._sock.shutdown(socket.SHUT_RDWR)
logging.warning("could not shut down socket", exc_info=True) except Exception:
logging.warning("could not shut down socket", exc_info=True)
def next_serial(self) -> int: def next_serial(self) -> int:
"""Get the next serial for this bus. This can be used as the ``serial`` """Get the next serial for this bus. This can be used as the ``serial``
@@ -543,7 +539,7 @@ class BaseMessageBus:
return self._serial return self._serial
def add_message_handler( def add_message_handler(
self, handler: Callable[[Message], Optional[Union[Message, bool]]] self, handler: Callable[[Message], Message | bool | None]
) -> None: ) -> None:
"""Add a custom message handler for incoming messages. """Add a custom message handler for incoming messages.
@@ -568,7 +564,7 @@ class BaseMessageBus:
self._user_message_handlers.append(handler) self._user_message_handlers.append(handler)
def remove_message_handler( def remove_message_handler(
self, handler: Callable[[Message], Optional[Union[Message, bool]]] self, handler: Callable[[Message], Message | bool | None]
) -> None: ) -> None:
"""Remove a message handler that was previously added by """Remove a message handler that was previously added by
:func:`add_message_handler() :func:`add_message_handler()
@@ -592,7 +588,7 @@ class BaseMessageBus:
'the "send" method must be implemented in the inheriting class' 'the "send" method must be implemented in the inheriting class'
) )
def _finalize(self, err: Optional[Exception]) -> None: def _finalize(self, err: Exception | None) -> None:
"""should be called after the socket disconnects with the disconnection """should be called after the socket disconnects with the disconnection
error to clean up resources and put the bus in a disconnected state""" error to clean up resources and put the bus in a disconnected state"""
if self._disconnected: if self._disconnected:
@@ -615,14 +611,6 @@ class BaseMessageBus:
self._user_message_handlers.clear() self._user_message_handlers.clear()
def _has_interface(self, interface: ServiceInterface) -> bool:
for _, exports in self._path_exports.items():
for iface in exports:
if iface is interface:
return True
return False
def _interface_signal_notify( def _interface_signal_notify(
self, self,
interface: ServiceInterface, interface: ServiceInterface,
@@ -632,9 +620,9 @@ class BaseMessageBus:
body: list[Any], body: list[Any],
unix_fds: list[int] = [], unix_fds: list[int] = [],
) -> None: ) -> None:
path = None path: str | None = None
for p, ifaces in self._path_exports.items(): for p, ifaces in self._path_exports.items():
for i in ifaces: for i in ifaces.values():
if i is interface: if i is interface:
path = p path = p
@@ -657,9 +645,9 @@ class BaseMessageBus:
def _introspect_export_path(self, path: str) -> intr.Node: def _introspect_export_path(self, path: str) -> intr.Node:
assert_object_path_valid(path) assert_object_path_valid(path)
if path in self._path_exports: if (interfaces := self._path_exports.get(path)) is not None:
node = intr.Node.default(path) node = intr.Node.default(path)
for interface in self._path_exports[path]: for interface in interfaces.values():
node.interfaces.append(interface.introspect()) node.interfaces.append(interface.introspect())
else: else:
node = intr.Node(path) node = intr.Node(path)
@@ -687,7 +675,7 @@ class BaseMessageBus:
err = None err = None
for transport, options in self._bus_address: for transport, options in self._bus_address:
filename = None filename: bytes | str | None = None
ip_addr = "" ip_addr = ""
ip_port = 0 ip_port = 0
@@ -738,9 +726,9 @@ class BaseMessageBus:
def _reply_notify( def _reply_notify(
self, self,
msg: Message, msg: Message,
callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]], callback: Callable[[Message | None, Exception | None], None],
reply: Optional[Message], reply: Message | None,
err: Optional[Exception], err: Exception | None,
) -> None: ) -> None:
"""Callback on reply.""" """Callback on reply."""
if reply and msg.destination and reply.sender: if reply and msg.destination and reply.sender:
@@ -750,24 +738,23 @@ class BaseMessageBus:
def _call( def _call(
self, self,
msg: Message, msg: Message,
callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]], callback: Callable[[Message | None, Exception | None], None] | None,
) -> None: ) -> None:
if not msg.serial: if not msg.serial:
msg.serial = self.next_serial() msg.serial = self.next_serial()
reply_expected = _expects_reply(msg)
# Make sure the return reply handler is installed # Make sure the return reply handler is installed
# before sending the message to avoid a race condition # before sending the message to avoid a race condition
# where the reply is lost in case the backend can # where the reply is lost in case the backend can
# send it right away. # send it right away.
if reply_expected: if (reply_expected := _expects_reply(msg)) and callback is not None:
self._method_return_handlers[msg.serial] = partial( self._method_return_handlers[msg.serial] = partial(
self._reply_notify, msg, callback self._reply_notify, msg, callback
) )
self.send(msg) self.send(msg)
if not reply_expected: if not reply_expected and callback is not None:
callback(None, None) callback(None, None)
@staticmethod @staticmethod
@@ -785,7 +772,7 @@ class BaseMessageBus:
@staticmethod @staticmethod
def _check_method_return( def _check_method_return(
msg: Optional[Message], err: Optional[Exception], signature: str msg: Message | None, err: Exception | None, signature: str
) -> None: ) -> None:
if err: if err:
raise err raise err
@@ -809,8 +796,7 @@ class BaseMessageBus:
handled = False handled = False
for user_handler in self._user_message_handlers: for user_handler in self._user_message_handlers:
try: try:
result = user_handler(msg) if result := user_handler(msg):
if result:
if type(result) is Message: if type(result) is Message:
self.send(result) self.send(result)
handled = True handled = True
@@ -842,8 +828,8 @@ class BaseMessageBus:
and msg.path == "/org/freedesktop/DBus" and msg.path == "/org/freedesktop/DBus"
and msg.interface == "org.freedesktop.DBus" and msg.interface == "org.freedesktop.DBus"
): ):
[name, old_owner, new_owner] = msg.body name = msg.body[0]
if new_owner: if new_owner := msg.body[2]:
self._name_owners[name] = new_owner self._name_owners[name] = new_owner
elif name in self._name_owners: elif name in self._name_owners:
del self._name_owners[name] del self._name_owners[name]
@@ -852,9 +838,9 @@ class BaseMessageBus:
if msg.message_type is MESSAGE_TYPE_CALL: if msg.message_type is MESSAGE_TYPE_CALL:
if not handled: if not handled:
handler = self._find_message_handler(msg) handler = self._find_message_handler(msg)
if _expects_reply(msg) is False: if not _expects_reply(msg):
if handler: if handler:
handler(msg, BLOCK_UNEXPECTED_REPLY) handler(msg, BLOCK_UNEXPECTED_REPLY) # type: ignore[arg-type]
else: else:
_LOGGER.error( _LOGGER.error(
'"%s.%s" with signature "%s" could not be found', '"%s.%s" with signature "%s" could not be found',
@@ -891,17 +877,17 @@ class BaseMessageBus:
interface: ServiceInterface, interface: ServiceInterface,
method: _Method, method: _Method,
msg: Message, msg: Message,
send_reply: Callable[[Message], None], send_reply: SendReply,
) -> None: ) -> None:
"""This is the callback that will be called when a method call is.""" """This is the callback that will be called when a method call is."""
args = ServiceInterface._c_msg_body_to_args(msg) if msg.unix_fds else msg.body args = ServiceInterface._c_msg_body_to_args(msg) if msg.unix_fds else msg.body
result = method.fn(interface, *args) result = method.fn(interface, *args)
if send_reply is BLOCK_UNEXPECTED_REPLY or _expects_reply(msg) is False: if send_reply is BLOCK_UNEXPECTED_REPLY or _expects_reply(msg) is False:
return return
body, fds = ServiceInterface._c_fn_result_to_body( body_fds = ServiceInterface._c_fn_result_to_body(
result, result,
signature_tree=method.out_signature_tree, method.out_signature_tree,
replace_fds=self._negotiate_unix_fd, self._negotiate_unix_fd,
) )
send_reply( send_reply(
Message( Message(
@@ -909,19 +895,20 @@ class BaseMessageBus:
reply_serial=msg.serial, reply_serial=msg.serial,
destination=msg.sender, destination=msg.sender,
signature=method.out_signature, signature=method.out_signature,
body=body, body=body_fds[0],
unix_fds=fds, unix_fds=body_fds[1],
) )
) )
def _make_method_handler( def _make_method_handler(
self, interface: ServiceInterface, method: _Method self, interface: ServiceInterface, method: _Method
) -> Callable[[Message, Callable[[Message], None]], None]: ) -> HandlerType:
return partial(self._callback_method_handler, interface, method) return partial(self._callback_method_handler, interface, method)
def _find_message_handler( def _find_message_handler(self, msg: _Message) -> HandlerType | None:
self, msg: _Message if TYPE_CHECKING:
) -> Optional[Callable[[Message, Callable[[Message], None]], None]]: assert msg.interface is not None
if "org.freedesktop.DBus." in msg.interface: if "org.freedesktop.DBus." in msg.interface:
if ( if (
msg.interface == "org.freedesktop.DBus.Introspectable" msg.interface == "org.freedesktop.DBus.Introspectable"
@@ -945,54 +932,52 @@ class BaseMessageBus:
): ):
return self._default_get_managed_objects_handler return self._default_get_managed_objects_handler
msg_path = msg.path if (
if msg_path: msg.path is not None
interfaces = self._path_exports.get(msg_path) and msg.member is not None
if not interfaces: and (interfaces := self._path_exports.get(msg.path)) is not None
return None and (interface := interfaces.get(msg.interface)) is not None
for interface in interfaces: and (
methods = ServiceInterface._c_get_methods(interface) handler := ServiceInterface._get_enabled_handler_by_name_signature(
for method in methods: interface, self, msg.member, msg.signature
if method.disabled: )
continue )
is not None
if ( ):
msg.interface == interface.name return handler
and msg.member == method.name
and msg.signature == method.in_signature
):
return ServiceInterface._c_get_handler(interface, method, self)
return None return None
def _default_introspect_handler( def _default_introspect_handler(self, msg: Message, send_reply: SendReply) -> None:
self, msg: Message, send_reply: Callable[[Message], None] if TYPE_CHECKING:
) -> None: assert msg.path is not None
introspection = self._introspect_export_path(msg.path).tostring() introspection = self._introspect_export_path(msg.path).tostring()
send_reply(Message.new_method_return(msg, "s", [introspection])) send_reply(Message.new_method_return(msg, "s", [introspection]))
def _default_ping_handler( def _default_ping_handler(self, msg: Message, send_reply: SendReply) -> None:
self, msg: Message, send_reply: Callable[[Message], None]
) -> None:
send_reply(Message.new_method_return(msg)) send_reply(Message.new_method_return(msg))
def _send_machine_id_reply(self, msg: Message, send_reply: SendReply) -> None:
send_reply(Message.new_method_return(msg, "s", [self._machine_id]))
def _default_get_machine_id_handler( def _default_get_machine_id_handler(
self, msg: Message, send_reply: Callable[[Message], None] self, msg: Message, send_reply: SendReply
) -> None: ) -> None:
if self._machine_id: if self._machine_id:
send_reply(Message.new_method_return(msg, "s", self._machine_id)) self._send_machine_id_reply(msg, send_reply)
return return
def reply_handler(reply, err): def reply_handler(reply: Message | None, err: Exception | None) -> None:
if err: if err or reply is None:
# the bus has been disconnected, cannot send a reply # the bus has been disconnected, cannot send a reply
return return
if reply.message_type == MessageType.METHOD_RETURN: if reply.message_type == MessageType.METHOD_RETURN:
self._machine_id = reply.body[0] self._machine_id = reply.body[0]
send_reply(Message.new_method_return(msg, "s", [self._machine_id])) self._send_machine_id_reply(msg, send_reply)
elif reply.message_type == MessageType.ERROR: elif (
send_reply(Message.new_error(msg, reply.error_name, reply.body)) reply.message_type == MessageType.ERROR and reply.error_name is not None
):
send_reply(Message.new_error(msg, reply.error_name, str(reply.body)))
else: else:
send_reply( send_reply(
Message.new_error(msg, ErrorType.FAILED, "could not get machine_id") Message.new_error(msg, ErrorType.FAILED, "could not get machine_id")
@@ -1009,13 +994,12 @@ class BaseMessageBus:
) )
def _default_get_managed_objects_handler( def _default_get_managed_objects_handler(
self, msg: Message, send_reply: Callable[[Message], None] self, msg: Message, send_reply: SendReply
) -> None: ) -> None:
result = {}
result_signature = "a{oa{sa{sv}}}" result_signature = "a{oa{sa{sv}}}"
error_handled = False error_handled = False
def is_result_complete(): def is_result_complete() -> bool:
if not result: if not result:
return True return True
for n, interfaces in result.items(): for n, interfaces in result.items():
@@ -1025,6 +1009,9 @@ class BaseMessageBus:
return True return True
if TYPE_CHECKING:
assert msg.path is not None
nodes = [ nodes = [
node node
for node in self._path_exports for node in self._path_exports
@@ -1032,16 +1019,18 @@ class BaseMessageBus:
] ]
# first build up the result object to know when it's complete # first build up the result object to know when it's complete
for node in nodes: result: dict[str, dict[str, Any]] = {
result[node] = {} node: {interface: None for interface in self._path_exports[node]}
for interface in self._path_exports[node]: for node in nodes
result[node][interface.name] = None }
if is_result_complete(): if is_result_complete():
send_reply(Message.new_method_return(msg, result_signature, [result])) send_reply(Message.new_method_return(msg, result_signature, [result]))
return return
def get_all_properties_callback(interface, values, node, err): def get_all_properties_callback(
interface: ServiceInterface, values: Any, node: str, err: Exception | None
) -> None:
nonlocal error_handled nonlocal error_handled
if err is not None: if err is not None:
if not error_handled: if not error_handled:
@@ -1055,14 +1044,12 @@ class BaseMessageBus:
send_reply(Message.new_method_return(msg, result_signature, [result])) send_reply(Message.new_method_return(msg, result_signature, [result]))
for node in nodes: for node in nodes:
for interface in self._path_exports[node]: for interface in self._path_exports[node].values():
ServiceInterface._get_all_property_values( ServiceInterface._get_all_property_values(
interface, get_all_properties_callback, node interface, get_all_properties_callback, node
) )
def _default_properties_handler( def _default_properties_handler(self, msg: Message, send_reply: SendReply) -> None:
self, msg: Message, send_reply: Callable[[Message], None]
) -> None:
methods = {"Get": "ss", "Set": "ssv", "GetAll": "s"} methods = {"Get": "ss", "Set": "ssv", "GetAll": "s"}
if msg.member not in methods or methods[msg.member] != msg.signature: if msg.member not in methods or methods[msg.member] != msg.signature:
raise DBusError( raise DBusError(
@@ -1082,12 +1069,7 @@ class BaseMessageBus:
ErrorType.UNKNOWN_OBJECT, f'no interfaces at path: "{msg.path}"' ErrorType.UNKNOWN_OBJECT, f'no interfaces at path: "{msg.path}"'
) )
match = [ if (interface := self._path_exports[msg.path].get(interface_name)) is None:
iface
for iface in self._path_exports[msg.path]
if iface.name == interface_name
]
if not match:
if interface_name in [ if interface_name in [
"org.freedesktop.DBus.Properties", "org.freedesktop.DBus.Properties",
"org.freedesktop.DBus.Introspectable", "org.freedesktop.DBus.Introspectable",
@@ -1111,7 +1093,6 @@ class BaseMessageBus:
f'could not find an interface "{interface_name}" at path: "{msg.path}"', f'could not find an interface "{interface_name}" at path: "{msg.path}"',
) )
interface = match[0]
properties = ServiceInterface._get_properties(interface) properties = ServiceInterface._get_properties(interface)
if msg.member == "Get" or msg.member == "Set": if msg.member == "Get" or msg.member == "Set":
@@ -1135,7 +1116,12 @@ class BaseMessageBus:
"the property does not have read access", "the property does not have read access",
) )
def get_property_callback(interface, prop, prop_value, err): def get_property_callback(
interface: ServiceInterface,
prop: _Property,
prop_value: Any,
err: Exception | None,
) -> None:
try: try:
if err is not None: if err is not None:
send_reply.send_error(err) send_reply.send_error(err)
@@ -1173,7 +1159,9 @@ class BaseMessageBus:
) )
assert prop.prop_setter assert prop.prop_setter
def set_property_callback(interface, prop, err): def set_property_callback(
interface: ServiceInterface, prop: _Property, err: Exception | None
) -> None:
if err is not None: if err is not None:
send_reply.send_error(err) send_reply.send_error(err)
return return
@@ -1188,7 +1176,12 @@ class BaseMessageBus:
elif msg.member == "GetAll": elif msg.member == "GetAll":
def get_all_properties_callback(interface, values, user_data, err): def get_all_properties_callback(
interface: ServiceInterface,
values: Any,
user_data: Any,
err: Exception | None,
) -> None:
if err is not None: if err is not None:
send_reply.send_error(err) send_reply.send_error(err)
return return
@@ -1212,12 +1205,12 @@ class BaseMessageBus:
return return
self._high_level_client_initialized = True self._high_level_client_initialized = True
def add_match_notify(msg, err): def add_match_notify(msg: Message | None, err: Exception | None) -> None:
if err: if err:
logging.error( logging.error(
f'add match request failed. match="{self._name_owner_match_rule}", {err}' f'add match request failed. match="{self._name_owner_match_rule}", {err}'
) )
elif msg.message_type == MessageType.ERROR: elif msg is not None and msg.message_type == MessageType.ERROR:
logging.error( logging.error(
f'add match request failed. match="{self._name_owner_match_rule}", {msg.body[0]}' f'add match request failed. match="{self._name_owner_match_rule}", {msg.body[0]}'
) )
@@ -1234,7 +1227,7 @@ class BaseMessageBus:
add_match_notify, add_match_notify,
) )
def _add_match_rule(self, match_rule): def _add_match_rule(self, match_rule: str) -> None:
"""Add a match rule. Match rules added by this function are refcounted """Add a match rule. Match rules added by this function are refcounted
and must be removed by _remove_match_rule(). This is for use in the and must be removed by _remove_match_rule(). This is for use in the
high level client only.""" high level client only."""
@@ -1247,10 +1240,10 @@ class BaseMessageBus:
self._match_rules[match_rule] = 1 self._match_rules[match_rule] = 1
def add_match_notify(msg: Message, err: Optional[Exception]) -> None: def add_match_notify(msg: Message | None, err: Exception | None) -> None:
if err: if err:
logging.error(f'add match request failed. match="{match_rule}", {err}') logging.error(f'add match request failed. match="{match_rule}", {err}')
elif msg.message_type == MessageType.ERROR: elif msg is not None and msg.message_type == MessageType.ERROR:
logging.error( logging.error(
f'add match request failed. match="{match_rule}", {msg.body[0]}' f'add match request failed. match="{match_rule}", {msg.body[0]}'
) )
@@ -1267,7 +1260,7 @@ class BaseMessageBus:
add_match_notify, add_match_notify,
) )
def _remove_match_rule(self, match_rule): def _remove_match_rule(self, match_rule: str) -> None:
"""Remove a match rule added with _add_match_rule(). This is for use in """Remove a match rule added with _add_match_rule(). This is for use in
the high level client only.""" the high level client only."""
if match_rule == self._name_owner_match_rule: if match_rule == self._name_owner_match_rule:
@@ -1280,7 +1273,7 @@ class BaseMessageBus:
del self._match_rules[match_rule] del self._match_rules[match_rule]
def remove_match_notify(msg, err): def remove_match_notify(msg: Message | None, err: Exception | None) -> None:
if self._disconnected: if self._disconnected:
return return
@@ -1288,7 +1281,7 @@ class BaseMessageBus:
logging.error( logging.error(
f'remove match request failed. match="{match_rule}", {err}' f'remove match request failed. match="{match_rule}", {err}'
) )
elif msg.message_type == MessageType.ERROR: elif msg is not None and msg.message_type == MessageType.ERROR:
logging.error( logging.error(
f'remove match request failed. match="{match_rule}", {msg.body[0]}' f'remove match request failed. match="{match_rule}", {msg.body[0]}'
) )

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import traceback import traceback
from types import TracebackType from types import TracebackType
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
from .constants import ErrorType from .constants import ErrorType
from .errors import DBusError from .errors import DBusError
@@ -15,12 +16,12 @@ class SendReply:
__slots__ = ("_bus", "_msg") __slots__ = ("_bus", "_msg")
def __init__(self, bus: "BaseMessageBus", msg: Message) -> None: def __init__(self, bus: BaseMessageBus, msg: Message) -> None:
"""Create a new reply context manager.""" """Create a new reply context manager."""
self._bus = bus self._bus = bus
self._msg = msg self._msg = msg
def __enter__(self): def __enter__(self) -> SendReply:
return self return self
def __call__(self, reply: Message) -> None: def __call__(self, reply: Message) -> None:
@@ -28,9 +29,9 @@ class SendReply:
def _exit( def _exit(
self, self,
exc_type: Optional[type[Exception]], exc_type: type[Exception] | None,
exc_value: Optional[Exception], exc_value: Exception | None,
tb: Optional[TracebackType], tb: TracebackType | None,
) -> bool: ) -> bool:
if exc_value: if exc_value:
if isinstance(exc_value, DBusError): if isinstance(exc_value, DBusError):
@@ -49,9 +50,9 @@ class SendReply:
def __exit__( def __exit__(
self, self,
exc_type: Optional[type[Exception]], exc_type: type[Exception] | None,
exc_value: Optional[Exception], exc_value: Exception | None,
tb: Optional[TracebackType], tb: TracebackType | None,
) -> bool: ) -> bool:
return self._exit(exc_type, exc_value, tb) return self._exit(exc_type, exc_value, tb)

View File

@@ -33,12 +33,11 @@ cdef class ServiceInterface:
cdef list __signals cdef list __signals
cdef set __buses cdef set __buses
cdef dict __handlers cdef dict __handlers
cdef dict __enabled_handlers_by_name_signature
@cython.locals(handlers=dict,in_signature=str,method=_Method)
@staticmethod @staticmethod
cdef list _c_get_methods(ServiceInterface interface) cdef object _get_enabled_handler_by_name_signature(ServiceInterface interface, object bus, object name, object signature)
@staticmethod
cdef object _c_get_handler(ServiceInterface interface, _Method method, object bus)
@staticmethod @staticmethod
cdef list _c_msg_body_to_args(Message msg) cdef list _c_msg_body_to_args(Message msg)

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
import asyncio import asyncio
import copy import copy
import inspect import inspect
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Callable, Protocol
from . import introspection as intr from . import introspection as intr
from ._private.util import ( from ._private.util import (
@@ -20,19 +21,30 @@ from .signature import (
Variant, Variant,
get_signature_tree, get_signature_tree,
) )
from .send_reply import SendReply
if TYPE_CHECKING: if TYPE_CHECKING:
from .message_bus import BaseMessageBus from .message_bus import BaseMessageBus
str_ = str
HandlerType = Callable[[Message, SendReply], None]
class _MethodCallbackProtocol(Protocol):
def __call__(self, interface: ServiceInterface, *args: Any) -> Any: ...
class _Method: class _Method:
def __init__(self, fn, name: str, disabled=False): def __init__(
self, fn: _MethodCallbackProtocol, name: str, disabled: bool = False
) -> None:
in_signature = "" in_signature = ""
out_signature = "" out_signature = ""
inspection = inspect.signature(fn) inspection = inspect.signature(fn)
in_args = [] in_args: list[intr.Arg] = []
for i, param in enumerate(inspection.parameters.values()): for i, param in enumerate(inspection.parameters.values()):
if i == 0: if i == 0:
# first is self # first is self
@@ -45,7 +57,7 @@ class _Method:
in_args.append(intr.Arg(annotation, intr.ArgDirection.IN, param.name)) in_args.append(intr.Arg(annotation, intr.ArgDirection.IN, param.name))
in_signature += annotation in_signature += annotation
out_args = [] out_args: list[intr.Arg] = []
out_signature = parse_annotation(inspection.return_annotation) out_signature = parse_annotation(inspection.return_annotation)
if out_signature: if out_signature:
for type_ in get_signature_tree(out_signature).types: for type_ in get_signature_tree(out_signature).types:
@@ -61,7 +73,7 @@ class _Method:
self.out_signature_tree = get_signature_tree(out_signature) self.out_signature_tree = get_signature_tree(out_signature)
def method(name: Optional[str] = None, disabled: bool = False): def method(name: str | None = None, disabled: bool = False) -> Callable:
"""A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus service method. """A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus service method.
The parameters and return value must each be annotated with a signature The parameters and return value must each be annotated with a signature
@@ -99,9 +111,9 @@ def method(name: Optional[str] = None, disabled: bool = False):
if type(disabled) is not bool: if type(disabled) is not bool:
raise TypeError("disabled must be a bool") raise TypeError("disabled must be a bool")
def decorator(fn): def decorator(fn: Callable) -> Callable:
@wraps(fn) @wraps(fn)
def wrapped(*args, **kwargs): def wrapped(*args: Any, **kwargs: Any) -> None:
fn(*args, **kwargs) fn(*args, **kwargs)
fn_name = name if name else fn.__name__ fn_name = name if name else fn.__name__
@@ -113,7 +125,7 @@ def method(name: Optional[str] = None, disabled: bool = False):
class _Signal: class _Signal:
def __init__(self, fn, name, disabled=False): def __init__(self, fn: Callable, name: str, disabled: bool = False) -> None:
inspection = inspect.signature(fn) inspection = inspect.signature(fn)
args = [] args = []
@@ -138,7 +150,7 @@ class _Signal:
self.introspection = intr.Signal(self.name, args) self.introspection = intr.Signal(self.name, args)
def signal(name: Optional[str] = None, disabled: bool = False): def signal(name: str | None = None, disabled: bool = False) -> Callable:
"""A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus signal. """A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus signal.
The signal is broadcast on the bus when the decorated class method is The signal is broadcast on the bus when the decorated class method is
@@ -173,12 +185,12 @@ def signal(name: Optional[str] = None, disabled: bool = False):
if type(disabled) is not bool: if type(disabled) is not bool:
raise TypeError("disabled must be a bool") raise TypeError("disabled must be a bool")
def decorator(fn): def decorator(fn: Callable) -> Callable:
fn_name = name if name else fn.__name__ fn_name = name if name else fn.__name__
signal = _Signal(fn, fn_name, disabled) signal = _Signal(fn, fn_name, disabled)
@wraps(fn) @wraps(fn)
def wrapped(self, *args, **kwargs): def wrapped(self, *args: Any, **kwargs: Any) -> Any:
if signal.disabled: if signal.disabled:
raise SignalDisabledError("Tried to call a disabled signal") raise SignalDisabledError("Tried to call a disabled signal")
result = fn(self, *args, **kwargs) result = fn(self, *args, **kwargs)
@@ -259,9 +271,9 @@ class _Property(property):
def dbus_property( def dbus_property(
access: PropertyAccess = PropertyAccess.READWRITE, access: PropertyAccess = PropertyAccess.READWRITE,
name: Optional[str] = None, name: str | None = None,
disabled: bool = False, disabled: bool = False,
): ) -> Callable:
"""A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus property. """A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus property.
The class method must be a Python getter method with a return annotation The class method must be a Python getter method with a return annotation
@@ -306,7 +318,7 @@ def dbus_property(
if type(disabled) is not bool: if type(disabled) is not bool:
raise TypeError("disabled must be a bool") raise TypeError("disabled must be a bool")
def decorator(fn): def decorator(fn: Callable) -> _Property:
options = {"name": name, "access": access, "disabled": disabled} options = {"name": name, "access": access, "disabled": disabled}
return _Property(fn, options=options) return _Property(fn, options=options)
@@ -314,7 +326,7 @@ def dbus_property(
def _real_fn_result_to_body( def _real_fn_result_to_body(
result: Optional[Any], result: Any | None,
signature_tree: SignatureTree, signature_tree: SignatureTree,
replace_fds: bool, replace_fds: bool,
) -> tuple[list[Any], list[int]]: ) -> tuple[list[Any], list[int]]:
@@ -334,7 +346,8 @@ def _real_fn_result_to_body(
if out_len != len(final_result): if out_len != len(final_result):
raise SignatureBodyMismatchError( raise SignatureBodyMismatchError(
f"Signature and function return mismatch, expected {len(signature_tree.types)} arguments but got {len(result)}" f"Signature and function return mismatch, expected "
f"{len(signature_tree.types)} arguments but got {len(result)}" # type: ignore[arg-type]
) )
if not replace_fds: if not replace_fds:
@@ -365,12 +378,12 @@ class ServiceInterface:
self.__methods: list[_Method] = [] self.__methods: list[_Method] = []
self.__properties: list[_Property] = [] self.__properties: list[_Property] = []
self.__signals: list[_Signal] = [] self.__signals: list[_Signal] = []
self.__buses = set() self.__buses: set[BaseMessageBus] = set()
self.__handlers: dict[ self.__handlers: dict[BaseMessageBus, dict[_Method, HandlerType]] = {}
BaseMessageBus, # Map of methods by bus of name -> method, handler
dict[_Method, Callable[[Message, Callable[[Message], None]], None]], self.__handlers_by_name_signature: dict[
BaseMessageBus, dict[str, tuple[_Method, HandlerType]]
] = {} ] = {}
for name, member in inspect.getmembers(type(self)): for name, member in inspect.getmembers(type(self)):
member_dict = getattr(member, "__dict__", {}) member_dict = getattr(member, "__dict__", {})
if type(member) is _Property: if type(member) is _Property:
@@ -405,7 +418,7 @@ class ServiceInterface:
def emit_properties_changed( def emit_properties_changed(
self, changed_properties: dict[str, Any], invalidated_properties: list[str] = [] self, changed_properties: dict[str, Any], invalidated_properties: list[str] = []
): ) -> None:
"""Emit the ``org.freedesktop.DBus.Properties.PropertiesChanged`` signal. """Emit the ``org.freedesktop.DBus.Properties.PropertiesChanged`` signal.
This signal is intended to be used to alert clients when a property of This signal is intended to be used to alert clients when a property of
@@ -464,58 +477,59 @@ class ServiceInterface:
) )
@staticmethod @staticmethod
def _get_properties(interface: "ServiceInterface") -> list[_Property]: def _get_properties(interface: ServiceInterface) -> list[_Property]:
return interface.__properties return interface.__properties
@staticmethod @staticmethod
def _get_methods(interface: "ServiceInterface") -> list[_Method]: def _get_methods(interface: ServiceInterface) -> list[_Method]:
return interface.__methods return interface.__methods
@staticmethod @staticmethod
def _c_get_methods(interface: "ServiceInterface") -> list[_Method]: def _get_signals(interface: ServiceInterface) -> list[_Signal]:
# _c_get_methods is used by the C code to get the methods for an
# interface
# https://github.com/cython/cython/issues/3327
return interface.__methods
@staticmethod
def _get_signals(interface: "ServiceInterface") -> list[_Signal]:
return interface.__signals return interface.__signals
@staticmethod @staticmethod
def _get_buses(interface: "ServiceInterface") -> set["BaseMessageBus"]: def _get_buses(interface: ServiceInterface) -> set[BaseMessageBus]:
return interface.__buses return interface.__buses
@staticmethod @staticmethod
def _get_handler( def _get_handler(
interface: "ServiceInterface", method: _Method, bus: "BaseMessageBus" interface: ServiceInterface, method: _Method, bus: BaseMessageBus
) -> Callable[[Message, Callable[[Message], None]], None]: ) -> HandlerType:
return interface.__handlers[bus][method] return interface.__handlers[bus][method]
@staticmethod @staticmethod
def _c_get_handler( def _get_enabled_handler_by_name_signature(
interface: "ServiceInterface", method: _Method, bus: "BaseMessageBus" interface: ServiceInterface,
) -> Callable[[Message, Callable[[Message], None]], None]: bus: BaseMessageBus,
# _c_get_handler is used by the C code to get the handler for a method name: str_,
# https://github.com/cython/cython/issues/3327 signature: str_,
return interface.__handlers[bus][method] ) -> HandlerType | None:
handlers = interface.__handlers_by_name_signature[bus]
if (method_handler := handlers.get(name)) is None:
return None
method = method_handler[0]
if method.disabled:
return None
return method_handler[1] if method.in_signature == signature else None
@staticmethod @staticmethod
def _add_bus( def _add_bus(
interface: "ServiceInterface", interface: ServiceInterface,
bus: "BaseMessageBus", bus: BaseMessageBus,
maker: Callable[ maker: Callable[[ServiceInterface, _Method], HandlerType],
["ServiceInterface", _Method],
Callable[[Message, Callable[[Message], None]], None],
],
) -> None: ) -> None:
interface.__buses.add(bus) interface.__buses.add(bus)
interface.__handlers[bus] = { interface.__handlers[bus] = {
method: maker(interface, method) for method in interface.__methods method: maker(interface, method) for method in interface.__methods
} }
interface.__handlers_by_name_signature[bus] = {
method.name: (method, handler)
for method, handler in interface.__handlers[bus].items()
}
@staticmethod @staticmethod
def _remove_bus(interface: "ServiceInterface", bus: "BaseMessageBus") -> None: def _remove_bus(interface: ServiceInterface, bus: BaseMessageBus) -> None:
interface.__buses.remove(bus) interface.__buses.remove(bus)
del interface.__handlers[bus] del interface.__handlers[bus]
@@ -538,7 +552,7 @@ class ServiceInterface:
@staticmethod @staticmethod
def _fn_result_to_body( def _fn_result_to_body(
result: Optional[Any], result: Any | None,
signature_tree: SignatureTree, signature_tree: SignatureTree,
replace_fds: bool = True, replace_fds: bool = True,
) -> tuple[list[Any], list[int]]: ) -> tuple[list[Any], list[int]]:
@@ -546,7 +560,7 @@ class ServiceInterface:
@staticmethod @staticmethod
def _c_fn_result_to_body( def _c_fn_result_to_body(
result: Optional[Any], result: Any | None,
signature_tree: SignatureTree, signature_tree: SignatureTree,
replace_fds: bool, replace_fds: bool,
) -> tuple[list[Any], list[int]]: ) -> tuple[list[Any], list[int]]:
@@ -558,7 +572,7 @@ class ServiceInterface:
@staticmethod @staticmethod
def _handle_signal( def _handle_signal(
interface: "ServiceInterface", signal: _Signal, result: Optional[Any] interface: ServiceInterface, signal: _Signal, result: Any | None
) -> None: ) -> None:
body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree) body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree)
for bus in ServiceInterface._get_buses(interface): for bus in ServiceInterface._get_buses(interface):
@@ -567,15 +581,19 @@ class ServiceInterface:
) )
@staticmethod @staticmethod
def _get_property_value(interface: "ServiceInterface", prop: _Property, callback): def _get_property_value(
interface: ServiceInterface,
prop: _Property,
callback: Callable[[ServiceInterface, _Property, Any, Exception | None], None],
) -> None:
# XXX MUST CHECK TYPE RETURNED BY GETTER # XXX MUST CHECK TYPE RETURNED BY GETTER
try: try:
if asyncio.iscoroutinefunction(prop.prop_getter): if asyncio.iscoroutinefunction(prop.prop_getter):
task = asyncio.ensure_future(prop.prop_getter(interface)) task: asyncio.Task = asyncio.ensure_future(prop.prop_getter(interface))
def get_property_callback(task): def get_property_callback(task_: asyncio.Task) -> None:
try: try:
result = task.result() result = task_.result()
except Exception as e: except Exception as e:
callback(interface, prop, None, e) callback(interface, prop, None, e)
return return
@@ -592,15 +610,22 @@ class ServiceInterface:
callback(interface, prop, None, e) callback(interface, prop, None, e)
@staticmethod @staticmethod
def _set_property_value(interface: "ServiceInterface", prop, value, callback): def _set_property_value(
interface: ServiceInterface,
prop: _Property,
value: Any,
callback: Callable[[ServiceInterface, _Property, Exception | None], None],
) -> None:
# XXX MUST CHECK TYPE TO SET # XXX MUST CHECK TYPE TO SET
try: try:
if asyncio.iscoroutinefunction(prop.prop_setter): if asyncio.iscoroutinefunction(prop.prop_setter):
task = asyncio.ensure_future(prop.prop_setter(interface, value)) task: asyncio.Task = asyncio.ensure_future(
prop.prop_setter(interface, value)
)
def set_property_callback(task): def set_property_callback(task_: asyncio.Task) -> None:
try: try:
task.result() task_.result()
except Exception as e: except Exception as e:
callback(interface, prop, e) callback(interface, prop, e)
return return
@@ -617,9 +642,11 @@ class ServiceInterface:
@staticmethod @staticmethod
def _get_all_property_values( def _get_all_property_values(
interface: "ServiceInterface", callback, user_data=None interface: ServiceInterface,
): callback: Callable[[ServiceInterface, Any, Any, Exception | None], None],
result = {} user_data: Any | None = None,
) -> None:
result: dict[str, Variant | None] = {}
result_error = None result_error = None
for prop in ServiceInterface._get_properties(interface): for prop in ServiceInterface._get_properties(interface):
@@ -632,10 +659,10 @@ class ServiceInterface:
return return
def get_property_callback( def get_property_callback(
interface: "ServiceInterface", interface: ServiceInterface,
prop: _Property, prop: _Property,
value: Any, value: Any,
e: Optional[Exception], e: Exception | None,
) -> None: ) -> None:
nonlocal result_error nonlocal result_error
if e is not None: if e is not None:

View File

@@ -28,9 +28,14 @@ async def test_export_unexport():
bus = await MessageBus().connect() bus = await MessageBus().connect()
bus.export(export_path, interface) bus.export(export_path, interface)
with pytest.raises(ValueError):
# Already exported
bus.export(export_path, interface)
assert export_path in bus._path_exports assert export_path in bus._path_exports
assert len(bus._path_exports[export_path]) == 1 assert len(bus._path_exports[export_path]) == 1
assert bus._path_exports[export_path][0] is interface assert bus._path_exports[export_path][interface.name] is interface
assert len(ServiceInterface._get_buses(interface)) == 1 assert len(ServiceInterface._get_buses(interface)) == 1
bus.export(export_path2, interface2) bus.export(export_path2, interface2)
@@ -60,11 +65,23 @@ async def test_export_unexport():
assert not bus._path_exports assert not bus._path_exports
assert not ServiceInterface._get_buses(interface) assert not ServiceInterface._get_buses(interface)
# test unexporting by ServiceInterface
bus.export(export_path, interface)
bus.unexport(export_path, interface)
assert not bus._path_exports
assert not ServiceInterface._get_buses(interface)
with pytest.raises(TypeError):
bus.unexport(export_path, object())
node = bus._introspect_export_path("/path/doesnt/exist") node = bus._introspect_export_path("/path/doesnt/exist")
assert type(node) is intr.Node assert type(node) is intr.Node
assert not node.interfaces assert not node.interfaces
assert not node.nodes assert not node.nodes
# Should to nothing
bus.unexport("/path/doesnt/exist", interface)
bus.disconnect() bus.disconnect()

View File

@@ -149,6 +149,19 @@ async def test_peer_interface():
assert reply.message_type == MessageType.METHOD_RETURN, reply.body[0] assert reply.message_type == MessageType.METHOD_RETURN, reply.body[0]
assert reply.signature == "s" assert reply.signature == "s"
reply2 = await bus2.call(
Message(
destination=bus1.unique_name,
path="/path/doesnt/exist",
interface="org.freedesktop.DBus.Peer",
member="GetMachineId",
signature="",
)
)
assert reply2.message_type == MessageType.METHOD_RETURN, reply.body[0]
assert reply2.signature == "s"
bus1.disconnect() bus1.disconnect()
bus2.disconnect() bus2.disconnect()
@@ -213,9 +226,9 @@ async def test_object_manager():
) )
) )
assert reply_root.signature == "a{oa{sa{sv}}}" assert reply_root.signature == "a{oa{sa{sv}}}", reply_root
assert reply_level1.signature == "a{oa{sa{sv}}}" assert reply_level1.signature == "a{oa{sa{sv}}}", reply_level1
assert reply_level2.signature == "a{oa{sa{sv}}}" assert reply_level2.signature == "a{oa{sa{sv}}}", reply_level2
assert reply_level2.body == [{}] assert reply_level2.body == [{}]
assert reply_level1.body == [expected_reply] assert reply_level1.body == [expected_reply]