feat: add unpack variants option (#20)
This commit is contained in:
parent
1209048551
commit
cfad28bd2b
@ -9,6 +9,7 @@ from ..message import Message, MessageFlag
|
||||
from ..message_bus import BaseMessageBus
|
||||
from ..proxy_object import BaseProxyInterface, BaseProxyObject
|
||||
from ..signature import Variant
|
||||
from ..signature import unpack_variants as unpack
|
||||
|
||||
|
||||
class ProxyInterface(BaseProxyInterface):
|
||||
@ -74,7 +75,9 @@ class ProxyInterface(BaseProxyInterface):
|
||||
"""
|
||||
|
||||
def _add_method(self, intr_method):
|
||||
async def method_fn(*args, flags=MessageFlag.NONE):
|
||||
async def method_fn(
|
||||
*args, flags=MessageFlag.NONE, unpack_variants: bool = False
|
||||
):
|
||||
input_body, unix_fds = replace_fds_with_idx(
|
||||
intr_method.in_signature, list(args)
|
||||
)
|
||||
@ -103,16 +106,24 @@ class ProxyInterface(BaseProxyInterface):
|
||||
|
||||
if not out_len:
|
||||
return None
|
||||
elif out_len == 1:
|
||||
|
||||
if unpack_variants:
|
||||
body = unpack(body)
|
||||
|
||||
if out_len == 1:
|
||||
return body[0]
|
||||
else:
|
||||
return body
|
||||
return body
|
||||
|
||||
method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
|
||||
setattr(self, method_name, method_fn)
|
||||
|
||||
def _add_property(self, intr_property):
|
||||
async def property_getter():
|
||||
def _add_property(
|
||||
self,
|
||||
intr_property,
|
||||
):
|
||||
async def property_getter(
|
||||
*, flags=MessageFlag.NONE, unpack_variants: bool = False
|
||||
):
|
||||
msg = await self.bus.call(
|
||||
Message(
|
||||
destination=self.bus_name,
|
||||
@ -133,7 +144,11 @@ class ProxyInterface(BaseProxyInterface):
|
||||
msg,
|
||||
)
|
||||
|
||||
return replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value
|
||||
body = replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value
|
||||
|
||||
if unpack_variants:
|
||||
return unpack(body)
|
||||
return body
|
||||
|
||||
async def property_setter(val):
|
||||
variant = Variant(intr_property.signature, val)
|
||||
|
||||
@ -8,6 +8,7 @@ from ..message import Message
|
||||
from ..message_bus import BaseMessageBus
|
||||
from ..proxy_object import BaseProxyInterface, BaseProxyObject
|
||||
from ..signature import Variant
|
||||
from ..signature import unpack_variants as unpack
|
||||
|
||||
# glib is optional
|
||||
try:
|
||||
@ -113,7 +114,7 @@ class ProxyInterface(BaseProxyInterface):
|
||||
in_len = len(intr_method.in_args)
|
||||
out_len = len(intr_method.out_args)
|
||||
|
||||
def method_fn(*args):
|
||||
def method_fn(*args, unpack_variants: bool = False):
|
||||
if len(args) != in_len + 1:
|
||||
raise TypeError(
|
||||
f"method {intr_method.name} expects {in_len} arguments and a callback (got {len(args)} args)"
|
||||
@ -136,7 +137,10 @@ class ProxyInterface(BaseProxyInterface):
|
||||
except DBusError as e:
|
||||
err = e
|
||||
|
||||
callback(msg.body, err)
|
||||
if unpack_variants:
|
||||
callback(unpack(msg.body), err)
|
||||
else:
|
||||
callback(msg.body, err)
|
||||
|
||||
self.bus.call(
|
||||
Message(
|
||||
@ -150,7 +154,7 @@ class ProxyInterface(BaseProxyInterface):
|
||||
call_notify,
|
||||
)
|
||||
|
||||
def method_fn_sync(*args):
|
||||
def method_fn_sync(*args, unpack_variants: bool = False):
|
||||
main = GLib.MainLoop()
|
||||
call_error = None
|
||||
call_body = None
|
||||
@ -171,10 +175,13 @@ class ProxyInterface(BaseProxyInterface):
|
||||
|
||||
if not out_len:
|
||||
return None
|
||||
elif out_len == 1:
|
||||
|
||||
if unpack_variants:
|
||||
call_body = unpack(call_body)
|
||||
|
||||
if out_len == 1:
|
||||
return call_body[0]
|
||||
else:
|
||||
return call_body
|
||||
return call_body
|
||||
|
||||
method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
|
||||
method_name_sync = f"{method_name}_sync"
|
||||
@ -183,7 +190,7 @@ class ProxyInterface(BaseProxyInterface):
|
||||
setattr(self, method_name_sync, method_fn_sync)
|
||||
|
||||
def _add_property(self, intr_property):
|
||||
def property_getter(callback):
|
||||
def property_getter(callback, *, unpack_variants: bool = False):
|
||||
def call_notify(msg, err):
|
||||
if err:
|
||||
callback(None, err)
|
||||
@ -204,8 +211,10 @@ class ProxyInterface(BaseProxyInterface):
|
||||
)
|
||||
callback(None, err)
|
||||
return
|
||||
|
||||
callback(variant.value, None)
|
||||
if unpack_variants:
|
||||
callback(unpack(variant.value), None)
|
||||
else:
|
||||
callback(variant.value, None)
|
||||
|
||||
self.bus.call(
|
||||
Message(
|
||||
@ -219,7 +228,7 @@ class ProxyInterface(BaseProxyInterface):
|
||||
call_notify,
|
||||
)
|
||||
|
||||
def property_getter_sync():
|
||||
def property_getter_sync(*, unpack_variants: bool = False):
|
||||
property_value = None
|
||||
reply_error = None
|
||||
|
||||
@ -236,6 +245,8 @@ class ProxyInterface(BaseProxyInterface):
|
||||
main.run()
|
||||
if reply_error:
|
||||
raise reply_error
|
||||
if unpack_variants:
|
||||
return unpack(property_value)
|
||||
return property_value
|
||||
|
||||
def property_setter(value, callback):
|
||||
|
||||
@ -3,7 +3,8 @@ import inspect
|
||||
import logging
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Coroutine, List, Type, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Coroutine, Dict, List, Type, Union
|
||||
|
||||
from . import introspection as intr
|
||||
from . import message_bus
|
||||
@ -11,9 +12,18 @@ from ._private.util import replace_idx_with_fds
|
||||
from .constants import ErrorType, MessageType
|
||||
from .errors import DBusError, InterfaceNotFoundError
|
||||
from .message import Message
|
||||
from .signature import unpack_variants as unpack
|
||||
from .validators import assert_bus_name_valid, assert_object_path_valid
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignalHandler:
|
||||
"""Signal handler."""
|
||||
|
||||
fn: Callable
|
||||
unpack_variants: bool
|
||||
|
||||
|
||||
class BaseProxyInterface:
|
||||
"""An abstract class representing a proxy to an interface exported on the bus by another client.
|
||||
|
||||
@ -46,7 +56,7 @@ class BaseProxyInterface:
|
||||
self.path = path
|
||||
self.introspection = introspection
|
||||
self.bus = bus
|
||||
self._signal_handlers = {}
|
||||
self._signal_handlers: Dict[str, List[SignalHandler]] = {}
|
||||
self._signal_match_rule = f"type='signal',sender={bus_name},interface={introspection.name},path={path}"
|
||||
|
||||
_underscorer1 = re.compile(r"(.)([A-Z][a-z]+)")
|
||||
@ -110,13 +120,21 @@ class BaseProxyInterface:
|
||||
return
|
||||
|
||||
body = replace_idx_with_fds(msg.signature, msg.body, msg.unix_fds)
|
||||
no_sig = None
|
||||
for handler in self._signal_handlers[msg.member]:
|
||||
cb_result = handler(*body)
|
||||
if handler.unpack_variants:
|
||||
if not no_sig:
|
||||
no_sig = unpack(body)
|
||||
data = no_sig
|
||||
else:
|
||||
data = body
|
||||
|
||||
cb_result = handler.fn(*data)
|
||||
if isinstance(cb_result, Coroutine):
|
||||
asyncio.create_task(cb_result)
|
||||
|
||||
def _add_signal(self, intr_signal, interface):
|
||||
def on_signal_fn(fn):
|
||||
def on_signal_fn(fn, *, unpack_variants: bool = False):
|
||||
fn_signature = inspect.signature(fn)
|
||||
if len(fn_signature.parameters) != len(intr_signal.args) and (
|
||||
inspect.Parameter.VAR_POSITIONAL
|
||||
@ -134,11 +152,15 @@ class BaseProxyInterface:
|
||||
if intr_signal.name not in self._signal_handlers:
|
||||
self._signal_handlers[intr_signal.name] = []
|
||||
|
||||
self._signal_handlers[intr_signal.name].append(fn)
|
||||
self._signal_handlers[intr_signal.name].append(
|
||||
SignalHandler(fn, unpack_variants)
|
||||
)
|
||||
|
||||
def off_signal_fn(fn):
|
||||
def off_signal_fn(fn, *, unpack_variants: bool = False):
|
||||
try:
|
||||
i = self._signal_handlers[intr_signal.name].index(fn)
|
||||
i = self._signal_handlers[intr_signal.name].index(
|
||||
SignalHandler(fn, unpack_variants)
|
||||
)
|
||||
del self._signal_handlers[intr_signal.name][i]
|
||||
if not self._signal_handlers[intr_signal.name]:
|
||||
del self._signal_handlers[intr_signal.name]
|
||||
|
||||
@ -5,6 +5,17 @@ from .errors import InvalidSignatureError, SignatureBodyMismatchError
|
||||
from .validators import is_object_path_valid
|
||||
|
||||
|
||||
def unpack_variants(data: Any):
|
||||
"""Unpack variants and remove signature info."""
|
||||
if isinstance(data, Variant):
|
||||
return unpack_variants(data.value)
|
||||
if isinstance(data, dict):
|
||||
return {k: unpack_variants(v) for k, v in data.items()}
|
||||
if isinstance(data, list):
|
||||
return [unpack_variants(item) for item in data]
|
||||
return data
|
||||
|
||||
|
||||
class SignatureType:
|
||||
"""A class that represents a single complete type within a signature.
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import dbus_fast.introspection as intr
|
||||
from dbus_fast import DBusError, aio, glib
|
||||
from dbus_fast.message import MessageFlag
|
||||
from dbus_fast.service import ServiceInterface, method
|
||||
from dbus_fast.signature import Variant
|
||||
from tests.util import check_gi_repository, skip_reason_no_gi
|
||||
|
||||
has_gi = check_gi_repository()
|
||||
@ -33,6 +34,11 @@ class ExampleInterface(ServiceInterface):
|
||||
def EchoThree(self, what1: "s", what2: "s", what3: "s") -> "sss":
|
||||
return [what1, what2, what3]
|
||||
|
||||
@method()
|
||||
def GetComplex(self) -> "a{sv}":
|
||||
"""Return complex output."""
|
||||
return {"hello": Variant("s", "world")}
|
||||
|
||||
@method()
|
||||
def ThrowsError(self):
|
||||
raise DBusError("test.error", "something went wrong")
|
||||
@ -81,6 +87,12 @@ async def test_aio_proxy_object():
|
||||
)
|
||||
assert result is None
|
||||
|
||||
result = await interface.call_get_complex()
|
||||
assert result == {"hello": Variant("s", "world")}
|
||||
|
||||
result = await interface.call_get_complex(unpack_variants=True)
|
||||
assert result == {"hello": "world"}
|
||||
|
||||
with pytest.raises(DBusError):
|
||||
try:
|
||||
await interface.call_throws_error()
|
||||
@ -120,6 +132,12 @@ def test_glib_proxy_object():
|
||||
result = interface.call_echo_three_sync("hello", "there", "world")
|
||||
assert result == ["hello", "there", "world"]
|
||||
|
||||
result = interface.call_get_complex_sync()
|
||||
assert result == {"hello": Variant("s", "world")}
|
||||
|
||||
result = interface.call_get_complex_sync(unpack_variants=True)
|
||||
assert result == {"hello": "world"}
|
||||
|
||||
with pytest.raises(DBusError):
|
||||
try:
|
||||
result = interface.call_throws_error_sync()
|
||||
|
||||
@ -2,6 +2,7 @@ import pytest
|
||||
|
||||
from dbus_fast import DBusError, Message, aio, glib
|
||||
from dbus_fast.service import PropertyAccess, ServiceInterface, dbus_property
|
||||
from dbus_fast.signature import Variant
|
||||
from tests.util import check_gi_repository, skip_reason_no_gi
|
||||
|
||||
has_gi = check_gi_repository()
|
||||
@ -27,6 +28,11 @@ class ExampleInterface(ServiceInterface):
|
||||
def Int64Property(self) -> "x":
|
||||
return self._int64_property
|
||||
|
||||
@dbus_property(access=PropertyAccess.READ)
|
||||
def ComplexProperty(self) -> "a{sv}":
|
||||
"""Return complex output."""
|
||||
return {"hello": Variant("s", "world")}
|
||||
|
||||
@dbus_property()
|
||||
def ErrorThrowingProperty(self) -> "s":
|
||||
raise DBusError(self.error_name, self.error_text)
|
||||
@ -59,6 +65,12 @@ async def test_aio_properties():
|
||||
await interface.set_some_property("different")
|
||||
assert service_interface._some_property == "different"
|
||||
|
||||
prop = await interface.get_complex_property()
|
||||
assert prop == {"hello": Variant("s", "world")}
|
||||
|
||||
prop = await interface.get_complex_property(unpack_variants=True)
|
||||
assert prop == {"hello": "world"}
|
||||
|
||||
with pytest.raises(DBusError):
|
||||
try:
|
||||
prop = await interface.get_error_throwing_property()
|
||||
@ -102,6 +114,12 @@ def test_glib_properties():
|
||||
interface.set_some_property_sync("different")
|
||||
assert service_interface._some_property == "different"
|
||||
|
||||
prop = interface.get_complex_property_sync()
|
||||
assert prop == {"hello": Variant("s", "world")}
|
||||
|
||||
prop = interface.get_complex_property_sync(unpack_variants=True)
|
||||
assert prop == {"hello": "world"}
|
||||
|
||||
with pytest.raises(DBusError):
|
||||
try:
|
||||
prop = interface.get_error_throwing_property_sync()
|
||||
|
||||
@ -2,10 +2,10 @@ import pytest
|
||||
|
||||
from dbus_fast import Message
|
||||
from dbus_fast.aio import MessageBus
|
||||
from dbus_fast.aio.proxy_object import ProxyInterface
|
||||
from dbus_fast.constants import RequestNameReply
|
||||
from dbus_fast.introspection import Node
|
||||
from dbus_fast.service import ServiceInterface, signal
|
||||
from dbus_fast.signature import Variant
|
||||
|
||||
|
||||
class ExampleInterface(ServiceInterface):
|
||||
@ -20,6 +20,11 @@ class ExampleInterface(ServiceInterface):
|
||||
def SignalMultiple(self) -> "ss":
|
||||
return ["hello", "world"]
|
||||
|
||||
@signal()
|
||||
def SignalComplex(self) -> "a{sv}":
|
||||
"""Broadcast a complex signal."""
|
||||
return {"hello": Variant("s", "world")}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_signals():
|
||||
@ -159,6 +164,69 @@ async def test_signals():
|
||||
bus3.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_signals():
|
||||
"""Test complex signals with and without signature removal."""
|
||||
bus1 = await MessageBus().connect()
|
||||
bus2 = await MessageBus().connect()
|
||||
|
||||
await bus1.request_name("test.signals.name")
|
||||
service_interface = ExampleInterface()
|
||||
bus1.export("/test/path", service_interface)
|
||||
|
||||
obj = bus2.get_proxy_object(
|
||||
"test.signals.name", "/test/path", bus1._introspect_export_path("/test/path")
|
||||
)
|
||||
interface = obj.get_interface(service_interface.name)
|
||||
|
||||
async def ping():
|
||||
await bus2.call(
|
||||
Message(
|
||||
destination=bus1.unique_name,
|
||||
interface="org.freedesktop.DBus.Peer",
|
||||
path="/test/path",
|
||||
member="Ping",
|
||||
)
|
||||
)
|
||||
|
||||
sig_handler_counter = 0
|
||||
sig_handler_err = None
|
||||
no_sig_handler_counter = 0
|
||||
no_sig_handler_err = None
|
||||
|
||||
def complex_handler_with_sig(value):
|
||||
nonlocal sig_handler_counter
|
||||
nonlocal sig_handler_err
|
||||
try:
|
||||
assert value == {"hello": Variant("s", "world")}
|
||||
sig_handler_counter += 1
|
||||
except AssertionError as ex:
|
||||
sig_handler_err = ex
|
||||
|
||||
def complex_handler_no_sig(value):
|
||||
nonlocal no_sig_handler_counter
|
||||
nonlocal no_sig_handler_err
|
||||
try:
|
||||
assert value == {"hello": "world"}
|
||||
no_sig_handler_counter += 1
|
||||
except AssertionError as ex:
|
||||
no_sig_handler_err = ex
|
||||
|
||||
interface.on_signal_complex(complex_handler_with_sig)
|
||||
interface.on_signal_complex(complex_handler_no_sig, unpack_variants=True)
|
||||
await ping()
|
||||
|
||||
service_interface.SignalComplex()
|
||||
await ping()
|
||||
assert sig_handler_err is None
|
||||
assert sig_handler_counter == 1
|
||||
assert no_sig_handler_err is None
|
||||
assert no_sig_handler_counter == 1
|
||||
|
||||
bus1.disconnect()
|
||||
bus2.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_varargs_callback():
|
||||
"""Test varargs callback for signal."""
|
||||
|
||||
56
tests/test_unpack_variants.py
Normal file
56
tests/test_unpack_variants.py
Normal file
@ -0,0 +1,56 @@
|
||||
"""Test unpack variants."""
|
||||
import pytest
|
||||
|
||||
from dbus_fast.signature import Variant, unpack_variants
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dictionary():
|
||||
"""Test variants unpacked from dictionary."""
|
||||
assert unpack_variants(
|
||||
{
|
||||
"string": Variant("s", "test"),
|
||||
"boolean": Variant("b", True),
|
||||
"int": Variant("u", 1),
|
||||
"object": Variant("o", "/test/path"),
|
||||
"array": Variant("as", ["test", "value"]),
|
||||
"tuple": Variant("(su)", ["test", 1]),
|
||||
"bytes": Variant("ay", b"\0x62\0x75\0x66"),
|
||||
}
|
||||
) == {
|
||||
"string": "test",
|
||||
"boolean": True,
|
||||
"int": 1,
|
||||
"object": "/test/path",
|
||||
"array": ["test", "value"],
|
||||
"tuple": ["test", 1],
|
||||
"bytes": b"\0x62\0x75\0x66",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_list():
|
||||
"""Test variants unpacked from multiple outputs."""
|
||||
assert unpack_variants(
|
||||
[{"hello": Variant("s", "world")}, {"boolean": Variant("b", True)}, 1]
|
||||
) == [{"hello": "world"}, {"boolean": True}, 1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_variants():
|
||||
"""Test unpack variants handles nesting."""
|
||||
assert unpack_variants(
|
||||
{
|
||||
"dict": Variant("a{sv}", {"hello": Variant("s", "world")}),
|
||||
"array": Variant(
|
||||
"aa{sv}",
|
||||
[
|
||||
{"hello": Variant("s", "world")},
|
||||
{"bytes": Variant("ay", b"\0x62\0x75\0x66")},
|
||||
],
|
||||
),
|
||||
}
|
||||
) == {
|
||||
"dict": {"hello": "world"},
|
||||
"array": [{"hello": "world"}, {"bytes": b"\0x62\0x75\0x66"}],
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user