feat: speed up sending messages with call on the MessageBus (#271)

This commit is contained in:
J. Nick Koston
2023-12-03 15:58:39 -10:00
committed by GitHub
parent e3bc922eed
commit 6d7f522e1c
4 changed files with 24 additions and 20 deletions

View File

@@ -394,7 +394,7 @@ class MessageBus(BaseMessageBus):
else: else:
_future_set_result(future, reply) _future_set_result(future, reply)
self._call(msg, reply_handler, check_callback=False) self._call(msg, reply_handler)
await future await future

View File

@@ -288,6 +288,7 @@ class MessageBus(BaseMessageBus):
this message. May return an :class:`Exception` on connection errors. this message. May return an :class:`Exception` on connection errors.
:type reply_notify: Callable :type reply_notify: Callable
""" """
BaseMessageBus._check_callback_type(reply_notify)
self._call(msg, reply_notify) self._call(msg, reply_notify)
def call_sync(self, msg: Message) -> Optional[Message]: def call_sync(self, msg: Message) -> Optional[Message]:

View File

@@ -52,3 +52,8 @@ cdef class BaseMessageBus:
cdef _find_message_handler(self, Message msg) cdef _find_message_handler(self, Message msg)
cdef _setup_socket(self) cdef _setup_socket(self)
@cython.locals(no_reply_expected=bint)
cpdef _call(self, Message msg, object callback)
cpdef next_serial(self)

View File

@@ -3,6 +3,7 @@ import logging
import socket import socket
import traceback import traceback
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable, Dict, List, Optional, Type, Union
from . import introspection as intr from . import introspection as intr
@@ -713,34 +714,35 @@ class BaseMessageBus:
if err: if err:
raise err raise err
def _reply_notify(
self,
msg: Message,
callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]],
reply: Optional[Message],
err: Optional[Exception],
) -> None:
"""Callback on reply."""
if reply and msg.destination and reply.sender:
self._name_owners[msg.destination] = reply.sender
callback(reply, err)
def _call( def _call(
self, self,
msg: Message, msg: Message,
callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]], callback: Optional[Callable[[Optional[Message], Optional[Exception]], None]],
check_callback: bool = True,
) -> None: ) -> None:
if check_callback:
BaseMessageBus._check_callback_type(callback)
if not msg.serial: if not msg.serial:
msg.serial = self.next_serial() msg.serial = self.next_serial()
no_reply_expected = _expects_reply(msg) is False no_reply_expected = not _expects_reply(msg)
# Make sure the return reply handler is installed # Make sure the return reply handler is installed
# before sending the message to avoid a race condition # before sending the message to avoid a race condition
# where the reply is lost in case the backend can # where the reply is lost in case the backend can
# send it right away. # send it right away.
if not no_reply_expected: if not no_reply_expected:
self._method_return_handlers[msg.serial] = partial(
def _reply_notify( self._reply_notify, msg, callback
reply: Optional[Message], err: Optional[Exception] )
) -> None:
"""Callback on reply."""
if reply and msg.destination and reply.sender:
self._name_owners[msg.destination] = reply.sender
callback(reply, err)
self._method_return_handlers[msg.serial] = _reply_notify
self.send(msg) self.send(msg)
@@ -986,7 +988,6 @@ class BaseMessageBus:
member="GetMachineId", member="GetMachineId",
), ),
reply_handler, reply_handler,
check_callback=False,
) )
def _default_get_managed_objects_handler( def _default_get_managed_objects_handler(
@@ -1213,7 +1214,6 @@ class BaseMessageBus:
body=[self._name_owner_match_rule], body=[self._name_owner_match_rule],
), ),
add_match_notify, add_match_notify,
check_callback=False,
) )
def _add_match_rule(self, match_rule): def _add_match_rule(self, match_rule):
@@ -1247,7 +1247,6 @@ class BaseMessageBus:
body=[match_rule], body=[match_rule],
), ),
add_match_notify, add_match_notify,
check_callback=False,
) )
def _remove_match_rule(self, match_rule): def _remove_match_rule(self, match_rule):
@@ -1286,5 +1285,4 @@ class BaseMessageBus:
body=[match_rule], body=[match_rule],
), ),
remove_match_notify, remove_match_notify,
check_callback=False,
) )