feat: reduce overhead to dispatch method handlers (#227)

This commit is contained in:
J. Nick Koston
2023-08-18 14:33:23 -05:00
committed by GitHub
parent 8f4f9451b7
commit b2225527ae
3 changed files with 48 additions and 23 deletions

View File

@@ -4,7 +4,7 @@ import logging
import socket import socket
from collections import deque from collections import deque
from copy import copy from copy import copy
from typing import Any, List, Optional, Tuple from typing import Any, Callable, List, Optional, Set, Tuple
from .. import introspection as intr from .. import introspection as intr
from ..auth import Authenticator, AuthExternal from ..auth import Authenticator, AuthExternal
@@ -18,7 +18,7 @@ from ..constants import (
) )
from ..errors import AuthError from ..errors import AuthError
from ..message import Message from ..message import Message
from ..message_bus import BaseMessageBus from ..message_bus import BaseMessageBus, _block_unexpected_reply
from ..service import ServiceInterface from ..service import ServiceInterface
from .message_reader import build_message_reader from .message_reader import build_message_reader
from .proxy_object import ProxyObject from .proxy_object import ProxyObject
@@ -173,7 +173,7 @@ class MessageBus(BaseMessageBus):
:vartype connected: bool :vartype connected: bool
""" """
__slots__ = ("_loop", "_auth", "_writer", "_disconnect_future") __slots__ = ("_loop", "_auth", "_writer", "_disconnect_future", "_pending_futures")
def __init__( def __init__(
self, self,
@@ -193,6 +193,7 @@ class MessageBus(BaseMessageBus):
self._auth = auth self._auth = auth
self._disconnect_future = self._loop.create_future() self._disconnect_future = self._loop.create_future()
self._pending_futures: Set[asyncio.Future] = set()
async def connect(self) -> "MessageBus": async def connect(self) -> "MessageBus":
"""Connect this message bus to the DBus daemon. """Connect this message bus to the DBus daemon.
@@ -431,12 +432,33 @@ class MessageBus(BaseMessageBus):
if not asyncio.iscoroutinefunction(method.fn): if not asyncio.iscoroutinefunction(method.fn):
return super()._make_method_handler(interface, method) return super()._make_method_handler(interface, method)
def _coro_method_handler(msg, send_reply): negotiate_unix_fd = self._negotiate_unix_fd
def done(fut): msg_body_to_args = ServiceInterface._msg_body_to_args
fn_result_to_body = ServiceInterface._fn_result_to_body
def _coroutine_method_handler(
msg: Message, send_reply: Callable[[Message], None]
) -> None:
"""A coroutine method handler."""
args = msg_body_to_args(msg) if msg.unix_fds else msg.body
fut = asyncio.ensure_future(method.fn(interface, *args))
# Hold a strong reference to the future to ensure
# it is not garbage collected before it is done.
self._pending_futures.add(fut)
if (
send_reply is _block_unexpected_reply
or msg.flags.value & NO_REPLY_EXPECTED_VALUE
):
fut.add_done_callback(self._pending_futures.discard)
return
# We only create the closure function if we are actually going to reply
def _done(fut: asyncio.Future) -> None:
"""The callback for when the method is done."""
with send_reply: with send_reply:
result = fut.result() result = fut.result()
body, unix_fds = ServiceInterface._fn_result_to_body( body, unix_fds = fn_result_to_body(
result, method.out_signature_tree result, method.out_signature_tree, replace_fds=negotiate_unix_fd
) )
send_reply( send_reply(
Message.new_method_return( Message.new_method_return(
@@ -444,11 +466,11 @@ class MessageBus(BaseMessageBus):
) )
) )
args = ServiceInterface._msg_body_to_args(msg) fut.add_done_callback(_done)
fut = asyncio.ensure_future(method.fn(interface, *args)) # Discard the future only after running the done callback
fut.add_done_callback(done) fut.add_done_callback(self._pending_futures.discard)
return _coro_method_handler return _coroutine_method_handler
async def _auth_readline(self) -> str: async def _auth_readline(self) -> str:
buf = b"" buf = b""

View File

@@ -770,24 +770,27 @@ class BaseMessageBus:
if not msg.serial: if not msg.serial:
msg.serial = self.next_serial() msg.serial = self.next_serial()
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
if reply and msg.destination and reply.sender:
self._name_owners[msg.destination] = reply.sender
callback(reply, err) # type: ignore[misc]
no_reply_expected = not _expects_reply(msg) no_reply_expected = not _expects_reply(msg)
# Make sure the return reply handler is installed # Make sure the return reply handler is installed
# before sending the message to avoid a race condition # before sending the message to avoid a race condition
# where the reply is lost in case the backend can # where the reply is lost in case the backend can
# send it right away. # send it right away.
if not no_reply_expected: if not no_reply_expected:
self._method_return_handlers[msg.serial] = reply_notify
def _reply_notify(
reply: Optional[Message], err: Optional[Exception]
) -> None:
"""Callback on reply."""
if reply and msg.destination and reply.sender:
self._name_owners[msg.destination] = reply.sender
callback(reply, err)
self._method_return_handlers[msg.serial] = _reply_notify
self.send(msg) self.send(msg)
if no_reply_expected: if no_reply_expected:
callback(None, None) # type: ignore[misc] callback(None, None)
@staticmethod @staticmethod
def _check_callback_type(callback: Callable) -> None: def _check_callback_type(callback: Callable) -> None:
@@ -921,7 +924,9 @@ class BaseMessageBus:
def _callback_method_handler( def _callback_method_handler(
msg: Message, send_reply: Callable[[Message], None] msg: Message, send_reply: Callable[[Message], None]
) -> None: ) -> None:
result = method_fn(interface, *msg_body_to_args(msg)) """This is the callback that will be called when a method call is."""
args = msg_body_to_args(msg) if msg.unix_fds else msg.body
result = method_fn(interface, *args)
if send_reply is BLOCK_UNEXPECTED_REPLY or not _expects_reply(msg): if send_reply is BLOCK_UNEXPECTED_REPLY or not _expects_reply(msg):
return return
body, fds = fn_result_to_body( body, fds = fn_result_to_body(

View File

@@ -490,9 +490,7 @@ class ServiceInterface:
@staticmethod @staticmethod
def _msg_body_to_args(msg: Message) -> List[Any]: def _msg_body_to_args(msg: Message) -> List[Any]:
if not msg.unix_fds or not signature_contains_type( if not signature_contains_type(msg.signature_tree, msg.body, "h"):
msg.signature_tree, msg.body, "h"
):
return msg.body return msg.body
# XXX: This deep copy could be expensive if messages are very # XXX: This deep copy could be expensive if messages are very