feat: add unpack variants option (#20)

This commit is contained in:
Mike Degatano 2022-09-19 20:43:58 -04:00 committed by GitHub
parent 1209048551
commit cfad28bd2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 244 additions and 25 deletions

View File

@ -9,6 +9,7 @@ from ..message import Message, MessageFlag
from ..message_bus import BaseMessageBus
from ..proxy_object import BaseProxyInterface, BaseProxyObject
from ..signature import Variant
from ..signature import unpack_variants as unpack
class ProxyInterface(BaseProxyInterface):
@ -74,7 +75,9 @@ class ProxyInterface(BaseProxyInterface):
"""
def _add_method(self, intr_method):
async def method_fn(*args, flags=MessageFlag.NONE):
async def method_fn(
*args, flags=MessageFlag.NONE, unpack_variants: bool = False
):
input_body, unix_fds = replace_fds_with_idx(
intr_method.in_signature, list(args)
)
@ -103,16 +106,24 @@ class ProxyInterface(BaseProxyInterface):
if not out_len:
return None
elif out_len == 1:
if unpack_variants:
body = unpack(body)
if out_len == 1:
return body[0]
else:
return body
return body
method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
setattr(self, method_name, method_fn)
def _add_property(self, intr_property):
async def property_getter():
def _add_property(
self,
intr_property,
):
async def property_getter(
*, flags=MessageFlag.NONE, unpack_variants: bool = False
):
msg = await self.bus.call(
Message(
destination=self.bus_name,
@ -133,7 +144,11 @@ class ProxyInterface(BaseProxyInterface):
msg,
)
return replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value
body = replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value
if unpack_variants:
return unpack(body)
return body
async def property_setter(val):
variant = Variant(intr_property.signature, val)

View File

@ -8,6 +8,7 @@ from ..message import Message
from ..message_bus import BaseMessageBus
from ..proxy_object import BaseProxyInterface, BaseProxyObject
from ..signature import Variant
from ..signature import unpack_variants as unpack
# glib is optional
try:
@ -113,7 +114,7 @@ class ProxyInterface(BaseProxyInterface):
in_len = len(intr_method.in_args)
out_len = len(intr_method.out_args)
def method_fn(*args):
def method_fn(*args, unpack_variants: bool = False):
if len(args) != in_len + 1:
raise TypeError(
f"method {intr_method.name} expects {in_len} arguments and a callback (got {len(args)} args)"
@ -136,7 +137,10 @@ class ProxyInterface(BaseProxyInterface):
except DBusError as e:
err = e
callback(msg.body, err)
if unpack_variants:
callback(unpack(msg.body), err)
else:
callback(msg.body, err)
self.bus.call(
Message(
@ -150,7 +154,7 @@ class ProxyInterface(BaseProxyInterface):
call_notify,
)
def method_fn_sync(*args):
def method_fn_sync(*args, unpack_variants: bool = False):
main = GLib.MainLoop()
call_error = None
call_body = None
@ -171,10 +175,13 @@ class ProxyInterface(BaseProxyInterface):
if not out_len:
return None
elif out_len == 1:
if unpack_variants:
call_body = unpack(call_body)
if out_len == 1:
return call_body[0]
else:
return call_body
return call_body
method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
method_name_sync = f"{method_name}_sync"
@ -183,7 +190,7 @@ class ProxyInterface(BaseProxyInterface):
setattr(self, method_name_sync, method_fn_sync)
def _add_property(self, intr_property):
def property_getter(callback):
def property_getter(callback, *, unpack_variants: bool = False):
def call_notify(msg, err):
if err:
callback(None, err)
@ -204,8 +211,10 @@ class ProxyInterface(BaseProxyInterface):
)
callback(None, err)
return
callback(variant.value, None)
if unpack_variants:
callback(unpack(variant.value), None)
else:
callback(variant.value, None)
self.bus.call(
Message(
@ -219,7 +228,7 @@ class ProxyInterface(BaseProxyInterface):
call_notify,
)
def property_getter_sync():
def property_getter_sync(*, unpack_variants: bool = False):
property_value = None
reply_error = None
@ -236,6 +245,8 @@ class ProxyInterface(BaseProxyInterface):
main.run()
if reply_error:
raise reply_error
if unpack_variants:
return unpack(property_value)
return property_value
def property_setter(value, callback):

View File

@ -3,7 +3,8 @@ import inspect
import logging
import re
import xml.etree.ElementTree as ET
from typing import Coroutine, List, Type, Union
from dataclasses import dataclass
from typing import Callable, Coroutine, Dict, List, Type, Union
from . import introspection as intr
from . import message_bus
@ -11,9 +12,18 @@ from ._private.util import replace_idx_with_fds
from .constants import ErrorType, MessageType
from .errors import DBusError, InterfaceNotFoundError
from .message import Message
from .signature import unpack_variants as unpack
from .validators import assert_bus_name_valid, assert_object_path_valid
@dataclass
class SignalHandler:
"""Signal handler."""
fn: Callable
unpack_variants: bool
class BaseProxyInterface:
"""An abstract class representing a proxy to an interface exported on the bus by another client.
@ -46,7 +56,7 @@ class BaseProxyInterface:
self.path = path
self.introspection = introspection
self.bus = bus
self._signal_handlers = {}
self._signal_handlers: Dict[str, List[SignalHandler]] = {}
self._signal_match_rule = f"type='signal',sender={bus_name},interface={introspection.name},path={path}"
_underscorer1 = re.compile(r"(.)([A-Z][a-z]+)")
@ -110,13 +120,21 @@ class BaseProxyInterface:
return
body = replace_idx_with_fds(msg.signature, msg.body, msg.unix_fds)
no_sig = None
for handler in self._signal_handlers[msg.member]:
cb_result = handler(*body)
if handler.unpack_variants:
if not no_sig:
no_sig = unpack(body)
data = no_sig
else:
data = body
cb_result = handler.fn(*data)
if isinstance(cb_result, Coroutine):
asyncio.create_task(cb_result)
def _add_signal(self, intr_signal, interface):
def on_signal_fn(fn):
def on_signal_fn(fn, *, unpack_variants: bool = False):
fn_signature = inspect.signature(fn)
if len(fn_signature.parameters) != len(intr_signal.args) and (
inspect.Parameter.VAR_POSITIONAL
@ -134,11 +152,15 @@ class BaseProxyInterface:
if intr_signal.name not in self._signal_handlers:
self._signal_handlers[intr_signal.name] = []
self._signal_handlers[intr_signal.name].append(fn)
self._signal_handlers[intr_signal.name].append(
SignalHandler(fn, unpack_variants)
)
def off_signal_fn(fn):
def off_signal_fn(fn, *, unpack_variants: bool = False):
try:
i = self._signal_handlers[intr_signal.name].index(fn)
i = self._signal_handlers[intr_signal.name].index(
SignalHandler(fn, unpack_variants)
)
del self._signal_handlers[intr_signal.name][i]
if not self._signal_handlers[intr_signal.name]:
del self._signal_handlers[intr_signal.name]

View File

@ -5,6 +5,17 @@ from .errors import InvalidSignatureError, SignatureBodyMismatchError
from .validators import is_object_path_valid
def unpack_variants(data: Any):
"""Unpack variants and remove signature info."""
if isinstance(data, Variant):
return unpack_variants(data.value)
if isinstance(data, dict):
return {k: unpack_variants(v) for k, v in data.items()}
if isinstance(data, list):
return [unpack_variants(item) for item in data]
return data
class SignatureType:
"""A class that represents a single complete type within a signature.

View File

@ -4,6 +4,7 @@ import dbus_fast.introspection as intr
from dbus_fast import DBusError, aio, glib
from dbus_fast.message import MessageFlag
from dbus_fast.service import ServiceInterface, method
from dbus_fast.signature import Variant
from tests.util import check_gi_repository, skip_reason_no_gi
has_gi = check_gi_repository()
@ -33,6 +34,11 @@ class ExampleInterface(ServiceInterface):
def EchoThree(self, what1: "s", what2: "s", what3: "s") -> "sss":
return [what1, what2, what3]
@method()
def GetComplex(self) -> "a{sv}":
"""Return complex output."""
return {"hello": Variant("s", "world")}
@method()
def ThrowsError(self):
raise DBusError("test.error", "something went wrong")
@ -81,6 +87,12 @@ async def test_aio_proxy_object():
)
assert result is None
result = await interface.call_get_complex()
assert result == {"hello": Variant("s", "world")}
result = await interface.call_get_complex(unpack_variants=True)
assert result == {"hello": "world"}
with pytest.raises(DBusError):
try:
await interface.call_throws_error()
@ -120,6 +132,12 @@ def test_glib_proxy_object():
result = interface.call_echo_three_sync("hello", "there", "world")
assert result == ["hello", "there", "world"]
result = interface.call_get_complex_sync()
assert result == {"hello": Variant("s", "world")}
result = interface.call_get_complex_sync(unpack_variants=True)
assert result == {"hello": "world"}
with pytest.raises(DBusError):
try:
result = interface.call_throws_error_sync()

View File

@ -2,6 +2,7 @@ import pytest
from dbus_fast import DBusError, Message, aio, glib
from dbus_fast.service import PropertyAccess, ServiceInterface, dbus_property
from dbus_fast.signature import Variant
from tests.util import check_gi_repository, skip_reason_no_gi
has_gi = check_gi_repository()
@ -27,6 +28,11 @@ class ExampleInterface(ServiceInterface):
def Int64Property(self) -> "x":
return self._int64_property
@dbus_property(access=PropertyAccess.READ)
def ComplexProperty(self) -> "a{sv}":
"""Return complex output."""
return {"hello": Variant("s", "world")}
@dbus_property()
def ErrorThrowingProperty(self) -> "s":
raise DBusError(self.error_name, self.error_text)
@ -59,6 +65,12 @@ async def test_aio_properties():
await interface.set_some_property("different")
assert service_interface._some_property == "different"
prop = await interface.get_complex_property()
assert prop == {"hello": Variant("s", "world")}
prop = await interface.get_complex_property(unpack_variants=True)
assert prop == {"hello": "world"}
with pytest.raises(DBusError):
try:
prop = await interface.get_error_throwing_property()
@ -102,6 +114,12 @@ def test_glib_properties():
interface.set_some_property_sync("different")
assert service_interface._some_property == "different"
prop = interface.get_complex_property_sync()
assert prop == {"hello": Variant("s", "world")}
prop = interface.get_complex_property_sync(unpack_variants=True)
assert prop == {"hello": "world"}
with pytest.raises(DBusError):
try:
prop = interface.get_error_throwing_property_sync()

View File

@ -2,10 +2,10 @@ import pytest
from dbus_fast import Message
from dbus_fast.aio import MessageBus
from dbus_fast.aio.proxy_object import ProxyInterface
from dbus_fast.constants import RequestNameReply
from dbus_fast.introspection import Node
from dbus_fast.service import ServiceInterface, signal
from dbus_fast.signature import Variant
class ExampleInterface(ServiceInterface):
@ -20,6 +20,11 @@ class ExampleInterface(ServiceInterface):
def SignalMultiple(self) -> "ss":
return ["hello", "world"]
@signal()
def SignalComplex(self) -> "a{sv}":
"""Broadcast a complex signal."""
return {"hello": Variant("s", "world")}
@pytest.mark.asyncio
async def test_signals():
@ -159,6 +164,69 @@ async def test_signals():
bus3.disconnect()
@pytest.mark.asyncio
async def test_complex_signals():
"""Test complex signals with and without signature removal."""
bus1 = await MessageBus().connect()
bus2 = await MessageBus().connect()
await bus1.request_name("test.signals.name")
service_interface = ExampleInterface()
bus1.export("/test/path", service_interface)
obj = bus2.get_proxy_object(
"test.signals.name", "/test/path", bus1._introspect_export_path("/test/path")
)
interface = obj.get_interface(service_interface.name)
async def ping():
await bus2.call(
Message(
destination=bus1.unique_name,
interface="org.freedesktop.DBus.Peer",
path="/test/path",
member="Ping",
)
)
sig_handler_counter = 0
sig_handler_err = None
no_sig_handler_counter = 0
no_sig_handler_err = None
def complex_handler_with_sig(value):
nonlocal sig_handler_counter
nonlocal sig_handler_err
try:
assert value == {"hello": Variant("s", "world")}
sig_handler_counter += 1
except AssertionError as ex:
sig_handler_err = ex
def complex_handler_no_sig(value):
nonlocal no_sig_handler_counter
nonlocal no_sig_handler_err
try:
assert value == {"hello": "world"}
no_sig_handler_counter += 1
except AssertionError as ex:
no_sig_handler_err = ex
interface.on_signal_complex(complex_handler_with_sig)
interface.on_signal_complex(complex_handler_no_sig, unpack_variants=True)
await ping()
service_interface.SignalComplex()
await ping()
assert sig_handler_err is None
assert sig_handler_counter == 1
assert no_sig_handler_err is None
assert no_sig_handler_counter == 1
bus1.disconnect()
bus2.disconnect()
@pytest.mark.asyncio
async def test_varargs_callback():
"""Test varargs callback for signal."""

View File

@ -0,0 +1,56 @@
"""Test unpack variants."""
import pytest
from dbus_fast.signature import Variant, unpack_variants
@pytest.mark.asyncio
async def test_dictionary():
"""Test variants unpacked from dictionary."""
assert unpack_variants(
{
"string": Variant("s", "test"),
"boolean": Variant("b", True),
"int": Variant("u", 1),
"object": Variant("o", "/test/path"),
"array": Variant("as", ["test", "value"]),
"tuple": Variant("(su)", ["test", 1]),
"bytes": Variant("ay", b"\0x62\0x75\0x66"),
}
) == {
"string": "test",
"boolean": True,
"int": 1,
"object": "/test/path",
"array": ["test", "value"],
"tuple": ["test", 1],
"bytes": b"\0x62\0x75\0x66",
}
@pytest.mark.asyncio
async def test_output_list():
"""Test variants unpacked from multiple outputs."""
assert unpack_variants(
[{"hello": Variant("s", "world")}, {"boolean": Variant("b", True)}, 1]
) == [{"hello": "world"}, {"boolean": True}, 1]
@pytest.mark.asyncio
async def test_nested_variants():
"""Test unpack variants handles nesting."""
assert unpack_variants(
{
"dict": Variant("a{sv}", {"hello": Variant("s", "world")}),
"array": Variant(
"aa{sv}",
[
{"hello": Variant("s", "world")},
{"bytes": Variant("ay", b"\0x62\0x75\0x66")},
],
),
}
) == {
"dict": {"hello": "world"},
"array": [{"hello": "world"}, {"bytes": b"\0x62\0x75\0x66"}],
}