diff --git a/src/dbus_fast/aio/message_bus.py b/src/dbus_fast/aio/message_bus.py index 71b97eb..e4bb466 100644 --- a/src/dbus_fast/aio/message_bus.py +++ b/src/dbus_fast/aio/message_bus.py @@ -4,7 +4,7 @@ import logging import socket from collections import deque from copy import copy -from typing import Any, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Set, Tuple from .. import introspection as intr from ..auth import Authenticator, AuthExternal @@ -18,7 +18,7 @@ from ..constants import ( ) from ..errors import AuthError from ..message import Message -from ..message_bus import BaseMessageBus +from ..message_bus import BaseMessageBus, _block_unexpected_reply from ..service import ServiceInterface from .message_reader import build_message_reader from .proxy_object import ProxyObject @@ -173,7 +173,7 @@ class MessageBus(BaseMessageBus): :vartype connected: bool """ - __slots__ = ("_loop", "_auth", "_writer", "_disconnect_future") + __slots__ = ("_loop", "_auth", "_writer", "_disconnect_future", "_pending_futures") def __init__( self, @@ -193,6 +193,7 @@ class MessageBus(BaseMessageBus): self._auth = auth self._disconnect_future = self._loop.create_future() + self._pending_futures: Set[asyncio.Future] = set() async def connect(self) -> "MessageBus": """Connect this message bus to the DBus daemon. @@ -431,12 +432,33 @@ class MessageBus(BaseMessageBus): if not asyncio.iscoroutinefunction(method.fn): return super()._make_method_handler(interface, method) - def _coro_method_handler(msg, send_reply): - def done(fut): + negotiate_unix_fd = self._negotiate_unix_fd + msg_body_to_args = ServiceInterface._msg_body_to_args + fn_result_to_body = ServiceInterface._fn_result_to_body + + def _coroutine_method_handler( + msg: Message, send_reply: Callable[[Message], None] + ) -> None: + """A coroutine method handler.""" + args = msg_body_to_args(msg) if msg.unix_fds else msg.body + fut = asyncio.ensure_future(method.fn(interface, *args)) + # Hold a strong reference to the future to ensure + # it is not garbage collected before it is done. + self._pending_futures.add(fut) + if ( + send_reply is _block_unexpected_reply + or msg.flags.value & NO_REPLY_EXPECTED_VALUE + ): + fut.add_done_callback(self._pending_futures.discard) + return + + # We only create the closure function if we are actually going to reply + def _done(fut: asyncio.Future) -> None: + """The callback for when the method is done.""" with send_reply: result = fut.result() - body, unix_fds = ServiceInterface._fn_result_to_body( - result, method.out_signature_tree + body, unix_fds = fn_result_to_body( + result, method.out_signature_tree, replace_fds=negotiate_unix_fd ) send_reply( Message.new_method_return( @@ -444,11 +466,11 @@ class MessageBus(BaseMessageBus): ) ) - args = ServiceInterface._msg_body_to_args(msg) - fut = asyncio.ensure_future(method.fn(interface, *args)) - fut.add_done_callback(done) + fut.add_done_callback(_done) + # Discard the future only after running the done callback + fut.add_done_callback(self._pending_futures.discard) - return _coro_method_handler + return _coroutine_method_handler async def _auth_readline(self) -> str: buf = b"" diff --git a/src/dbus_fast/message_bus.py b/src/dbus_fast/message_bus.py index 8248105..98c23d4 100644 --- a/src/dbus_fast/message_bus.py +++ b/src/dbus_fast/message_bus.py @@ -770,24 +770,27 @@ class BaseMessageBus: if not msg.serial: msg.serial = self.next_serial() - def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None: - if reply and msg.destination and reply.sender: - self._name_owners[msg.destination] = reply.sender - callback(reply, err) # type: ignore[misc] - no_reply_expected = not _expects_reply(msg) - # Make sure the return reply handler is installed # before sending the message to avoid a race condition # where the reply is lost in case the backend can # send it right away. if not no_reply_expected: - self._method_return_handlers[msg.serial] = reply_notify + + def _reply_notify( + reply: Optional[Message], err: Optional[Exception] + ) -> None: + """Callback on reply.""" + if reply and msg.destination and reply.sender: + self._name_owners[msg.destination] = reply.sender + callback(reply, err) + + self._method_return_handlers[msg.serial] = _reply_notify self.send(msg) if no_reply_expected: - callback(None, None) # type: ignore[misc] + callback(None, None) @staticmethod def _check_callback_type(callback: Callable) -> None: @@ -921,7 +924,9 @@ class BaseMessageBus: def _callback_method_handler( msg: Message, send_reply: Callable[[Message], None] ) -> None: - result = method_fn(interface, *msg_body_to_args(msg)) + """This is the callback that will be called when a method call is.""" + args = msg_body_to_args(msg) if msg.unix_fds else msg.body + result = method_fn(interface, *args) if send_reply is BLOCK_UNEXPECTED_REPLY or not _expects_reply(msg): return body, fds = fn_result_to_body( diff --git a/src/dbus_fast/service.py b/src/dbus_fast/service.py index 19226a1..8868531 100644 --- a/src/dbus_fast/service.py +++ b/src/dbus_fast/service.py @@ -490,9 +490,7 @@ class ServiceInterface: @staticmethod def _msg_body_to_args(msg: Message) -> List[Any]: - if not msg.unix_fds or not signature_contains_type( - msg.signature_tree, msg.body, "h" - ): + if not signature_contains_type(msg.signature_tree, msg.body, "h"): return msg.body # XXX: This deep copy could be expensive if messages are very