feat: speed up Message creation and callbacks (#217)

This commit is contained in:
J. Nick Koston
2023-08-02 08:11:54 -10:00
committed by GitHub
parent 3e6560ded1
commit 04d6451157
6 changed files with 45 additions and 35 deletions

View File

@@ -158,6 +158,7 @@ try:
except ImportError: except ImportError:
from ._cython_compat import FAKE_CYTHON as cython from ._cython_compat import FAKE_CYTHON as cython
# #
# Alignment padding is handled with the following formula below # Alignment padding is handled with the following formula below
# #

View File

@@ -413,7 +413,7 @@ class MessageBus(BaseMessageBus):
if not asyncio.iscoroutinefunction(method.fn): if not asyncio.iscoroutinefunction(method.fn):
return super()._make_method_handler(interface, method) return super()._make_method_handler(interface, method)
def handler(msg, send_reply): def _coro_method_handler(msg, send_reply):
def done(fut): def done(fut):
with send_reply: with send_reply:
result = fut.result() result = fut.result()
@@ -430,7 +430,7 @@ class MessageBus(BaseMessageBus):
fut = asyncio.ensure_future(method.fn(interface, *args)) fut = asyncio.ensure_future(method.fn(interface, *args))
fut.add_done_callback(done) fut.add_done_callback(done)
return handler return _coro_method_handler
async def _auth_readline(self) -> str: async def _auth_readline(self) -> str:
buf = b"" buf = b""

View File

@@ -13,10 +13,10 @@ from .validators import (
) )
REQUIRED_FIELDS = { REQUIRED_FIELDS = {
MessageType.METHOD_CALL: ("path", "member"), MessageType.METHOD_CALL.value: ("path", "member"),
MessageType.SIGNAL: ("path", "member", "interface"), MessageType.SIGNAL.value: ("path", "member", "interface"),
MessageType.ERROR: ("error_name", "reply_serial"), MessageType.ERROR.value: ("error_name", "reply_serial"),
MessageType.METHOD_RETURN: ("reply_serial",), MessageType.METHOD_RETURN.value: ("reply_serial",),
} }
HEADER_PATH = HeaderField.PATH.value HEADER_PATH = HeaderField.PATH.value
@@ -146,7 +146,7 @@ class Message:
if self.error_name is not None: if self.error_name is not None:
assert_interface_name_valid(self.error_name) # type: ignore[arg-type] assert_interface_name_valid(self.error_name) # type: ignore[arg-type]
required_fields = REQUIRED_FIELDS.get(self.message_type) required_fields = REQUIRED_FIELDS.get(self.message_type.value)
if not required_fields: if not required_fields:
raise InvalidMessageError(f"got unknown message type: {self.message_type}") raise InvalidMessageError(f"got unknown message type: {self.message_type}")
for field in required_fields: for field in required_fields:

View File

@@ -881,69 +881,75 @@ class BaseMessageBus:
def _make_method_handler( def _make_method_handler(
self, interface: ServiceInterface, method: _Method self, interface: ServiceInterface, method: _Method
) -> Callable[[Message, Callable[[Message], None]], None]: ) -> Callable[[Message, Callable[[Message], None]], None]:
def handler(msg: Message, send_reply: Callable[[Message], None]) -> None: method_fn = method.fn
args = ServiceInterface._msg_body_to_args(msg) out_signature_tree = method.out_signature_tree
result = method.fn(interface, *args) negotiate_unix_fd = self._negotiate_unix_fd
body, fds = ServiceInterface._fn_result_to_body( out_signature = method.out_signature
result, message_type_method_return = MessageType.METHOD_RETURN
signature_tree=method.out_signature_tree, msg_body_to_args = ServiceInterface._msg_body_to_args
replace_fds=self._negotiate_unix_fd, fn_result_to_body = ServiceInterface._fn_result_to_body
def _callback_method_handler(
msg: Message, send_reply: Callable[[Message], None]
) -> None:
body, fds = fn_result_to_body(
method_fn(interface, *msg_body_to_args(msg)),
signature_tree=out_signature_tree,
replace_fds=negotiate_unix_fd,
) )
send_reply( send_reply(
Message( Message(
message_type=MessageType.METHOD_RETURN, message_type=message_type_method_return,
reply_serial=msg.serial, reply_serial=msg.serial,
destination=msg.sender, destination=msg.sender,
signature=method.out_signature, signature=out_signature,
body=body, body=body,
unix_fds=fds, unix_fds=fds,
) )
) )
return handler return _callback_method_handler
def _find_message_handler( def _find_message_handler(
self, msg self, msg
) -> Optional[Callable[[Message, Callable], None]]: ) -> Optional[Callable[[Message, Callable], None]]:
handler: Optional[Callable[[Message, Callable], None]] = None
if ( if (
msg.interface == "org.freedesktop.DBus.Introspectable" msg.interface == "org.freedesktop.DBus.Introspectable"
and msg.member == "Introspect" and msg.member == "Introspect"
and msg.signature == "" and msg.signature == ""
): ):
handler = self._default_introspect_handler return self._default_introspect_handler
elif msg.interface == "org.freedesktop.DBus.Properties": if msg.interface == "org.freedesktop.DBus.Properties":
handler = self._default_properties_handler return self._default_properties_handler
elif msg.interface == "org.freedesktop.DBus.Peer": if msg.interface == "org.freedesktop.DBus.Peer":
if msg.member == "Ping" and msg.signature == "": if msg.member == "Ping" and msg.signature == "":
handler = self._default_ping_handler return self._default_ping_handler
elif msg.member == "GetMachineId" and msg.signature == "": elif msg.member == "GetMachineId" and msg.signature == "":
handler = self._default_get_machine_id_handler return self._default_get_machine_id_handler
elif (
if (
msg.interface == "org.freedesktop.DBus.ObjectManager" msg.interface == "org.freedesktop.DBus.ObjectManager"
and msg.member == "GetManagedObjects" and msg.member == "GetManagedObjects"
): ):
handler = self._default_get_managed_objects_handler return self._default_get_managed_objects_handler
elif msg.path: msg_path = msg.path
for interface in self._path_exports.get(msg.path, []): if msg_path:
for interface in self._path_exports.get(msg_path, []):
for method in ServiceInterface._get_methods(interface): for method in ServiceInterface._get_methods(interface):
if method.disabled: if method.disabled:
continue continue
if ( if (
msg.interface == interface.name msg.interface == interface.name
and msg.member == method.name and msg.member == method.name
and msg.signature == method.in_signature and msg.signature == method.in_signature
): ):
handler = ServiceInterface._get_handler(interface, method, self) return ServiceInterface._get_handler(interface, method, self)
break
if handler:
break
return handler return None
def _default_introspect_handler( def _default_introspect_handler(
self, msg: Message, send_reply: Callable[[Message], None] self, msg: Message, send_reply: Callable[[Message], None]

View File

@@ -58,7 +58,6 @@ class BaseProxyInterface:
introspection: intr.Interface, introspection: intr.Interface,
bus: "message_bus.BaseMessageBus", bus: "message_bus.BaseMessageBus",
) -> None: ) -> None:
self.bus_name = bus_name self.bus_name = bus_name
self.path = path self.path = path
self.introspection = introspection self.introspection = introspection

View File

@@ -135,6 +135,7 @@ def is_member_name_valid(member: str) -> bool:
return True return True
@lru_cache(maxsize=32)
def assert_bus_name_valid(name: str) -> None: def assert_bus_name_valid(name: str) -> None:
"""Raise an error if this is not a valid bus name. """Raise an error if this is not a valid bus name.
@@ -150,6 +151,7 @@ def assert_bus_name_valid(name: str) -> None:
raise InvalidBusNameError(name) raise InvalidBusNameError(name)
@lru_cache(maxsize=1024)
def assert_object_path_valid(path: str) -> None: def assert_object_path_valid(path: str) -> None:
"""Raise an error if this is not a valid object path. """Raise an error if this is not a valid object path.
@@ -165,6 +167,7 @@ def assert_object_path_valid(path: str) -> None:
raise InvalidObjectPathError(path) raise InvalidObjectPathError(path)
@lru_cache(maxsize=32)
def assert_interface_name_valid(name: str) -> None: def assert_interface_name_valid(name: str) -> None:
"""Raise an error if this is not a valid interface name. """Raise an error if this is not a valid interface name.
@@ -180,6 +183,7 @@ def assert_interface_name_valid(name: str) -> None:
raise InvalidInterfaceNameError(name) raise InvalidInterfaceNameError(name)
@lru_cache(maxsize=512)
def assert_member_name_valid(member: str) -> None: def assert_member_name_valid(member: str) -> None:
"""Raise an error if this is not a valid member name. """Raise an error if this is not a valid member name.