feat: refactor service bus handler lookup to avoid linear searches (#400)
This commit is contained in:
parent
640e1f8d87
commit
996659e1b5
@ -37,7 +37,7 @@ class MessageFlag(IntFlag):
|
||||
ALLOW_INTERACTIVE_AUTHORIZATION = 4
|
||||
|
||||
@cached_property
|
||||
def value(self) -> str:
|
||||
def value(self) -> int:
|
||||
"""Return the value."""
|
||||
return self._value_
|
||||
|
||||
|
||||
@ -51,7 +51,7 @@ class Arg:
|
||||
def __init__(
|
||||
self,
|
||||
signature: Union[SignatureType, str],
|
||||
direction: Optional[list[ArgDirection]] = None,
|
||||
direction: Optional[ArgDirection] = None,
|
||||
name: Optional[str] = None,
|
||||
annotations: Optional[dict[str, str]] = None,
|
||||
):
|
||||
|
||||
@ -4,6 +4,7 @@ from ._private.address cimport get_bus_address, parse_address
|
||||
from .message cimport Message
|
||||
from .service cimport ServiceInterface, _Method
|
||||
|
||||
cdef bint TYPE_CHECKING
|
||||
|
||||
cdef object MessageType
|
||||
cdef object DBusError
|
||||
@ -39,24 +40,26 @@ cdef class BaseMessageBus:
|
||||
cdef public object _high_level_client_initialized
|
||||
cdef public object _ProxyObject
|
||||
cdef public object _machine_id
|
||||
cdef public object _negotiate_unix_fd
|
||||
cdef public bint _negotiate_unix_fd
|
||||
cdef public object _sock
|
||||
cdef public object _stream
|
||||
cdef public object _fd
|
||||
|
||||
cpdef _process_message(self, Message msg)
|
||||
cpdef void _process_message(self, Message msg)
|
||||
|
||||
@cython.locals(exported_service_interface=ServiceInterface)
|
||||
cpdef export(self, str path, ServiceInterface interface)
|
||||
|
||||
@cython.locals(
|
||||
methods=cython.list,
|
||||
method=_Method,
|
||||
interface=ServiceInterface,
|
||||
interfaces=cython.list,
|
||||
interfaces=dict,
|
||||
)
|
||||
cdef _find_message_handler(self, Message msg)
|
||||
|
||||
cdef _setup_socket(self)
|
||||
|
||||
@cython.locals(no_reply_expected=bint)
|
||||
cpdef _call(self, Message msg, object callback)
|
||||
|
||||
cpdef next_serial(self)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import inspect
|
||||
import logging
|
||||
import socket
|
||||
import traceback
|
||||
import xml.etree.ElementTree as ET
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
from . import introspection as intr
|
||||
from ._private.address import get_bus_address, parse_address
|
||||
@ -22,7 +23,7 @@ from .errors import DBusError, InvalidAddressError
|
||||
from .message import Message
|
||||
from .proxy_object import BaseProxyObject
|
||||
from .send_reply import SendReply
|
||||
from .service import ServiceInterface, _Method
|
||||
from .service import ServiceInterface, _Method, _Property, HandlerType
|
||||
from .signature import Variant
|
||||
from .validators import assert_bus_name_valid, assert_object_path_valid
|
||||
|
||||
@ -119,12 +120,12 @@ class BaseMessageBus:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bus_address: Optional[str] = None,
|
||||
bus_address: str | None = None,
|
||||
bus_type: BusType = BusType.SESSION,
|
||||
ProxyObject: Optional[type[BaseProxyObject]] = None,
|
||||
ProxyObject: type[BaseProxyObject] | None = None,
|
||||
negotiate_unix_fd: bool = False,
|
||||
) -> None:
|
||||
self.unique_name: Optional[str] = None
|
||||
self.unique_name: str | None = None
|
||||
self._disconnected = False
|
||||
self._negotiate_unix_fd = negotiate_unix_fd
|
||||
|
||||
@ -133,11 +134,11 @@ class BaseMessageBus:
|
||||
self._user_disconnect = False
|
||||
|
||||
self._method_return_handlers: dict[
|
||||
int, Callable[[Optional[Message], Optional[Exception]], None]
|
||||
int, Callable[[Message | None, Exception | None], None]
|
||||
] = {}
|
||||
self._serial = 0
|
||||
self._user_message_handlers: list[
|
||||
Callable[[Message], Union[Message, bool, None]]
|
||||
Callable[[Message], Message | bool | None]
|
||||
] = []
|
||||
# the key is the name and the value is the unique name of the owner.
|
||||
# This cache is kept up to date by the NameOwnerChanged signal and is
|
||||
@ -145,7 +146,7 @@ class BaseMessageBus:
|
||||
# high level client only)
|
||||
self._name_owners: dict[str, str] = {}
|
||||
# used for the high level service
|
||||
self._path_exports: dict[str, list[ServiceInterface]] = {}
|
||||
self._path_exports: dict[str, dict[str, ServiceInterface]] = {}
|
||||
self._bus_address = (
|
||||
parse_address(bus_address)
|
||||
if bus_address
|
||||
@ -161,10 +162,10 @@ class BaseMessageBus:
|
||||
self._ProxyObject = ProxyObject
|
||||
|
||||
# machine id is lazy loaded
|
||||
self._machine_id: Optional[int] = None
|
||||
self._sock: Optional[socket.socket] = None
|
||||
self._fd: Optional[int] = None
|
||||
self._stream: Optional[Any] = None
|
||||
self._machine_id: int | None = None
|
||||
self._sock: socket.socket | None = None
|
||||
self._fd: int | None = None
|
||||
self._stream: Any | None = None
|
||||
|
||||
self._setup_socket()
|
||||
|
||||
@ -193,20 +194,18 @@ class BaseMessageBus:
|
||||
raise TypeError("interface must be a ServiceInterface")
|
||||
|
||||
if path not in self._path_exports:
|
||||
self._path_exports[path] = []
|
||||
self._path_exports[path] = {}
|
||||
elif interface.name in self._path_exports[path]:
|
||||
raise ValueError(
|
||||
f'An interface with this name is already exported on this bus at path "{path}": "{interface.name}"'
|
||||
)
|
||||
|
||||
for f in self._path_exports[path]:
|
||||
if f.name == interface.name:
|
||||
raise ValueError(
|
||||
f'An interface with this name is already exported on this bus at path "{path}": "{interface.name}"'
|
||||
)
|
||||
|
||||
self._path_exports[path].append(interface)
|
||||
self._path_exports[path][interface.name] = interface
|
||||
ServiceInterface._add_bus(interface, self, self._make_method_handler)
|
||||
self._emit_interface_added(path, interface)
|
||||
|
||||
def unexport(
|
||||
self, path: str, interface: Optional[Union[ServiceInterface, str]] = None
|
||||
self, path: str, interface: ServiceInterface | str | None = None
|
||||
) -> None:
|
||||
"""Unexport the path or service interface to make it no longer
|
||||
available to clients.
|
||||
@ -222,45 +221,42 @@ class BaseMessageBus:
|
||||
- :class:`InvalidObjectPathError <dbus_fast.InvalidObjectPathError>` - If the given object path is not valid.
|
||||
"""
|
||||
assert_object_path_valid(path)
|
||||
if type(interface) not in [str, type(None)] and not isinstance(
|
||||
interface, ServiceInterface
|
||||
):
|
||||
raise TypeError("interface must be a ServiceInterface or interface name")
|
||||
|
||||
if path not in self._path_exports:
|
||||
return
|
||||
|
||||
exports = self._path_exports[path]
|
||||
|
||||
if type(interface) is str:
|
||||
try:
|
||||
interface = next(iface for iface in exports if iface.name == interface)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
removed_interfaces = []
|
||||
interface_name: str | None
|
||||
if interface is None:
|
||||
del self._path_exports[path]
|
||||
for iface in filter(lambda e: not self._has_interface(e), exports):
|
||||
removed_interfaces.append(iface.name)
|
||||
ServiceInterface._remove_bus(iface, self)
|
||||
interface_name = None
|
||||
elif type(interface) is str:
|
||||
interface_name = interface
|
||||
elif isinstance(interface, ServiceInterface):
|
||||
interface_name = interface.name
|
||||
else:
|
||||
for i, iface in enumerate(exports):
|
||||
if iface is interface:
|
||||
removed_interfaces.append(iface.name)
|
||||
del self._path_exports[path][i]
|
||||
if not self._path_exports[path]:
|
||||
del self._path_exports[path]
|
||||
if not self._has_interface(iface):
|
||||
ServiceInterface._remove_bus(iface, self)
|
||||
break
|
||||
self._emit_interface_removed(path, removed_interfaces)
|
||||
raise TypeError(
|
||||
f"interface must be a ServiceInterface or interface name not {type(interface)}"
|
||||
)
|
||||
|
||||
if (interfaces := self._path_exports.get(path)) is None:
|
||||
return
|
||||
removed_interface_names: list[str] = []
|
||||
|
||||
if interface_name is not None:
|
||||
if (removed_interface := interfaces.pop(interface_name, None)) is None:
|
||||
return
|
||||
removed_interface_names.append(interface_name)
|
||||
if not interfaces:
|
||||
del self._path_exports[path]
|
||||
ServiceInterface._remove_bus(removed_interface, self)
|
||||
else:
|
||||
del self._path_exports[path]
|
||||
for removed_interface in interfaces.values():
|
||||
removed_interface_names.append(removed_interface.name)
|
||||
ServiceInterface._remove_bus(removed_interface, self)
|
||||
|
||||
self._emit_interface_removed(path, removed_interface_names)
|
||||
|
||||
def introspect(
|
||||
self,
|
||||
bus_name: str,
|
||||
path: str,
|
||||
callback: Callable[[Optional[intr.Node], Optional[Exception]], None],
|
||||
callback: Callable[[intr.Node | None, Exception | None], None],
|
||||
check_callback_type: bool = True,
|
||||
validate_property_names: bool = True,
|
||||
) -> None:
|
||||
@ -289,12 +285,13 @@ class BaseMessageBus:
|
||||
if check_callback_type:
|
||||
BaseMessageBus._check_callback_type(callback)
|
||||
|
||||
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
|
||||
def reply_notify(reply: Message | None, err: Exception | None) -> None:
|
||||
try:
|
||||
BaseMessageBus._check_method_return(reply, err, "s")
|
||||
result = intr.Node.parse(
|
||||
reply.body[0], validate_property_names=validate_property_names
|
||||
) # type: ignore[union-attr]
|
||||
reply.body[0], # type: ignore[union-attr]
|
||||
validate_property_names=validate_property_names,
|
||||
)
|
||||
except Exception as e:
|
||||
callback(None, e)
|
||||
return
|
||||
@ -330,7 +327,7 @@ class BaseMessageBus:
|
||||
interface: ServiceInterface,
|
||||
result: Any,
|
||||
user_data: Any,
|
||||
e: Optional[Exception],
|
||||
e: Exception | None,
|
||||
) -> None:
|
||||
if e is not None:
|
||||
try:
|
||||
@ -385,9 +382,8 @@ class BaseMessageBus:
|
||||
self,
|
||||
name: str,
|
||||
flags: NameFlag = NameFlag.NONE,
|
||||
callback: Optional[
|
||||
Callable[[Optional[RequestNameReply], Optional[Exception]], None]
|
||||
] = None,
|
||||
callback: None
|
||||
| (Callable[[RequestNameReply | None, Exception | None], None]) = None,
|
||||
check_callback_type: bool = True,
|
||||
) -> None:
|
||||
"""Request that this message bus owns the given name.
|
||||
@ -424,24 +420,23 @@ class BaseMessageBus:
|
||||
self._call(message, None)
|
||||
return
|
||||
|
||||
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
|
||||
def reply_notify(reply: Message | None, err: Exception | None) -> None:
|
||||
try:
|
||||
BaseMessageBus._check_method_return(reply, err, "u")
|
||||
result = RequestNameReply(reply.body[0]) # type: ignore[union-attr]
|
||||
except Exception as e:
|
||||
callback(None, e) # type: ignore[misc]
|
||||
callback(None, e)
|
||||
return
|
||||
|
||||
callback(result, None) # type: ignore[misc]
|
||||
callback(result, None)
|
||||
|
||||
self._call(message, reply_notify)
|
||||
|
||||
def release_name(
|
||||
self,
|
||||
name: str,
|
||||
callback: Optional[
|
||||
Callable[[Optional[ReleaseNameReply], Optional[Exception]], None]
|
||||
] = None,
|
||||
callback: None
|
||||
| (Callable[[ReleaseNameReply | None, Exception | None], None]) = None,
|
||||
check_callback_type: bool = True,
|
||||
) -> None:
|
||||
"""Request that this message bus release the given name.
|
||||
@ -474,20 +469,20 @@ class BaseMessageBus:
|
||||
self._call(message, None)
|
||||
return
|
||||
|
||||
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
|
||||
def reply_notify(reply: Message | None, err: Exception | None) -> None:
|
||||
try:
|
||||
BaseMessageBus._check_method_return(reply, err, "u")
|
||||
result = ReleaseNameReply(reply.body[0]) # type: ignore[union-attr]
|
||||
except Exception as e:
|
||||
callback(None, e) # type: ignore[misc]
|
||||
callback(None, e)
|
||||
return
|
||||
|
||||
callback(result, None) # type: ignore[misc]
|
||||
callback(result, None)
|
||||
|
||||
self._call(message, reply_notify)
|
||||
|
||||
def get_proxy_object(
|
||||
self, bus_name: str, path: str, introspection: Union[intr.Node, str, ET.Element]
|
||||
self, bus_name: str, path: str, introspection: intr.Node | str | ET.Element
|
||||
) -> BaseProxyObject:
|
||||
"""Get a proxy object for the path exported on the bus that owns the
|
||||
name. The object is expected to export the interfaces and nodes
|
||||
@ -526,10 +521,11 @@ class BaseMessageBus:
|
||||
All pending and future calls will error with a connection error.
|
||||
"""
|
||||
self._user_disconnect = True
|
||||
try:
|
||||
self._sock.shutdown(socket.SHUT_RDWR)
|
||||
except Exception:
|
||||
logging.warning("could not shut down socket", exc_info=True)
|
||||
if self._sock:
|
||||
try:
|
||||
self._sock.shutdown(socket.SHUT_RDWR)
|
||||
except Exception:
|
||||
logging.warning("could not shut down socket", exc_info=True)
|
||||
|
||||
def next_serial(self) -> int:
|
||||
"""Get the next serial for this bus. This can be used as the ``serial``
|
||||
@ -543,7 +539,7 @@ class BaseMessageBus:
|
||||
return self._serial
|
||||
|
||||
def add_message_handler(
|
||||
self, handler: Callable[[Message], Optional[Union[Message, bool]]]
|
||||
self, handler: Callable[[Message], Message | bool | None]
|
||||
) -> None:
|
||||
"""Add a custom message handler for incoming messages.
|
||||
|
||||
@ -568,7 +564,7 @@ class BaseMessageBus:
|
||||
self._user_message_handlers.append(handler)
|
||||
|
||||
def remove_message_handler(
|
||||
self, handler: Callable[[Message], Optional[Union[Message, bool]]]
|
||||
self, handler: Callable[[Message], Message | bool | None]
|
||||
) -> None:
|
||||
"""Remove a message handler that was previously added by
|
||||
:func:`add_message_handler()
|
||||
@ -592,7 +588,7 @@ class BaseMessageBus:
|
||||
'the "send" method must be implemented in the inheriting class'
|
||||
)
|
||||
|
||||
def _finalize(self, err: Optional[Exception]) -> None:
|
||||
def _finalize(self, err: Exception | None) -> None:
|
||||
"""should be called after the socket disconnects with the disconnection
|
||||
error to clean up resources and put the bus in a disconnected state"""
|
||||
if self._disconnected:
|
||||
@ -615,14 +611,6 @@ class BaseMessageBus:
|
||||
|
||||
self._user_message_handlers.clear()
|
||||
|
||||
def _has_interface(self, interface: ServiceInterface) -> bool:
|
||||
for _, exports in self._path_exports.items():
|
||||
for iface in exports:
|
||||
if iface is interface:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _interface_signal_notify(
|
||||
self,
|
||||
interface: ServiceInterface,
|
||||
@ -632,9 +620,9 @@ class BaseMessageBus:
|
||||
body: list[Any],
|
||||
unix_fds: list[int] = [],
|
||||
) -> None:
|
||||
path = None
|
||||
path: str | None = None
|
||||
for p, ifaces in self._path_exports.items():
|
||||
for i in ifaces:
|
||||
for i in ifaces.values():
|
||||
if i is interface:
|
||||
path = p
|
||||
|
||||
@ -657,9 +645,9 @@ class BaseMessageBus:
|
||||
def _introspect_export_path(self, path: str) -> intr.Node:
|
||||
assert_object_path_valid(path)
|
||||
|
||||
if path in self._path_exports:
|
||||
if (interfaces := self._path_exports.get(path)) is not None:
|
||||
node = intr.Node.default(path)
|
||||
for interface in self._path_exports[path]:
|
||||
for interface in interfaces.values():
|
||||
node.interfaces.append(interface.introspect())
|
||||
else:
|
||||
node = intr.Node(path)
|
||||
@ -687,7 +675,7 @@ class BaseMessageBus:
|
||||
err = None
|
||||
|
||||
for transport, options in self._bus_address:
|
||||
filename = None
|
||||
filename: bytes | str | None = None
|
||||
ip_addr = ""
|
||||
ip_port = 0
|
||||
|
||||
@ -738,9 +726,9 @@ class BaseMessageBus:
|
||||
def _reply_notify(
|
||||
self,
|
||||
msg: Message,
|
||||
callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]],
|
||||
reply: Optional[Message],
|
||||
err: Optional[Exception],
|
||||
callback: Callable[[Message | None, Exception | None], None],
|
||||
reply: Message | None,
|
||||
err: Exception | None,
|
||||
) -> None:
|
||||
"""Callback on reply."""
|
||||
if reply and msg.destination and reply.sender:
|
||||
@ -750,24 +738,23 @@ class BaseMessageBus:
|
||||
def _call(
|
||||
self,
|
||||
msg: Message,
|
||||
callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]],
|
||||
callback: Callable[[Message | None, Exception | None], None] | None,
|
||||
) -> None:
|
||||
if not msg.serial:
|
||||
msg.serial = self.next_serial()
|
||||
|
||||
reply_expected = _expects_reply(msg)
|
||||
# Make sure the return reply handler is installed
|
||||
# before sending the message to avoid a race condition
|
||||
# where the reply is lost in case the backend can
|
||||
# send it right away.
|
||||
if reply_expected:
|
||||
if (reply_expected := _expects_reply(msg)) and callback is not None:
|
||||
self._method_return_handlers[msg.serial] = partial(
|
||||
self._reply_notify, msg, callback
|
||||
)
|
||||
|
||||
self.send(msg)
|
||||
|
||||
if not reply_expected:
|
||||
if not reply_expected and callback is not None:
|
||||
callback(None, None)
|
||||
|
||||
@staticmethod
|
||||
@ -785,7 +772,7 @@ class BaseMessageBus:
|
||||
|
||||
@staticmethod
|
||||
def _check_method_return(
|
||||
msg: Optional[Message], err: Optional[Exception], signature: str
|
||||
msg: Message | None, err: Exception | None, signature: str
|
||||
) -> None:
|
||||
if err:
|
||||
raise err
|
||||
@ -809,8 +796,7 @@ class BaseMessageBus:
|
||||
handled = False
|
||||
for user_handler in self._user_message_handlers:
|
||||
try:
|
||||
result = user_handler(msg)
|
||||
if result:
|
||||
if result := user_handler(msg):
|
||||
if type(result) is Message:
|
||||
self.send(result)
|
||||
handled = True
|
||||
@ -842,8 +828,8 @@ class BaseMessageBus:
|
||||
and msg.path == "/org/freedesktop/DBus"
|
||||
and msg.interface == "org.freedesktop.DBus"
|
||||
):
|
||||
[name, old_owner, new_owner] = msg.body
|
||||
if new_owner:
|
||||
name = msg.body[0]
|
||||
if new_owner := msg.body[2]:
|
||||
self._name_owners[name] = new_owner
|
||||
elif name in self._name_owners:
|
||||
del self._name_owners[name]
|
||||
@ -852,9 +838,9 @@ class BaseMessageBus:
|
||||
if msg.message_type is MESSAGE_TYPE_CALL:
|
||||
if not handled:
|
||||
handler = self._find_message_handler(msg)
|
||||
if _expects_reply(msg) is False:
|
||||
if not _expects_reply(msg):
|
||||
if handler:
|
||||
handler(msg, BLOCK_UNEXPECTED_REPLY)
|
||||
handler(msg, BLOCK_UNEXPECTED_REPLY) # type: ignore[arg-type]
|
||||
else:
|
||||
_LOGGER.error(
|
||||
'"%s.%s" with signature "%s" could not be found',
|
||||
@ -891,17 +877,17 @@ class BaseMessageBus:
|
||||
interface: ServiceInterface,
|
||||
method: _Method,
|
||||
msg: Message,
|
||||
send_reply: Callable[[Message], None],
|
||||
send_reply: SendReply,
|
||||
) -> None:
|
||||
"""This is the callback that will be called when a method call is."""
|
||||
args = ServiceInterface._c_msg_body_to_args(msg) if msg.unix_fds else msg.body
|
||||
result = method.fn(interface, *args)
|
||||
if send_reply is BLOCK_UNEXPECTED_REPLY or _expects_reply(msg) is False:
|
||||
return
|
||||
body, fds = ServiceInterface._c_fn_result_to_body(
|
||||
body_fds = ServiceInterface._c_fn_result_to_body(
|
||||
result,
|
||||
signature_tree=method.out_signature_tree,
|
||||
replace_fds=self._negotiate_unix_fd,
|
||||
method.out_signature_tree,
|
||||
self._negotiate_unix_fd,
|
||||
)
|
||||
send_reply(
|
||||
Message(
|
||||
@ -909,19 +895,20 @@ class BaseMessageBus:
|
||||
reply_serial=msg.serial,
|
||||
destination=msg.sender,
|
||||
signature=method.out_signature,
|
||||
body=body,
|
||||
unix_fds=fds,
|
||||
body=body_fds[0],
|
||||
unix_fds=body_fds[1],
|
||||
)
|
||||
)
|
||||
|
||||
def _make_method_handler(
|
||||
self, interface: ServiceInterface, method: _Method
|
||||
) -> Callable[[Message, Callable[[Message], None]], None]:
|
||||
) -> HandlerType:
|
||||
return partial(self._callback_method_handler, interface, method)
|
||||
|
||||
def _find_message_handler(
|
||||
self, msg: _Message
|
||||
) -> Optional[Callable[[Message, Callable[[Message], None]], None]]:
|
||||
def _find_message_handler(self, msg: _Message) -> HandlerType | None:
|
||||
if TYPE_CHECKING:
|
||||
assert msg.interface is not None
|
||||
|
||||
if "org.freedesktop.DBus." in msg.interface:
|
||||
if (
|
||||
msg.interface == "org.freedesktop.DBus.Introspectable"
|
||||
@ -945,54 +932,52 @@ class BaseMessageBus:
|
||||
):
|
||||
return self._default_get_managed_objects_handler
|
||||
|
||||
msg_path = msg.path
|
||||
if msg_path:
|
||||
interfaces = self._path_exports.get(msg_path)
|
||||
if not interfaces:
|
||||
return None
|
||||
for interface in interfaces:
|
||||
methods = ServiceInterface._c_get_methods(interface)
|
||||
for method in methods:
|
||||
if method.disabled:
|
||||
continue
|
||||
|
||||
if (
|
||||
msg.interface == interface.name
|
||||
and msg.member == method.name
|
||||
and msg.signature == method.in_signature
|
||||
):
|
||||
return ServiceInterface._c_get_handler(interface, method, self)
|
||||
|
||||
if (
|
||||
msg.path is not None
|
||||
and msg.member is not None
|
||||
and (interfaces := self._path_exports.get(msg.path)) is not None
|
||||
and (interface := interfaces.get(msg.interface)) is not None
|
||||
and (
|
||||
handler := ServiceInterface._get_enabled_handler_by_name_signature(
|
||||
interface, self, msg.member, msg.signature
|
||||
)
|
||||
)
|
||||
is not None
|
||||
):
|
||||
return handler
|
||||
return None
|
||||
|
||||
def _default_introspect_handler(
|
||||
self, msg: Message, send_reply: Callable[[Message], None]
|
||||
) -> None:
|
||||
def _default_introspect_handler(self, msg: Message, send_reply: SendReply) -> None:
|
||||
if TYPE_CHECKING:
|
||||
assert msg.path is not None
|
||||
introspection = self._introspect_export_path(msg.path).tostring()
|
||||
send_reply(Message.new_method_return(msg, "s", [introspection]))
|
||||
|
||||
def _default_ping_handler(
|
||||
self, msg: Message, send_reply: Callable[[Message], None]
|
||||
) -> None:
|
||||
def _default_ping_handler(self, msg: Message, send_reply: SendReply) -> None:
|
||||
send_reply(Message.new_method_return(msg))
|
||||
|
||||
def _send_machine_id_reply(self, msg: Message, send_reply: SendReply) -> None:
|
||||
send_reply(Message.new_method_return(msg, "s", [self._machine_id]))
|
||||
|
||||
def _default_get_machine_id_handler(
|
||||
self, msg: Message, send_reply: Callable[[Message], None]
|
||||
self, msg: Message, send_reply: SendReply
|
||||
) -> None:
|
||||
if self._machine_id:
|
||||
send_reply(Message.new_method_return(msg, "s", self._machine_id))
|
||||
self._send_machine_id_reply(msg, send_reply)
|
||||
return
|
||||
|
||||
def reply_handler(reply, err):
|
||||
if err:
|
||||
def reply_handler(reply: Message | None, err: Exception | None) -> None:
|
||||
if err or reply is None:
|
||||
# the bus has been disconnected, cannot send a reply
|
||||
return
|
||||
|
||||
if reply.message_type == MessageType.METHOD_RETURN:
|
||||
self._machine_id = reply.body[0]
|
||||
send_reply(Message.new_method_return(msg, "s", [self._machine_id]))
|
||||
elif reply.message_type == MessageType.ERROR:
|
||||
send_reply(Message.new_error(msg, reply.error_name, reply.body))
|
||||
self._send_machine_id_reply(msg, send_reply)
|
||||
elif (
|
||||
reply.message_type == MessageType.ERROR and reply.error_name is not None
|
||||
):
|
||||
send_reply(Message.new_error(msg, reply.error_name, str(reply.body)))
|
||||
else:
|
||||
send_reply(
|
||||
Message.new_error(msg, ErrorType.FAILED, "could not get machine_id")
|
||||
@ -1009,13 +994,12 @@ class BaseMessageBus:
|
||||
)
|
||||
|
||||
def _default_get_managed_objects_handler(
|
||||
self, msg: Message, send_reply: Callable[[Message], None]
|
||||
self, msg: Message, send_reply: SendReply
|
||||
) -> None:
|
||||
result = {}
|
||||
result_signature = "a{oa{sa{sv}}}"
|
||||
error_handled = False
|
||||
|
||||
def is_result_complete():
|
||||
def is_result_complete() -> bool:
|
||||
if not result:
|
||||
return True
|
||||
for n, interfaces in result.items():
|
||||
@ -1025,6 +1009,9 @@ class BaseMessageBus:
|
||||
|
||||
return True
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert msg.path is not None
|
||||
|
||||
nodes = [
|
||||
node
|
||||
for node in self._path_exports
|
||||
@ -1032,16 +1019,18 @@ class BaseMessageBus:
|
||||
]
|
||||
|
||||
# first build up the result object to know when it's complete
|
||||
for node in nodes:
|
||||
result[node] = {}
|
||||
for interface in self._path_exports[node]:
|
||||
result[node][interface.name] = None
|
||||
result: dict[str, dict[str, Any]] = {
|
||||
node: {interface: None for interface in self._path_exports[node]}
|
||||
for node in nodes
|
||||
}
|
||||
|
||||
if is_result_complete():
|
||||
send_reply(Message.new_method_return(msg, result_signature, [result]))
|
||||
return
|
||||
|
||||
def get_all_properties_callback(interface, values, node, err):
|
||||
def get_all_properties_callback(
|
||||
interface: ServiceInterface, values: Any, node: str, err: Exception | None
|
||||
) -> None:
|
||||
nonlocal error_handled
|
||||
if err is not None:
|
||||
if not error_handled:
|
||||
@ -1055,14 +1044,12 @@ class BaseMessageBus:
|
||||
send_reply(Message.new_method_return(msg, result_signature, [result]))
|
||||
|
||||
for node in nodes:
|
||||
for interface in self._path_exports[node]:
|
||||
for interface in self._path_exports[node].values():
|
||||
ServiceInterface._get_all_property_values(
|
||||
interface, get_all_properties_callback, node
|
||||
)
|
||||
|
||||
def _default_properties_handler(
|
||||
self, msg: Message, send_reply: Callable[[Message], None]
|
||||
) -> None:
|
||||
def _default_properties_handler(self, msg: Message, send_reply: SendReply) -> None:
|
||||
methods = {"Get": "ss", "Set": "ssv", "GetAll": "s"}
|
||||
if msg.member not in methods or methods[msg.member] != msg.signature:
|
||||
raise DBusError(
|
||||
@ -1082,12 +1069,7 @@ class BaseMessageBus:
|
||||
ErrorType.UNKNOWN_OBJECT, f'no interfaces at path: "{msg.path}"'
|
||||
)
|
||||
|
||||
match = [
|
||||
iface
|
||||
for iface in self._path_exports[msg.path]
|
||||
if iface.name == interface_name
|
||||
]
|
||||
if not match:
|
||||
if (interface := self._path_exports[msg.path].get(interface_name)) is None:
|
||||
if interface_name in [
|
||||
"org.freedesktop.DBus.Properties",
|
||||
"org.freedesktop.DBus.Introspectable",
|
||||
@ -1111,7 +1093,6 @@ class BaseMessageBus:
|
||||
f'could not find an interface "{interface_name}" at path: "{msg.path}"',
|
||||
)
|
||||
|
||||
interface = match[0]
|
||||
properties = ServiceInterface._get_properties(interface)
|
||||
|
||||
if msg.member == "Get" or msg.member == "Set":
|
||||
@ -1135,7 +1116,12 @@ class BaseMessageBus:
|
||||
"the property does not have read access",
|
||||
)
|
||||
|
||||
def get_property_callback(interface, prop, prop_value, err):
|
||||
def get_property_callback(
|
||||
interface: ServiceInterface,
|
||||
prop: _Property,
|
||||
prop_value: Any,
|
||||
err: Exception | None,
|
||||
) -> None:
|
||||
try:
|
||||
if err is not None:
|
||||
send_reply.send_error(err)
|
||||
@ -1173,7 +1159,9 @@ class BaseMessageBus:
|
||||
)
|
||||
assert prop.prop_setter
|
||||
|
||||
def set_property_callback(interface, prop, err):
|
||||
def set_property_callback(
|
||||
interface: ServiceInterface, prop: _Property, err: Exception | None
|
||||
) -> None:
|
||||
if err is not None:
|
||||
send_reply.send_error(err)
|
||||
return
|
||||
@ -1188,7 +1176,12 @@ class BaseMessageBus:
|
||||
|
||||
elif msg.member == "GetAll":
|
||||
|
||||
def get_all_properties_callback(interface, values, user_data, err):
|
||||
def get_all_properties_callback(
|
||||
interface: ServiceInterface,
|
||||
values: Any,
|
||||
user_data: Any,
|
||||
err: Exception | None,
|
||||
) -> None:
|
||||
if err is not None:
|
||||
send_reply.send_error(err)
|
||||
return
|
||||
@ -1212,12 +1205,12 @@ class BaseMessageBus:
|
||||
return
|
||||
self._high_level_client_initialized = True
|
||||
|
||||
def add_match_notify(msg, err):
|
||||
def add_match_notify(msg: Message | None, err: Exception | None) -> None:
|
||||
if err:
|
||||
logging.error(
|
||||
f'add match request failed. match="{self._name_owner_match_rule}", {err}'
|
||||
)
|
||||
elif msg.message_type == MessageType.ERROR:
|
||||
elif msg is not None and msg.message_type == MessageType.ERROR:
|
||||
logging.error(
|
||||
f'add match request failed. match="{self._name_owner_match_rule}", {msg.body[0]}'
|
||||
)
|
||||
@ -1234,7 +1227,7 @@ class BaseMessageBus:
|
||||
add_match_notify,
|
||||
)
|
||||
|
||||
def _add_match_rule(self, match_rule):
|
||||
def _add_match_rule(self, match_rule: str) -> None:
|
||||
"""Add a match rule. Match rules added by this function are refcounted
|
||||
and must be removed by _remove_match_rule(). This is for use in the
|
||||
high level client only."""
|
||||
@ -1247,10 +1240,10 @@ class BaseMessageBus:
|
||||
|
||||
self._match_rules[match_rule] = 1
|
||||
|
||||
def add_match_notify(msg: Message, err: Optional[Exception]) -> None:
|
||||
def add_match_notify(msg: Message | None, err: Exception | None) -> None:
|
||||
if err:
|
||||
logging.error(f'add match request failed. match="{match_rule}", {err}')
|
||||
elif msg.message_type == MessageType.ERROR:
|
||||
elif msg is not None and msg.message_type == MessageType.ERROR:
|
||||
logging.error(
|
||||
f'add match request failed. match="{match_rule}", {msg.body[0]}'
|
||||
)
|
||||
@ -1267,7 +1260,7 @@ class BaseMessageBus:
|
||||
add_match_notify,
|
||||
)
|
||||
|
||||
def _remove_match_rule(self, match_rule):
|
||||
def _remove_match_rule(self, match_rule: str) -> None:
|
||||
"""Remove a match rule added with _add_match_rule(). This is for use in
|
||||
the high level client only."""
|
||||
if match_rule == self._name_owner_match_rule:
|
||||
@ -1280,7 +1273,7 @@ class BaseMessageBus:
|
||||
|
||||
del self._match_rules[match_rule]
|
||||
|
||||
def remove_match_notify(msg, err):
|
||||
def remove_match_notify(msg: Message | None, err: Exception | None) -> None:
|
||||
if self._disconnected:
|
||||
return
|
||||
|
||||
@ -1288,7 +1281,7 @@ class BaseMessageBus:
|
||||
logging.error(
|
||||
f'remove match request failed. match="{match_rule}", {err}'
|
||||
)
|
||||
elif msg.message_type == MessageType.ERROR:
|
||||
elif msg is not None and msg.message_type == MessageType.ERROR:
|
||||
logging.error(
|
||||
f'remove match request failed. match="{match_rule}", {msg.body[0]}'
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import traceback
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .constants import ErrorType
|
||||
from .errors import DBusError
|
||||
@ -15,12 +16,12 @@ class SendReply:
|
||||
|
||||
__slots__ = ("_bus", "_msg")
|
||||
|
||||
def __init__(self, bus: "BaseMessageBus", msg: Message) -> None:
|
||||
def __init__(self, bus: BaseMessageBus, msg: Message) -> None:
|
||||
"""Create a new reply context manager."""
|
||||
self._bus = bus
|
||||
self._msg = msg
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> SendReply:
|
||||
return self
|
||||
|
||||
def __call__(self, reply: Message) -> None:
|
||||
@ -28,9 +29,9 @@ class SendReply:
|
||||
|
||||
def _exit(
|
||||
self,
|
||||
exc_type: Optional[type[Exception]],
|
||||
exc_value: Optional[Exception],
|
||||
tb: Optional[TracebackType],
|
||||
exc_type: type[Exception] | None,
|
||||
exc_value: Exception | None,
|
||||
tb: TracebackType | None,
|
||||
) -> bool:
|
||||
if exc_value:
|
||||
if isinstance(exc_value, DBusError):
|
||||
@ -49,9 +50,9 @@ class SendReply:
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[Exception]],
|
||||
exc_value: Optional[Exception],
|
||||
tb: Optional[TracebackType],
|
||||
exc_type: type[Exception] | None,
|
||||
exc_value: Exception | None,
|
||||
tb: TracebackType | None,
|
||||
) -> bool:
|
||||
return self._exit(exc_type, exc_value, tb)
|
||||
|
||||
|
||||
@ -33,12 +33,11 @@ cdef class ServiceInterface:
|
||||
cdef list __signals
|
||||
cdef set __buses
|
||||
cdef dict __handlers
|
||||
cdef dict __enabled_handlers_by_name_signature
|
||||
|
||||
@cython.locals(handlers=dict,in_signature=str,method=_Method)
|
||||
@staticmethod
|
||||
cdef list _c_get_methods(ServiceInterface interface)
|
||||
|
||||
@staticmethod
|
||||
cdef object _c_get_handler(ServiceInterface interface, _Method method, object bus)
|
||||
cdef object _get_enabled_handler_by_name_signature(ServiceInterface interface, object bus, object name, object signature)
|
||||
|
||||
@staticmethod
|
||||
cdef list _c_msg_body_to_args(Message msg)
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Protocol
|
||||
|
||||
from . import introspection as intr
|
||||
from ._private.util import (
|
||||
@ -20,19 +21,30 @@ from .signature import (
|
||||
Variant,
|
||||
get_signature_tree,
|
||||
)
|
||||
from .send_reply import SendReply
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .message_bus import BaseMessageBus
|
||||
|
||||
str_ = str
|
||||
|
||||
HandlerType = Callable[[Message, SendReply], None]
|
||||
|
||||
|
||||
class _MethodCallbackProtocol(Protocol):
|
||||
def __call__(self, interface: ServiceInterface, *args: Any) -> Any: ...
|
||||
|
||||
|
||||
class _Method:
|
||||
def __init__(self, fn, name: str, disabled=False):
|
||||
def __init__(
|
||||
self, fn: _MethodCallbackProtocol, name: str, disabled: bool = False
|
||||
) -> None:
|
||||
in_signature = ""
|
||||
out_signature = ""
|
||||
|
||||
inspection = inspect.signature(fn)
|
||||
|
||||
in_args = []
|
||||
in_args: list[intr.Arg] = []
|
||||
for i, param in enumerate(inspection.parameters.values()):
|
||||
if i == 0:
|
||||
# first is self
|
||||
@ -45,7 +57,7 @@ class _Method:
|
||||
in_args.append(intr.Arg(annotation, intr.ArgDirection.IN, param.name))
|
||||
in_signature += annotation
|
||||
|
||||
out_args = []
|
||||
out_args: list[intr.Arg] = []
|
||||
out_signature = parse_annotation(inspection.return_annotation)
|
||||
if out_signature:
|
||||
for type_ in get_signature_tree(out_signature).types:
|
||||
@ -61,7 +73,7 @@ class _Method:
|
||||
self.out_signature_tree = get_signature_tree(out_signature)
|
||||
|
||||
|
||||
def method(name: Optional[str] = None, disabled: bool = False):
|
||||
def method(name: str | None = None, disabled: bool = False) -> Callable:
|
||||
"""A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus service method.
|
||||
|
||||
The parameters and return value must each be annotated with a signature
|
||||
@ -99,9 +111,9 @@ def method(name: Optional[str] = None, disabled: bool = False):
|
||||
if type(disabled) is not bool:
|
||||
raise TypeError("disabled must be a bool")
|
||||
|
||||
def decorator(fn):
|
||||
def decorator(fn: Callable) -> Callable:
|
||||
@wraps(fn)
|
||||
def wrapped(*args, **kwargs):
|
||||
def wrapped(*args: Any, **kwargs: Any) -> None:
|
||||
fn(*args, **kwargs)
|
||||
|
||||
fn_name = name if name else fn.__name__
|
||||
@ -113,7 +125,7 @@ def method(name: Optional[str] = None, disabled: bool = False):
|
||||
|
||||
|
||||
class _Signal:
|
||||
def __init__(self, fn, name, disabled=False):
|
||||
def __init__(self, fn: Callable, name: str, disabled: bool = False) -> None:
|
||||
inspection = inspect.signature(fn)
|
||||
|
||||
args = []
|
||||
@ -138,7 +150,7 @@ class _Signal:
|
||||
self.introspection = intr.Signal(self.name, args)
|
||||
|
||||
|
||||
def signal(name: Optional[str] = None, disabled: bool = False):
|
||||
def signal(name: str | None = None, disabled: bool = False) -> Callable:
|
||||
"""A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus signal.
|
||||
|
||||
The signal is broadcast on the bus when the decorated class method is
|
||||
@ -173,12 +185,12 @@ def signal(name: Optional[str] = None, disabled: bool = False):
|
||||
if type(disabled) is not bool:
|
||||
raise TypeError("disabled must be a bool")
|
||||
|
||||
def decorator(fn):
|
||||
def decorator(fn: Callable) -> Callable:
|
||||
fn_name = name if name else fn.__name__
|
||||
signal = _Signal(fn, fn_name, disabled)
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
def wrapped(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if signal.disabled:
|
||||
raise SignalDisabledError("Tried to call a disabled signal")
|
||||
result = fn(self, *args, **kwargs)
|
||||
@ -259,9 +271,9 @@ class _Property(property):
|
||||
|
||||
def dbus_property(
|
||||
access: PropertyAccess = PropertyAccess.READWRITE,
|
||||
name: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
disabled: bool = False,
|
||||
):
|
||||
) -> Callable:
|
||||
"""A decorator to mark a class method of a :class:`ServiceInterface` to be a DBus property.
|
||||
|
||||
The class method must be a Python getter method with a return annotation
|
||||
@ -306,7 +318,7 @@ def dbus_property(
|
||||
if type(disabled) is not bool:
|
||||
raise TypeError("disabled must be a bool")
|
||||
|
||||
def decorator(fn):
|
||||
def decorator(fn: Callable) -> _Property:
|
||||
options = {"name": name, "access": access, "disabled": disabled}
|
||||
return _Property(fn, options=options)
|
||||
|
||||
@ -314,7 +326,7 @@ def dbus_property(
|
||||
|
||||
|
||||
def _real_fn_result_to_body(
|
||||
result: Optional[Any],
|
||||
result: Any | None,
|
||||
signature_tree: SignatureTree,
|
||||
replace_fds: bool,
|
||||
) -> tuple[list[Any], list[int]]:
|
||||
@ -334,7 +346,8 @@ def _real_fn_result_to_body(
|
||||
|
||||
if out_len != len(final_result):
|
||||
raise SignatureBodyMismatchError(
|
||||
f"Signature and function return mismatch, expected {len(signature_tree.types)} arguments but got {len(result)}"
|
||||
f"Signature and function return mismatch, expected "
|
||||
f"{len(signature_tree.types)} arguments but got {len(result)}" # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if not replace_fds:
|
||||
@ -365,12 +378,12 @@ class ServiceInterface:
|
||||
self.__methods: list[_Method] = []
|
||||
self.__properties: list[_Property] = []
|
||||
self.__signals: list[_Signal] = []
|
||||
self.__buses = set()
|
||||
self.__handlers: dict[
|
||||
BaseMessageBus,
|
||||
dict[_Method, Callable[[Message, Callable[[Message], None]], None]],
|
||||
self.__buses: set[BaseMessageBus] = set()
|
||||
self.__handlers: dict[BaseMessageBus, dict[_Method, HandlerType]] = {}
|
||||
# Map of methods by bus of name -> method, handler
|
||||
self.__handlers_by_name_signature: dict[
|
||||
BaseMessageBus, dict[str, tuple[_Method, HandlerType]]
|
||||
] = {}
|
||||
|
||||
for name, member in inspect.getmembers(type(self)):
|
||||
member_dict = getattr(member, "__dict__", {})
|
||||
if type(member) is _Property:
|
||||
@ -405,7 +418,7 @@ class ServiceInterface:
|
||||
|
||||
def emit_properties_changed(
|
||||
self, changed_properties: dict[str, Any], invalidated_properties: list[str] = []
|
||||
):
|
||||
) -> None:
|
||||
"""Emit the ``org.freedesktop.DBus.Properties.PropertiesChanged`` signal.
|
||||
|
||||
This signal is intended to be used to alert clients when a property of
|
||||
@ -464,58 +477,59 @@ class ServiceInterface:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_properties(interface: "ServiceInterface") -> list[_Property]:
|
||||
def _get_properties(interface: ServiceInterface) -> list[_Property]:
|
||||
return interface.__properties
|
||||
|
||||
@staticmethod
|
||||
def _get_methods(interface: "ServiceInterface") -> list[_Method]:
|
||||
def _get_methods(interface: ServiceInterface) -> list[_Method]:
|
||||
return interface.__methods
|
||||
|
||||
@staticmethod
|
||||
def _c_get_methods(interface: "ServiceInterface") -> list[_Method]:
|
||||
# _c_get_methods is used by the C code to get the methods for an
|
||||
# interface
|
||||
# https://github.com/cython/cython/issues/3327
|
||||
return interface.__methods
|
||||
|
||||
@staticmethod
|
||||
def _get_signals(interface: "ServiceInterface") -> list[_Signal]:
|
||||
def _get_signals(interface: ServiceInterface) -> list[_Signal]:
|
||||
return interface.__signals
|
||||
|
||||
@staticmethod
|
||||
def _get_buses(interface: "ServiceInterface") -> set["BaseMessageBus"]:
|
||||
def _get_buses(interface: ServiceInterface) -> set[BaseMessageBus]:
|
||||
return interface.__buses
|
||||
|
||||
@staticmethod
|
||||
def _get_handler(
|
||||
interface: "ServiceInterface", method: _Method, bus: "BaseMessageBus"
|
||||
) -> Callable[[Message, Callable[[Message], None]], None]:
|
||||
interface: ServiceInterface, method: _Method, bus: BaseMessageBus
|
||||
) -> HandlerType:
|
||||
return interface.__handlers[bus][method]
|
||||
|
||||
@staticmethod
|
||||
def _c_get_handler(
|
||||
interface: "ServiceInterface", method: _Method, bus: "BaseMessageBus"
|
||||
) -> Callable[[Message, Callable[[Message], None]], None]:
|
||||
# _c_get_handler is used by the C code to get the handler for a method
|
||||
# https://github.com/cython/cython/issues/3327
|
||||
return interface.__handlers[bus][method]
|
||||
def _get_enabled_handler_by_name_signature(
|
||||
interface: ServiceInterface,
|
||||
bus: BaseMessageBus,
|
||||
name: str_,
|
||||
signature: str_,
|
||||
) -> HandlerType | None:
|
||||
handlers = interface.__handlers_by_name_signature[bus]
|
||||
if (method_handler := handlers.get(name)) is None:
|
||||
return None
|
||||
method = method_handler[0]
|
||||
if method.disabled:
|
||||
return None
|
||||
return method_handler[1] if method.in_signature == signature else None
|
||||
|
||||
@staticmethod
|
||||
def _add_bus(
|
||||
interface: "ServiceInterface",
|
||||
bus: "BaseMessageBus",
|
||||
maker: Callable[
|
||||
["ServiceInterface", _Method],
|
||||
Callable[[Message, Callable[[Message], None]], None],
|
||||
],
|
||||
interface: ServiceInterface,
|
||||
bus: BaseMessageBus,
|
||||
maker: Callable[[ServiceInterface, _Method], HandlerType],
|
||||
) -> None:
|
||||
interface.__buses.add(bus)
|
||||
interface.__handlers[bus] = {
|
||||
method: maker(interface, method) for method in interface.__methods
|
||||
}
|
||||
interface.__handlers_by_name_signature[bus] = {
|
||||
method.name: (method, handler)
|
||||
for method, handler in interface.__handlers[bus].items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _remove_bus(interface: "ServiceInterface", bus: "BaseMessageBus") -> None:
|
||||
def _remove_bus(interface: ServiceInterface, bus: BaseMessageBus) -> None:
|
||||
interface.__buses.remove(bus)
|
||||
del interface.__handlers[bus]
|
||||
|
||||
@ -538,7 +552,7 @@ class ServiceInterface:
|
||||
|
||||
@staticmethod
|
||||
def _fn_result_to_body(
|
||||
result: Optional[Any],
|
||||
result: Any | None,
|
||||
signature_tree: SignatureTree,
|
||||
replace_fds: bool = True,
|
||||
) -> tuple[list[Any], list[int]]:
|
||||
@ -546,7 +560,7 @@ class ServiceInterface:
|
||||
|
||||
@staticmethod
|
||||
def _c_fn_result_to_body(
|
||||
result: Optional[Any],
|
||||
result: Any | None,
|
||||
signature_tree: SignatureTree,
|
||||
replace_fds: bool,
|
||||
) -> tuple[list[Any], list[int]]:
|
||||
@ -558,7 +572,7 @@ class ServiceInterface:
|
||||
|
||||
@staticmethod
|
||||
def _handle_signal(
|
||||
interface: "ServiceInterface", signal: _Signal, result: Optional[Any]
|
||||
interface: ServiceInterface, signal: _Signal, result: Any | None
|
||||
) -> None:
|
||||
body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree)
|
||||
for bus in ServiceInterface._get_buses(interface):
|
||||
@ -567,15 +581,19 @@ class ServiceInterface:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_property_value(interface: "ServiceInterface", prop: _Property, callback):
|
||||
def _get_property_value(
|
||||
interface: ServiceInterface,
|
||||
prop: _Property,
|
||||
callback: Callable[[ServiceInterface, _Property, Any, Exception | None], None],
|
||||
) -> None:
|
||||
# XXX MUST CHECK TYPE RETURNED BY GETTER
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(prop.prop_getter):
|
||||
task = asyncio.ensure_future(prop.prop_getter(interface))
|
||||
task: asyncio.Task = asyncio.ensure_future(prop.prop_getter(interface))
|
||||
|
||||
def get_property_callback(task):
|
||||
def get_property_callback(task_: asyncio.Task) -> None:
|
||||
try:
|
||||
result = task.result()
|
||||
result = task_.result()
|
||||
except Exception as e:
|
||||
callback(interface, prop, None, e)
|
||||
return
|
||||
@ -592,15 +610,22 @@ class ServiceInterface:
|
||||
callback(interface, prop, None, e)
|
||||
|
||||
@staticmethod
|
||||
def _set_property_value(interface: "ServiceInterface", prop, value, callback):
|
||||
def _set_property_value(
|
||||
interface: ServiceInterface,
|
||||
prop: _Property,
|
||||
value: Any,
|
||||
callback: Callable[[ServiceInterface, _Property, Exception | None], None],
|
||||
) -> None:
|
||||
# XXX MUST CHECK TYPE TO SET
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(prop.prop_setter):
|
||||
task = asyncio.ensure_future(prop.prop_setter(interface, value))
|
||||
task: asyncio.Task = asyncio.ensure_future(
|
||||
prop.prop_setter(interface, value)
|
||||
)
|
||||
|
||||
def set_property_callback(task):
|
||||
def set_property_callback(task_: asyncio.Task) -> None:
|
||||
try:
|
||||
task.result()
|
||||
task_.result()
|
||||
except Exception as e:
|
||||
callback(interface, prop, e)
|
||||
return
|
||||
@ -617,9 +642,11 @@ class ServiceInterface:
|
||||
|
||||
@staticmethod
|
||||
def _get_all_property_values(
|
||||
interface: "ServiceInterface", callback, user_data=None
|
||||
):
|
||||
result = {}
|
||||
interface: ServiceInterface,
|
||||
callback: Callable[[ServiceInterface, Any, Any, Exception | None], None],
|
||||
user_data: Any | None = None,
|
||||
) -> None:
|
||||
result: dict[str, Variant | None] = {}
|
||||
result_error = None
|
||||
|
||||
for prop in ServiceInterface._get_properties(interface):
|
||||
@ -632,10 +659,10 @@ class ServiceInterface:
|
||||
return
|
||||
|
||||
def get_property_callback(
|
||||
interface: "ServiceInterface",
|
||||
interface: ServiceInterface,
|
||||
prop: _Property,
|
||||
value: Any,
|
||||
e: Optional[Exception],
|
||||
e: Exception | None,
|
||||
) -> None:
|
||||
nonlocal result_error
|
||||
if e is not None:
|
||||
|
||||
@ -28,9 +28,14 @@ async def test_export_unexport():
|
||||
|
||||
bus = await MessageBus().connect()
|
||||
bus.export(export_path, interface)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# Already exported
|
||||
bus.export(export_path, interface)
|
||||
|
||||
assert export_path in bus._path_exports
|
||||
assert len(bus._path_exports[export_path]) == 1
|
||||
assert bus._path_exports[export_path][0] is interface
|
||||
assert bus._path_exports[export_path][interface.name] is interface
|
||||
assert len(ServiceInterface._get_buses(interface)) == 1
|
||||
|
||||
bus.export(export_path2, interface2)
|
||||
@ -60,11 +65,23 @@ async def test_export_unexport():
|
||||
assert not bus._path_exports
|
||||
assert not ServiceInterface._get_buses(interface)
|
||||
|
||||
# test unexporting by ServiceInterface
|
||||
bus.export(export_path, interface)
|
||||
bus.unexport(export_path, interface)
|
||||
assert not bus._path_exports
|
||||
assert not ServiceInterface._get_buses(interface)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
bus.unexport(export_path, object())
|
||||
|
||||
node = bus._introspect_export_path("/path/doesnt/exist")
|
||||
assert type(node) is intr.Node
|
||||
assert not node.interfaces
|
||||
assert not node.nodes
|
||||
|
||||
# Should to nothing
|
||||
bus.unexport("/path/doesnt/exist", interface)
|
||||
|
||||
bus.disconnect()
|
||||
|
||||
|
||||
|
||||
@ -149,6 +149,19 @@ async def test_peer_interface():
|
||||
assert reply.message_type == MessageType.METHOD_RETURN, reply.body[0]
|
||||
assert reply.signature == "s"
|
||||
|
||||
reply2 = await bus2.call(
|
||||
Message(
|
||||
destination=bus1.unique_name,
|
||||
path="/path/doesnt/exist",
|
||||
interface="org.freedesktop.DBus.Peer",
|
||||
member="GetMachineId",
|
||||
signature="",
|
||||
)
|
||||
)
|
||||
|
||||
assert reply2.message_type == MessageType.METHOD_RETURN, reply.body[0]
|
||||
assert reply2.signature == "s"
|
||||
|
||||
bus1.disconnect()
|
||||
bus2.disconnect()
|
||||
|
||||
@ -213,9 +226,9 @@ async def test_object_manager():
|
||||
)
|
||||
)
|
||||
|
||||
assert reply_root.signature == "a{oa{sa{sv}}}"
|
||||
assert reply_level1.signature == "a{oa{sa{sv}}}"
|
||||
assert reply_level2.signature == "a{oa{sa{sv}}}"
|
||||
assert reply_root.signature == "a{oa{sa{sv}}}", reply_root
|
||||
assert reply_level1.signature == "a{oa{sa{sv}}}", reply_level1
|
||||
assert reply_level2.signature == "a{oa{sa{sv}}}", reply_level2
|
||||
|
||||
assert reply_level2.body == [{}]
|
||||
assert reply_level1.body == [expected_reply]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user