feat: reduce overhead to dispatch method handlers (#227)
This commit is contained in:
parent
8f4f9451b7
commit
b2225527ae
@ -4,7 +4,7 @@ import logging
|
||||
import socket
|
||||
from collections import deque
|
||||
from copy import copy
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, Callable, List, Optional, Set, Tuple
|
||||
|
||||
from .. import introspection as intr
|
||||
from ..auth import Authenticator, AuthExternal
|
||||
@ -18,7 +18,7 @@ from ..constants import (
|
||||
)
|
||||
from ..errors import AuthError
|
||||
from ..message import Message
|
||||
from ..message_bus import BaseMessageBus
|
||||
from ..message_bus import BaseMessageBus, _block_unexpected_reply
|
||||
from ..service import ServiceInterface
|
||||
from .message_reader import build_message_reader
|
||||
from .proxy_object import ProxyObject
|
||||
@ -173,7 +173,7 @@ class MessageBus(BaseMessageBus):
|
||||
:vartype connected: bool
|
||||
"""
|
||||
|
||||
__slots__ = ("_loop", "_auth", "_writer", "_disconnect_future")
|
||||
__slots__ = ("_loop", "_auth", "_writer", "_disconnect_future", "_pending_futures")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -193,6 +193,7 @@ class MessageBus(BaseMessageBus):
|
||||
self._auth = auth
|
||||
|
||||
self._disconnect_future = self._loop.create_future()
|
||||
self._pending_futures: Set[asyncio.Future] = set()
|
||||
|
||||
async def connect(self) -> "MessageBus":
|
||||
"""Connect this message bus to the DBus daemon.
|
||||
@ -431,12 +432,33 @@ class MessageBus(BaseMessageBus):
|
||||
if not asyncio.iscoroutinefunction(method.fn):
|
||||
return super()._make_method_handler(interface, method)
|
||||
|
||||
def _coro_method_handler(msg, send_reply):
|
||||
def done(fut):
|
||||
negotiate_unix_fd = self._negotiate_unix_fd
|
||||
msg_body_to_args = ServiceInterface._msg_body_to_args
|
||||
fn_result_to_body = ServiceInterface._fn_result_to_body
|
||||
|
||||
def _coroutine_method_handler(
|
||||
msg: Message, send_reply: Callable[[Message], None]
|
||||
) -> None:
|
||||
"""A coroutine method handler."""
|
||||
args = msg_body_to_args(msg) if msg.unix_fds else msg.body
|
||||
fut = asyncio.ensure_future(method.fn(interface, *args))
|
||||
# Hold a strong reference to the future to ensure
|
||||
# it is not garbage collected before it is done.
|
||||
self._pending_futures.add(fut)
|
||||
if (
|
||||
send_reply is _block_unexpected_reply
|
||||
or msg.flags.value & NO_REPLY_EXPECTED_VALUE
|
||||
):
|
||||
fut.add_done_callback(self._pending_futures.discard)
|
||||
return
|
||||
|
||||
# We only create the closure function if we are actually going to reply
|
||||
def _done(fut: asyncio.Future) -> None:
|
||||
"""The callback for when the method is done."""
|
||||
with send_reply:
|
||||
result = fut.result()
|
||||
body, unix_fds = ServiceInterface._fn_result_to_body(
|
||||
result, method.out_signature_tree
|
||||
body, unix_fds = fn_result_to_body(
|
||||
result, method.out_signature_tree, replace_fds=negotiate_unix_fd
|
||||
)
|
||||
send_reply(
|
||||
Message.new_method_return(
|
||||
@ -444,11 +466,11 @@ class MessageBus(BaseMessageBus):
|
||||
)
|
||||
)
|
||||
|
||||
args = ServiceInterface._msg_body_to_args(msg)
|
||||
fut = asyncio.ensure_future(method.fn(interface, *args))
|
||||
fut.add_done_callback(done)
|
||||
fut.add_done_callback(_done)
|
||||
# Discard the future only after running the done callback
|
||||
fut.add_done_callback(self._pending_futures.discard)
|
||||
|
||||
return _coro_method_handler
|
||||
return _coroutine_method_handler
|
||||
|
||||
async def _auth_readline(self) -> str:
|
||||
buf = b""
|
||||
|
||||
@ -770,24 +770,27 @@ class BaseMessageBus:
|
||||
if not msg.serial:
|
||||
msg.serial = self.next_serial()
|
||||
|
||||
def reply_notify(reply: Optional[Message], err: Optional[Exception]) -> None:
|
||||
if reply and msg.destination and reply.sender:
|
||||
self._name_owners[msg.destination] = reply.sender
|
||||
callback(reply, err) # type: ignore[misc]
|
||||
|
||||
no_reply_expected = not _expects_reply(msg)
|
||||
|
||||
# Make sure the return reply handler is installed
|
||||
# before sending the message to avoid a race condition
|
||||
# where the reply is lost in case the backend can
|
||||
# send it right away.
|
||||
if not no_reply_expected:
|
||||
self._method_return_handlers[msg.serial] = reply_notify
|
||||
|
||||
def _reply_notify(
|
||||
reply: Optional[Message], err: Optional[Exception]
|
||||
) -> None:
|
||||
"""Callback on reply."""
|
||||
if reply and msg.destination and reply.sender:
|
||||
self._name_owners[msg.destination] = reply.sender
|
||||
callback(reply, err)
|
||||
|
||||
self._method_return_handlers[msg.serial] = _reply_notify
|
||||
|
||||
self.send(msg)
|
||||
|
||||
if no_reply_expected:
|
||||
callback(None, None) # type: ignore[misc]
|
||||
callback(None, None)
|
||||
|
||||
@staticmethod
|
||||
def _check_callback_type(callback: Callable) -> None:
|
||||
@ -921,7 +924,9 @@ class BaseMessageBus:
|
||||
def _callback_method_handler(
|
||||
msg: Message, send_reply: Callable[[Message], None]
|
||||
) -> None:
|
||||
result = method_fn(interface, *msg_body_to_args(msg))
|
||||
"""This is the callback that will be called when a method call is."""
|
||||
args = msg_body_to_args(msg) if msg.unix_fds else msg.body
|
||||
result = method_fn(interface, *args)
|
||||
if send_reply is BLOCK_UNEXPECTED_REPLY or not _expects_reply(msg):
|
||||
return
|
||||
body, fds = fn_result_to_body(
|
||||
|
||||
@ -490,9 +490,7 @@ class ServiceInterface:
|
||||
|
||||
@staticmethod
|
||||
def _msg_body_to_args(msg: Message) -> List[Any]:
|
||||
if not msg.unix_fds or not signature_contains_type(
|
||||
msg.signature_tree, msg.body, "h"
|
||||
):
|
||||
if not signature_contains_type(msg.signature_tree, msg.body, "h"):
|
||||
return msg.body
|
||||
|
||||
# XXX: This deep copy could be expensive if messages are very
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user