feat: avoid replacing unix_fds if there are no unix_fds (#176)

This commit is contained in:
J. Nick Koston
2022-12-08 12:32:45 -10:00
committed by GitHub
parent 7d1fedbc13
commit 06647d7e49
5 changed files with 64 additions and 24 deletions

View File

@@ -1,6 +1,6 @@
import ast import ast
import inspect import inspect
from typing import Any, List, Union from typing import Any, List, Tuple, Union
from ..signature import SignatureTree, Variant, get_signature_tree from ..signature import SignatureTree, Variant, get_signature_tree
@@ -50,7 +50,7 @@ def signature_contains_type(
def replace_fds_with_idx( def replace_fds_with_idx(
signature: Union[str, SignatureTree], body: List[Any] signature: Union[str, SignatureTree], body: List[Any]
) -> (List[Any], List[int]): ) -> Tuple[List[Any], List[int]]:
"""Take the high level body format and convert it into the low level body """Take the high level body format and convert it into the low level body
format. Type 'h' refers directly to the fd in the body. Replace that with format. Type 'h' refers directly to the fd in the body. Replace that with
an index and return the corresponding list of unix fds that can be set on an index and return the corresponding list of unix fds that can be set on

View File

@@ -160,15 +160,16 @@ class MessageBus(BaseMessageBus):
:vartype connected: bool :vartype connected: bool
""" """
__slots__ = ("_loop", "_auth", "_writer", "_disconnect_future")
def __init__( def __init__(
self, self,
bus_address: str = None, bus_address: str = None,
bus_type: BusType = BusType.SESSION, bus_type: BusType = BusType.SESSION,
auth: Authenticator = None, auth: Authenticator = None,
negotiate_unix_fd=False, negotiate_unix_fd: bool = False,
): ) -> None:
super().__init__(bus_address, bus_type, ProxyObject) super().__init__(bus_address, bus_type, ProxyObject, negotiate_unix_fd)
self._negotiate_unix_fd = negotiate_unix_fd
self._loop = asyncio.get_running_loop() self._loop = asyncio.get_running_loop()
self._writer = _MessageWriter(self) self._writer = _MessageWriter(self)

View File

@@ -6,6 +6,7 @@ from .message cimport Message
cdef object MessageType cdef object MessageType
cdef object DBusError cdef object DBusError
cdef object MessageFlag cdef object MessageFlag
cdef object ServiceInterface
cdef object MESSAGE_TYPE_CALL cdef object MESSAGE_TYPE_CALL
cdef object MESSAGE_TYPE_SIGNAL cdef object MESSAGE_TYPE_SIGNAL
@@ -17,6 +18,7 @@ cdef class BaseMessageBus:
cdef public object _user_disconnect cdef public object _user_disconnect
cdef public object _method_return_handlers cdef public object _method_return_handlers
cdef public object _serial cdef public object _serial
cdef public object _path_exports
cdef public cython.list _user_message_handlers cdef public cython.list _user_message_handlers
cdef public object _name_owners cdef public object _name_owners
cdef public object _bus_address cdef public object _bus_address
@@ -25,5 +27,11 @@ cdef class BaseMessageBus:
cdef public object _high_level_client_initialized cdef public object _high_level_client_initialized
cdef public object _ProxyObject cdef public object _ProxyObject
cdef public object _machine_id cdef public object _machine_id
cdef public object _negotiate_unix_fd
cdef public object _sock
cdef public object _stream
cdef public object _fd
cpdef _process_message(self, Message msg) cpdef _process_message(self, Message msg)
cdef _find_message_handler(self, Message msg)

View File

@@ -67,12 +67,17 @@ class BaseMessageBus:
"_serial", "_serial",
"_user_message_handlers", "_user_message_handlers",
"_name_owners", "_name_owners",
"_path_exports",
"_bus_address", "_bus_address",
"_name_owner_match_rule", "_name_owner_match_rule",
"_match_rules", "_match_rules",
"_high_level_client_initialized", "_high_level_client_initialized",
"_ProxyObject", "_ProxyObject",
"_machine_id", "_machine_id",
"_negotiate_unix_fd",
"_sock",
"_stream",
"_fd",
) )
def __init__( def __init__(
@@ -80,9 +85,11 @@ class BaseMessageBus:
bus_address: Optional[str] = None, bus_address: Optional[str] = None,
bus_type: BusType = BusType.SESSION, bus_type: BusType = BusType.SESSION,
ProxyObject: Optional[Type[BaseProxyObject]] = None, ProxyObject: Optional[Type[BaseProxyObject]] = None,
negotiate_unix_fd: bool = False,
) -> None: ) -> None:
self.unique_name: Optional[str] = None self.unique_name: Optional[str] = None
self._disconnected = False self._disconnected = False
self._negotiate_unix_fd = negotiate_unix_fd
# True if the user disconnected himself, so don't throw errors out of # True if the user disconnected himself, so don't throw errors out of
# the main loop. # the main loop.
@@ -870,14 +877,16 @@ class BaseMessageBus:
args = ServiceInterface._msg_body_to_args(msg) args = ServiceInterface._msg_body_to_args(msg)
result = method.fn(interface, *args) result = method.fn(interface, *args)
body, fds = ServiceInterface._fn_result_to_body( body, fds = ServiceInterface._fn_result_to_body(
result, signature_tree=method.out_signature_tree result,
signature_tree=method.out_signature_tree,
replace_fds=self._negotiate_unix_fd,
) )
send_reply(Message.new_method_return(msg, method.out_signature, body, fds)) send_reply(Message.new_method_return(msg, method.out_signature, body, fds))
return handler return handler
def _find_message_handler( def _find_message_handler(
self, msg: Message self, msg
) -> Optional[Callable[[Message, Callable], None]]: ) -> Optional[Callable[[Message, Callable], None]]:
handler: Optional[Callable[[Message, Callable], None]] = None handler: Optional[Callable[[Message, Callable], None]] = None

View File

@@ -8,7 +8,9 @@ from typing import (
Callable, Callable,
Dict, Dict,
List, List,
Optional,
Set, Set,
Tuple,
no_type_check_decorator, no_type_check_decorator,
) )
@@ -22,7 +24,12 @@ from ._private.util import (
from .constants import PropertyAccess from .constants import PropertyAccess
from .errors import SignalDisabledError from .errors import SignalDisabledError
from .message import Message from .message import Message
from .signature import SignatureBodyMismatchError, Variant, get_signature_tree from .signature import (
SignatureBodyMismatchError,
SignatureTree,
Variant,
get_signature_tree,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .message_bus import BaseMessageBus from .message_bus import BaseMessageBus
@@ -482,19 +489,23 @@ class ServiceInterface:
del interface.__handlers[bus] del interface.__handlers[bus]
@staticmethod @staticmethod
def _msg_body_to_args(msg): def _msg_body_to_args(msg: Message) -> List[Any]:
if signature_contains_type(msg.signature_tree, msg.body, "h"): if not msg.unix_fds or not signature_contains_type(
# XXX: This deep copy could be expensive if messages are very msg.signature_tree, msg.body, "h"
# large. We could optimize this by only copying what we change ):
# here.
return replace_idx_with_fds(
msg.signature_tree, copy.deepcopy(msg.body), msg.unix_fds
)
else:
return msg.body return msg.body
# XXX: This deep copy could be expensive if messages are very
# large. We could optimize this by only copying what we change
# here.
return replace_idx_with_fds(
msg.signature_tree, copy.deepcopy(msg.body), msg.unix_fds
)
@staticmethod @staticmethod
def _fn_result_to_body(result, signature_tree): def _fn_result_to_body(
result: List[Any], signature_tree: SignatureTree, replace_fds: bool = True
) -> Tuple[List[Any], List[int]]:
"""The high level interfaces may return single values which may be """The high level interfaces may return single values which may be
wrapped in a list to be a message body. Also they may return fds wrapped in a list to be a message body. Also they may return fds
directly for type 'h' which need to be put into an external list.""" directly for type 'h' which need to be put into an external list."""
@@ -515,10 +526,14 @@ class ServiceInterface:
f"Signature and function return mismatch, expected {len(signature_tree.types)} arguments but got {len(result)}" f"Signature and function return mismatch, expected {len(signature_tree.types)} arguments but got {len(result)}"
) )
if not replace_fds:
return result, []
return replace_fds_with_idx(signature_tree, result) return replace_fds_with_idx(signature_tree, result)
@staticmethod @staticmethod
def _handle_signal(interface, signal, result): def _handle_signal(
interface: "ServiceInterface", signal: _Signal, result: List[Any]
) -> None:
body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree) body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree)
for bus in ServiceInterface._get_buses(interface): for bus in ServiceInterface._get_buses(interface):
bus._interface_signal_notify( bus._interface_signal_notify(
@@ -526,7 +541,7 @@ class ServiceInterface:
) )
@staticmethod @staticmethod
def _get_property_value(interface, prop, callback): def _get_property_value(interface: "ServiceInterface", prop: _Property, callback):
# XXX MUST CHECK TYPE RETURNED BY GETTER # XXX MUST CHECK TYPE RETURNED BY GETTER
try: try:
if asyncio.iscoroutinefunction(prop.prop_getter): if asyncio.iscoroutinefunction(prop.prop_getter):
@@ -551,7 +566,7 @@ class ServiceInterface:
callback(interface, prop, None, e) callback(interface, prop, None, e)
@staticmethod @staticmethod
def _set_property_value(interface, prop, value, callback): def _set_property_value(interface: "ServiceInterface", prop, value, callback):
# XXX MUST CHECK TYPE TO SET # XXX MUST CHECK TYPE TO SET
try: try:
if asyncio.iscoroutinefunction(prop.prop_setter): if asyncio.iscoroutinefunction(prop.prop_setter):
@@ -575,7 +590,9 @@ class ServiceInterface:
callback(interface, prop, e) callback(interface, prop, e)
@staticmethod @staticmethod
def _get_all_property_values(interface, callback, user_data=None): def _get_all_property_values(
interface: "ServiceInterface", callback, user_data=None
):
result = {} result = {}
result_error = None result_error = None
@@ -588,7 +605,12 @@ class ServiceInterface:
callback(interface, result, user_data, None) callback(interface, result, user_data, None)
return return
def get_property_callback(interface, prop, value, e): def get_property_callback(
interface: "ServiceInterface",
prop: _Property,
value: Any,
e: Optional[Exception],
) -> None:
nonlocal result_error nonlocal result_error
if e is not None: if e is not None:
result_error = e result_error = e