feat: reduce overhead to dispatch method handlers (#227)

This commit is contained in:
J. Nick Koston 2023-08-18 14:33:23 -05:00 committed by GitHub
parent 8f4f9451b7
commit b2225527ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 23 deletions

View File

@ -4,7 +4,7 @@ import logging
import socket
from collections import deque
from copy import copy
from typing import Any, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Set, Tuple
from .. import introspection as intr
from ..auth import Authenticator, AuthExternal
@ -18,7 +18,7 @@ from ..constants import (
)
from ..errors import AuthError
from ..message import Message
from ..message_bus import BaseMessageBus
from ..message_bus import BaseMessageBus, _block_unexpected_reply
from ..service import ServiceInterface
from .message_reader import build_message_reader
from .proxy_object import ProxyObject
@ -173,7 +173,7 @@ class MessageBus(BaseMessageBus):
:vartype connected: bool
"""
__slots__ = ("_loop", "_auth", "_writer", "_disconnect_future")
__slots__ = ("_loop", "_auth", "_writer", "_disconnect_future", "_pending_futures")
def __init__(
self,
@ -193,6 +193,7 @@ class MessageBus(BaseMessageBus):
self._auth = auth
self._disconnect_future = self._loop.create_future()
self._pending_futures: Set[asyncio.Future] = set()
async def connect(self) -> "MessageBus":
"""Connect this message bus to the DBus daemon.
@ -431,12 +432,33 @@ class MessageBus(BaseMessageBus):
if not asyncio.iscoroutinefunction(method.fn):
return super()._make_method_handler(interface, method)
def _coro_method_handler(msg, send_reply):
def done(fut):
negotiate_unix_fd = self._negotiate_unix_fd
msg_body_to_args = ServiceInterface._msg_body_to_args
fn_result_to_body = ServiceInterface._fn_result_to_body
def _coroutine_method_handler(
msg: Message, send_reply: Callable[[Message], None]
) -> None:
"""A coroutine method handler."""
args = msg_body_to_args(msg) if msg.unix_fds else msg.body
fut = asyncio.ensure_future(method.fn(interface, *args))
# Hold a strong reference to the future to ensure
# it is not garbage collected before it is done.
self._pending_futures.add(fut)
if (
send_reply is _block_unexpected_reply
or msg.flags.value & NO_REPLY_EXPECTED_VALUE
):
fut.add_done_callback(self._pending_futures.discard)
return
# We only create the closure function if we are actually going to reply
def _done(fut: asyncio.Future) -> None:
"""The callback for when the method is done."""
with send_reply:
result = fut.result()
body, unix_fds = ServiceInterface._fn_result_to_body(
result, method.out_signature_tree
body, unix_fds = fn_result_to_body(
result, method.out_signature_tree, replace_fds=negotiate_unix_fd
)
send_reply(
Message.new_method_return(
@ -444,11 +466,11 @@ class MessageBus(BaseMessageBus):
)
)
args = ServiceInterface._msg_body_to_args(msg)
fut = asyncio.ensure_future(method.fn(interface, *args))
fut.add_done_callback(done)
fut.add_done_callback(_done)
# Discard the future only after running the done callback
fut.add_done_callback(self._pending_futures.discard)
return _coro_method_handler
return _coroutine_method_handler
async def _auth_readline(self) -> str:
buf = b""

View File

@ -770,24 +770,27 @@ class BaseMessageBus:
if not msg.serial:
msg.serial = self.next_serial()
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
if reply and msg.destination and reply.sender:
self._name_owners[msg.destination] = reply.sender
callback(reply, err) # type: ignore[misc]
no_reply_expected = not _expects_reply(msg)
# Make sure the return reply handler is installed
# before sending the message to avoid a race condition
# where the reply is lost in case the backend can
# send it right away.
if not no_reply_expected:
self._method_return_handlers[msg.serial] = reply_notify
def _reply_notify(
reply: Optional[Message], err: Optional[Exception]
) -> None:
"""Callback on reply."""
if reply and msg.destination and reply.sender:
self._name_owners[msg.destination] = reply.sender
callback(reply, err)
self._method_return_handlers[msg.serial] = _reply_notify
self.send(msg)
if no_reply_expected:
callback(None, None) # type: ignore[misc]
callback(None, None)
@staticmethod
def _check_callback_type(callback: Callable) -> None:
@ -921,7 +924,9 @@ class BaseMessageBus:
def _callback_method_handler(
msg: Message, send_reply: Callable[[Message], None]
) -> None:
result = method_fn(interface, *msg_body_to_args(msg))
"""This is the callback that will be called when a method call is."""
args = msg_body_to_args(msg) if msg.unix_fds else msg.body
result = method_fn(interface, *args)
if send_reply is BLOCK_UNEXPECTED_REPLY or not _expects_reply(msg):
return
body, fds = fn_result_to_body(

View File

@ -490,9 +490,7 @@ class ServiceInterface:
@staticmethod
def _msg_body_to_args(msg: Message) -> List[Any]:
if not msg.unix_fds or not signature_contains_type(
msg.signature_tree, msg.body, "h"
):
if not signature_contains_type(msg.signature_tree, msg.body, "h"):
return msg.body
# XXX: This deep copy could be expensive if messages are very