feat: avoid replacing unix_fds if there are no unix_fds (#176)

This commit is contained in:
J. Nick Koston 2022-12-08 12:32:45 -10:00 committed by GitHub
parent 7d1fedbc13
commit 06647d7e49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 64 additions and 24 deletions

View File

@ -1,6 +1,6 @@
import ast
import inspect
from typing import Any, List, Union
from typing import Any, List, Tuple, Union
from ..signature import SignatureTree, Variant, get_signature_tree
@ -50,7 +50,7 @@ def signature_contains_type(
def replace_fds_with_idx(
signature: Union[str, SignatureTree], body: List[Any]
) -> (List[Any], List[int]):
) -> Tuple[List[Any], List[int]]:
"""Take the high level body format and convert it into the low level body
format. Type 'h' refers directly to the fd in the body. Replace that with
an index and return the corresponding list of unix fds that can be set on

View File

@ -160,15 +160,16 @@ class MessageBus(BaseMessageBus):
:vartype connected: bool
"""
__slots__ = ("_loop", "_auth", "_writer", "_disconnect_future")
def __init__(
self,
bus_address: str = None,
bus_type: BusType = BusType.SESSION,
auth: Authenticator = None,
negotiate_unix_fd=False,
):
super().__init__(bus_address, bus_type, ProxyObject)
self._negotiate_unix_fd = negotiate_unix_fd
negotiate_unix_fd: bool = False,
) -> None:
super().__init__(bus_address, bus_type, ProxyObject, negotiate_unix_fd)
self._loop = asyncio.get_running_loop()
self._writer = _MessageWriter(self)

View File

@ -6,6 +6,7 @@ from .message cimport Message
cdef object MessageType
cdef object DBusError
cdef object MessageFlag
cdef object ServiceInterface
cdef object MESSAGE_TYPE_CALL
cdef object MESSAGE_TYPE_SIGNAL
@ -17,6 +18,7 @@ cdef class BaseMessageBus:
cdef public object _user_disconnect
cdef public object _method_return_handlers
cdef public object _serial
cdef public object _path_exports
cdef public cython.list _user_message_handlers
cdef public object _name_owners
cdef public object _bus_address
@ -25,5 +27,11 @@ cdef class BaseMessageBus:
cdef public object _high_level_client_initialized
cdef public object _ProxyObject
cdef public object _machine_id
cdef public object _negotiate_unix_fd
cdef public object _sock
cdef public object _stream
cdef public object _fd
cpdef _process_message(self, Message msg)
cdef _find_message_handler(self, Message msg)

View File

@ -67,12 +67,17 @@ class BaseMessageBus:
"_serial",
"_user_message_handlers",
"_name_owners",
"_path_exports",
"_bus_address",
"_name_owner_match_rule",
"_match_rules",
"_high_level_client_initialized",
"_ProxyObject",
"_machine_id",
"_negotiate_unix_fd",
"_sock",
"_stream",
"_fd",
)
def __init__(
@ -80,9 +85,11 @@ class BaseMessageBus:
bus_address: Optional[str] = None,
bus_type: BusType = BusType.SESSION,
ProxyObject: Optional[Type[BaseProxyObject]] = None,
negotiate_unix_fd: bool = False,
) -> None:
self.unique_name: Optional[str] = None
self._disconnected = False
self._negotiate_unix_fd = negotiate_unix_fd
# True if the user disconnected himself, so don't throw errors out of
# the main loop.
@ -870,14 +877,16 @@ class BaseMessageBus:
args = ServiceInterface._msg_body_to_args(msg)
result = method.fn(interface, *args)
body, fds = ServiceInterface._fn_result_to_body(
result, signature_tree=method.out_signature_tree
result,
signature_tree=method.out_signature_tree,
replace_fds=self._negotiate_unix_fd,
)
send_reply(Message.new_method_return(msg, method.out_signature, body, fds))
return handler
def _find_message_handler(
self, msg: Message
self, msg
) -> Optional[Callable[[Message, Callable], None]]:
handler: Optional[Callable[[Message, Callable], None]] = None

View File

@ -8,7 +8,9 @@ from typing import (
Callable,
Dict,
List,
Optional,
Set,
Tuple,
no_type_check_decorator,
)
@ -22,7 +24,12 @@ from ._private.util import (
from .constants import PropertyAccess
from .errors import SignalDisabledError
from .message import Message
from .signature import SignatureBodyMismatchError, Variant, get_signature_tree
from .signature import (
SignatureBodyMismatchError,
SignatureTree,
Variant,
get_signature_tree,
)
if TYPE_CHECKING:
from .message_bus import BaseMessageBus
@ -482,19 +489,23 @@ class ServiceInterface:
del interface.__handlers[bus]
@staticmethod
def _msg_body_to_args(msg):
if signature_contains_type(msg.signature_tree, msg.body, "h"):
# XXX: This deep copy could be expensive if messages are very
# large. We could optimize this by only copying what we change
# here.
return replace_idx_with_fds(
msg.signature_tree, copy.deepcopy(msg.body), msg.unix_fds
)
else:
def _msg_body_to_args(msg: Message) -> List[Any]:
if not msg.unix_fds or not signature_contains_type(
msg.signature_tree, msg.body, "h"
):
return msg.body
# XXX: This deep copy could be expensive if messages are very
# large. We could optimize this by only copying what we change
# here.
return replace_idx_with_fds(
msg.signature_tree, copy.deepcopy(msg.body), msg.unix_fds
)
@staticmethod
def _fn_result_to_body(result, signature_tree):
def _fn_result_to_body(
result: List[Any], signature_tree: SignatureTree, replace_fds: bool = True
) -> Tuple[List[Any], List[int]]:
"""The high level interfaces may return single values which may be
wrapped in a list to be a message body. Also they may return fds
directly for type 'h' which need to be put into an external list."""
@ -515,10 +526,14 @@ class ServiceInterface:
f"Signature and function return mismatch, expected {len(signature_tree.types)} arguments but got {len(result)}"
)
if not replace_fds:
return result, []
return replace_fds_with_idx(signature_tree, result)
@staticmethod
def _handle_signal(interface, signal, result):
def _handle_signal(
interface: "ServiceInterface", signal: _Signal, result: List[Any]
) -> None:
body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree)
for bus in ServiceInterface._get_buses(interface):
bus._interface_signal_notify(
@ -526,7 +541,7 @@ class ServiceInterface:
)
@staticmethod
def _get_property_value(interface, prop, callback):
def _get_property_value(interface: "ServiceInterface", prop: _Property, callback):
# XXX MUST CHECK TYPE RETURNED BY GETTER
try:
if asyncio.iscoroutinefunction(prop.prop_getter):
@ -551,7 +566,7 @@ class ServiceInterface:
callback(interface, prop, None, e)
@staticmethod
def _set_property_value(interface, prop, value, callback):
def _set_property_value(interface: "ServiceInterface", prop, value, callback):
# XXX MUST CHECK TYPE TO SET
try:
if asyncio.iscoroutinefunction(prop.prop_setter):
@ -575,7 +590,9 @@ class ServiceInterface:
callback(interface, prop, e)
@staticmethod
def _get_all_property_values(interface, callback, user_data=None):
def _get_all_property_values(
interface: "ServiceInterface", callback, user_data=None
):
result = {}
result_error = None
@ -588,7 +605,12 @@ class ServiceInterface:
callback(interface, result, user_data, None)
return
def get_property_callback(interface, prop, value, e):
def get_property_callback(
interface: "ServiceInterface",
prop: _Property,
value: Any,
e: Optional[Exception],
) -> None:
nonlocal result_error
if e is not None:
result_error = e