feat: avoid replacing unix_fds if there are no unix_fds (#176)
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, List, Union
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
from ..signature import SignatureTree, Variant, get_signature_tree
|
from ..signature import SignatureTree, Variant, get_signature_tree
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ def signature_contains_type(
|
|||||||
|
|
||||||
def replace_fds_with_idx(
|
def replace_fds_with_idx(
|
||||||
signature: Union[str, SignatureTree], body: List[Any]
|
signature: Union[str, SignatureTree], body: List[Any]
|
||||||
) -> (List[Any], List[int]):
|
) -> Tuple[List[Any], List[int]]:
|
||||||
"""Take the high level body format and convert it into the low level body
|
"""Take the high level body format and convert it into the low level body
|
||||||
format. Type 'h' refers directly to the fd in the body. Replace that with
|
format. Type 'h' refers directly to the fd in the body. Replace that with
|
||||||
an index and return the corresponding list of unix fds that can be set on
|
an index and return the corresponding list of unix fds that can be set on
|
||||||
|
|||||||
@@ -160,15 +160,16 @@ class MessageBus(BaseMessageBus):
|
|||||||
:vartype connected: bool
|
:vartype connected: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("_loop", "_auth", "_writer", "_disconnect_future")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
bus_address: str = None,
|
bus_address: str = None,
|
||||||
bus_type: BusType = BusType.SESSION,
|
bus_type: BusType = BusType.SESSION,
|
||||||
auth: Authenticator = None,
|
auth: Authenticator = None,
|
||||||
negotiate_unix_fd=False,
|
negotiate_unix_fd: bool = False,
|
||||||
):
|
) -> None:
|
||||||
super().__init__(bus_address, bus_type, ProxyObject)
|
super().__init__(bus_address, bus_type, ProxyObject, negotiate_unix_fd)
|
||||||
self._negotiate_unix_fd = negotiate_unix_fd
|
|
||||||
self._loop = asyncio.get_running_loop()
|
self._loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
self._writer = _MessageWriter(self)
|
self._writer = _MessageWriter(self)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from .message cimport Message
|
|||||||
cdef object MessageType
|
cdef object MessageType
|
||||||
cdef object DBusError
|
cdef object DBusError
|
||||||
cdef object MessageFlag
|
cdef object MessageFlag
|
||||||
|
cdef object ServiceInterface
|
||||||
|
|
||||||
cdef object MESSAGE_TYPE_CALL
|
cdef object MESSAGE_TYPE_CALL
|
||||||
cdef object MESSAGE_TYPE_SIGNAL
|
cdef object MESSAGE_TYPE_SIGNAL
|
||||||
@@ -17,6 +18,7 @@ cdef class BaseMessageBus:
|
|||||||
cdef public object _user_disconnect
|
cdef public object _user_disconnect
|
||||||
cdef public object _method_return_handlers
|
cdef public object _method_return_handlers
|
||||||
cdef public object _serial
|
cdef public object _serial
|
||||||
|
cdef public object _path_exports
|
||||||
cdef public cython.list _user_message_handlers
|
cdef public cython.list _user_message_handlers
|
||||||
cdef public object _name_owners
|
cdef public object _name_owners
|
||||||
cdef public object _bus_address
|
cdef public object _bus_address
|
||||||
@@ -25,5 +27,11 @@ cdef class BaseMessageBus:
|
|||||||
cdef public object _high_level_client_initialized
|
cdef public object _high_level_client_initialized
|
||||||
cdef public object _ProxyObject
|
cdef public object _ProxyObject
|
||||||
cdef public object _machine_id
|
cdef public object _machine_id
|
||||||
|
cdef public object _negotiate_unix_fd
|
||||||
|
cdef public object _sock
|
||||||
|
cdef public object _stream
|
||||||
|
cdef public object _fd
|
||||||
|
|
||||||
cpdef _process_message(self, Message msg)
|
cpdef _process_message(self, Message msg)
|
||||||
|
|
||||||
|
cdef _find_message_handler(self, Message msg)
|
||||||
|
|||||||
@@ -67,12 +67,17 @@ class BaseMessageBus:
|
|||||||
"_serial",
|
"_serial",
|
||||||
"_user_message_handlers",
|
"_user_message_handlers",
|
||||||
"_name_owners",
|
"_name_owners",
|
||||||
|
"_path_exports",
|
||||||
"_bus_address",
|
"_bus_address",
|
||||||
"_name_owner_match_rule",
|
"_name_owner_match_rule",
|
||||||
"_match_rules",
|
"_match_rules",
|
||||||
"_high_level_client_initialized",
|
"_high_level_client_initialized",
|
||||||
"_ProxyObject",
|
"_ProxyObject",
|
||||||
"_machine_id",
|
"_machine_id",
|
||||||
|
"_negotiate_unix_fd",
|
||||||
|
"_sock",
|
||||||
|
"_stream",
|
||||||
|
"_fd",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -80,9 +85,11 @@ class BaseMessageBus:
|
|||||||
bus_address: Optional[str] = None,
|
bus_address: Optional[str] = None,
|
||||||
bus_type: BusType = BusType.SESSION,
|
bus_type: BusType = BusType.SESSION,
|
||||||
ProxyObject: Optional[Type[BaseProxyObject]] = None,
|
ProxyObject: Optional[Type[BaseProxyObject]] = None,
|
||||||
|
negotiate_unix_fd: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.unique_name: Optional[str] = None
|
self.unique_name: Optional[str] = None
|
||||||
self._disconnected = False
|
self._disconnected = False
|
||||||
|
self._negotiate_unix_fd = negotiate_unix_fd
|
||||||
|
|
||||||
# True if the user disconnected himself, so don't throw errors out of
|
# True if the user disconnected himself, so don't throw errors out of
|
||||||
# the main loop.
|
# the main loop.
|
||||||
@@ -870,14 +877,16 @@ class BaseMessageBus:
|
|||||||
args = ServiceInterface._msg_body_to_args(msg)
|
args = ServiceInterface._msg_body_to_args(msg)
|
||||||
result = method.fn(interface, *args)
|
result = method.fn(interface, *args)
|
||||||
body, fds = ServiceInterface._fn_result_to_body(
|
body, fds = ServiceInterface._fn_result_to_body(
|
||||||
result, signature_tree=method.out_signature_tree
|
result,
|
||||||
|
signature_tree=method.out_signature_tree,
|
||||||
|
replace_fds=self._negotiate_unix_fd,
|
||||||
)
|
)
|
||||||
send_reply(Message.new_method_return(msg, method.out_signature, body, fds))
|
send_reply(Message.new_method_return(msg, method.out_signature, body, fds))
|
||||||
|
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
def _find_message_handler(
|
def _find_message_handler(
|
||||||
self, msg: Message
|
self, msg
|
||||||
) -> Optional[Callable[[Message, Callable], None]]:
|
) -> Optional[Callable[[Message, Callable], None]]:
|
||||||
handler: Optional[Callable[[Message, Callable], None]] = None
|
handler: Optional[Callable[[Message, Callable], None]] = None
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
|
Tuple,
|
||||||
no_type_check_decorator,
|
no_type_check_decorator,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,7 +24,12 @@ from ._private.util import (
|
|||||||
from .constants import PropertyAccess
|
from .constants import PropertyAccess
|
||||||
from .errors import SignalDisabledError
|
from .errors import SignalDisabledError
|
||||||
from .message import Message
|
from .message import Message
|
||||||
from .signature import SignatureBodyMismatchError, Variant, get_signature_tree
|
from .signature import (
|
||||||
|
SignatureBodyMismatchError,
|
||||||
|
SignatureTree,
|
||||||
|
Variant,
|
||||||
|
get_signature_tree,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .message_bus import BaseMessageBus
|
from .message_bus import BaseMessageBus
|
||||||
@@ -482,19 +489,23 @@ class ServiceInterface:
|
|||||||
del interface.__handlers[bus]
|
del interface.__handlers[bus]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _msg_body_to_args(msg):
|
def _msg_body_to_args(msg: Message) -> List[Any]:
|
||||||
if signature_contains_type(msg.signature_tree, msg.body, "h"):
|
if not msg.unix_fds or not signature_contains_type(
|
||||||
# XXX: This deep copy could be expensive if messages are very
|
msg.signature_tree, msg.body, "h"
|
||||||
# large. We could optimize this by only copying what we change
|
):
|
||||||
# here.
|
|
||||||
return replace_idx_with_fds(
|
|
||||||
msg.signature_tree, copy.deepcopy(msg.body), msg.unix_fds
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return msg.body
|
return msg.body
|
||||||
|
|
||||||
|
# XXX: This deep copy could be expensive if messages are very
|
||||||
|
# large. We could optimize this by only copying what we change
|
||||||
|
# here.
|
||||||
|
return replace_idx_with_fds(
|
||||||
|
msg.signature_tree, copy.deepcopy(msg.body), msg.unix_fds
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fn_result_to_body(result, signature_tree):
|
def _fn_result_to_body(
|
||||||
|
result: List[Any], signature_tree: SignatureTree, replace_fds: bool = True
|
||||||
|
) -> Tuple[List[Any], List[int]]:
|
||||||
"""The high level interfaces may return single values which may be
|
"""The high level interfaces may return single values which may be
|
||||||
wrapped in a list to be a message body. Also they may return fds
|
wrapped in a list to be a message body. Also they may return fds
|
||||||
directly for type 'h' which need to be put into an external list."""
|
directly for type 'h' which need to be put into an external list."""
|
||||||
@@ -515,10 +526,14 @@ class ServiceInterface:
|
|||||||
f"Signature and function return mismatch, expected {len(signature_tree.types)} arguments but got {len(result)}"
|
f"Signature and function return mismatch, expected {len(signature_tree.types)} arguments but got {len(result)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not replace_fds:
|
||||||
|
return result, []
|
||||||
return replace_fds_with_idx(signature_tree, result)
|
return replace_fds_with_idx(signature_tree, result)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _handle_signal(interface, signal, result):
|
def _handle_signal(
|
||||||
|
interface: "ServiceInterface", signal: _Signal, result: List[Any]
|
||||||
|
) -> None:
|
||||||
body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree)
|
body, fds = ServiceInterface._fn_result_to_body(result, signal.signature_tree)
|
||||||
for bus in ServiceInterface._get_buses(interface):
|
for bus in ServiceInterface._get_buses(interface):
|
||||||
bus._interface_signal_notify(
|
bus._interface_signal_notify(
|
||||||
@@ -526,7 +541,7 @@ class ServiceInterface:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_property_value(interface, prop, callback):
|
def _get_property_value(interface: "ServiceInterface", prop: _Property, callback):
|
||||||
# XXX MUST CHECK TYPE RETURNED BY GETTER
|
# XXX MUST CHECK TYPE RETURNED BY GETTER
|
||||||
try:
|
try:
|
||||||
if asyncio.iscoroutinefunction(prop.prop_getter):
|
if asyncio.iscoroutinefunction(prop.prop_getter):
|
||||||
@@ -551,7 +566,7 @@ class ServiceInterface:
|
|||||||
callback(interface, prop, None, e)
|
callback(interface, prop, None, e)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_property_value(interface, prop, value, callback):
|
def _set_property_value(interface: "ServiceInterface", prop, value, callback):
|
||||||
# XXX MUST CHECK TYPE TO SET
|
# XXX MUST CHECK TYPE TO SET
|
||||||
try:
|
try:
|
||||||
if asyncio.iscoroutinefunction(prop.prop_setter):
|
if asyncio.iscoroutinefunction(prop.prop_setter):
|
||||||
@@ -575,7 +590,9 @@ class ServiceInterface:
|
|||||||
callback(interface, prop, e)
|
callback(interface, prop, e)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_all_property_values(interface, callback, user_data=None):
|
def _get_all_property_values(
|
||||||
|
interface: "ServiceInterface", callback, user_data=None
|
||||||
|
):
|
||||||
result = {}
|
result = {}
|
||||||
result_error = None
|
result_error = None
|
||||||
|
|
||||||
@@ -588,7 +605,12 @@ class ServiceInterface:
|
|||||||
callback(interface, result, user_data, None)
|
callback(interface, result, user_data, None)
|
||||||
return
|
return
|
||||||
|
|
||||||
def get_property_callback(interface, prop, value, e):
|
def get_property_callback(
|
||||||
|
interface: "ServiceInterface",
|
||||||
|
prop: _Property,
|
||||||
|
value: Any,
|
||||||
|
e: Optional[Exception],
|
||||||
|
) -> None:
|
||||||
nonlocal result_error
|
nonlocal result_error
|
||||||
if e is not None:
|
if e is not None:
|
||||||
result_error = e
|
result_error = e
|
||||||
|
|||||||
Reference in New Issue
Block a user