feat: add unpack variants option (#20)

This commit is contained in:
Mike Degatano
2022-09-19 20:43:58 -04:00
committed by GitHub
parent 1209048551
commit cfad28bd2b
8 changed files with 244 additions and 25 deletions

View File

@@ -9,6 +9,7 @@ from ..message import Message, MessageFlag
from ..message_bus import BaseMessageBus
from ..proxy_object import BaseProxyInterface, BaseProxyObject
from ..signature import Variant
from ..signature import unpack_variants as unpack
class ProxyInterface(BaseProxyInterface):
@@ -74,7 +75,9 @@ class ProxyInterface(BaseProxyInterface):
"""
def _add_method(self, intr_method):
async def method_fn(*args, flags=MessageFlag.NONE):
async def method_fn(
*args, flags=MessageFlag.NONE, unpack_variants: bool = False
):
input_body, unix_fds = replace_fds_with_idx(
intr_method.in_signature, list(args)
)
@@ -103,16 +106,24 @@ class ProxyInterface(BaseProxyInterface):
if not out_len:
return None
elif out_len == 1:
if unpack_variants:
body = unpack(body)
if out_len == 1:
return body[0]
else:
return body
return body
method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
setattr(self, method_name, method_fn)
def _add_property(self, intr_property):
async def property_getter():
def _add_property(
self,
intr_property,
):
async def property_getter(
*, flags=MessageFlag.NONE, unpack_variants: bool = False
):
msg = await self.bus.call(
Message(
destination=self.bus_name,
@@ -133,7 +144,11 @@ class ProxyInterface(BaseProxyInterface):
msg,
)
return replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value
body = replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value
if unpack_variants:
return unpack(body)
return body
async def property_setter(val):
variant = Variant(intr_property.signature, val)

View File

@@ -8,6 +8,7 @@ from ..message import Message
from ..message_bus import BaseMessageBus
from ..proxy_object import BaseProxyInterface, BaseProxyObject
from ..signature import Variant
from ..signature import unpack_variants as unpack
# glib is optional
try:
@@ -113,7 +114,7 @@ class ProxyInterface(BaseProxyInterface):
in_len = len(intr_method.in_args)
out_len = len(intr_method.out_args)
def method_fn(*args):
def method_fn(*args, unpack_variants: bool = False):
if len(args) != in_len + 1:
raise TypeError(
f"method {intr_method.name} expects {in_len} arguments and a callback (got {len(args)} args)"
@@ -136,7 +137,10 @@ class ProxyInterface(BaseProxyInterface):
except DBusError as e:
err = e
callback(msg.body, err)
if unpack_variants:
callback(unpack(msg.body), err)
else:
callback(msg.body, err)
self.bus.call(
Message(
@@ -150,7 +154,7 @@ class ProxyInterface(BaseProxyInterface):
call_notify,
)
def method_fn_sync(*args):
def method_fn_sync(*args, unpack_variants: bool = False):
main = GLib.MainLoop()
call_error = None
call_body = None
@@ -171,10 +175,13 @@ class ProxyInterface(BaseProxyInterface):
if not out_len:
return None
elif out_len == 1:
if unpack_variants:
call_body = unpack(call_body)
if out_len == 1:
return call_body[0]
else:
return call_body
return call_body
method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
method_name_sync = f"{method_name}_sync"
@@ -183,7 +190,7 @@ class ProxyInterface(BaseProxyInterface):
setattr(self, method_name_sync, method_fn_sync)
def _add_property(self, intr_property):
def property_getter(callback):
def property_getter(callback, *, unpack_variants: bool = False):
def call_notify(msg, err):
if err:
callback(None, err)
@@ -204,8 +211,10 @@ class ProxyInterface(BaseProxyInterface):
)
callback(None, err)
return
callback(variant.value, None)
if unpack_variants:
callback(unpack(variant.value), None)
else:
callback(variant.value, None)
self.bus.call(
Message(
@@ -219,7 +228,7 @@ class ProxyInterface(BaseProxyInterface):
call_notify,
)
def property_getter_sync():
def property_getter_sync(*, unpack_variants: bool = False):
property_value = None
reply_error = None
@@ -236,6 +245,8 @@ class ProxyInterface(BaseProxyInterface):
main.run()
if reply_error:
raise reply_error
if unpack_variants:
return unpack(property_value)
return property_value
def property_setter(value, callback):

View File

@@ -3,7 +3,8 @@ import inspect
import logging
import re
import xml.etree.ElementTree as ET
from typing import Coroutine, List, Type, Union
from dataclasses import dataclass
from typing import Callable, Coroutine, Dict, List, Type, Union
from . import introspection as intr
from . import message_bus
@@ -11,9 +12,18 @@ from ._private.util import replace_idx_with_fds
from .constants import ErrorType, MessageType
from .errors import DBusError, InterfaceNotFoundError
from .message import Message
from .signature import unpack_variants as unpack
from .validators import assert_bus_name_valid, assert_object_path_valid
@dataclass
class SignalHandler:
"""Signal handler."""
fn: Callable
unpack_variants: bool
class BaseProxyInterface:
"""An abstract class representing a proxy to an interface exported on the bus by another client.
@@ -46,7 +56,7 @@ class BaseProxyInterface:
self.path = path
self.introspection = introspection
self.bus = bus
self._signal_handlers = {}
self._signal_handlers: Dict[str, List[SignalHandler]] = {}
self._signal_match_rule = f"type='signal',sender={bus_name},interface={introspection.name},path={path}"
_underscorer1 = re.compile(r"(.)([A-Z][a-z]+)")
@@ -110,13 +120,21 @@ class BaseProxyInterface:
return
body = replace_idx_with_fds(msg.signature, msg.body, msg.unix_fds)
no_sig = None
for handler in self._signal_handlers[msg.member]:
cb_result = handler(*body)
if handler.unpack_variants:
if not no_sig:
no_sig = unpack(body)
data = no_sig
else:
data = body
cb_result = handler.fn(*data)
if isinstance(cb_result, Coroutine):
asyncio.create_task(cb_result)
def _add_signal(self, intr_signal, interface):
def on_signal_fn(fn):
def on_signal_fn(fn, *, unpack_variants: bool = False):
fn_signature = inspect.signature(fn)
if len(fn_signature.parameters) != len(intr_signal.args) and (
inspect.Parameter.VAR_POSITIONAL
@@ -134,11 +152,15 @@ class BaseProxyInterface:
if intr_signal.name not in self._signal_handlers:
self._signal_handlers[intr_signal.name] = []
self._signal_handlers[intr_signal.name].append(fn)
self._signal_handlers[intr_signal.name].append(
SignalHandler(fn, unpack_variants)
)
def off_signal_fn(fn):
def off_signal_fn(fn, *, unpack_variants: bool = False):
try:
i = self._signal_handlers[intr_signal.name].index(fn)
i = self._signal_handlers[intr_signal.name].index(
SignalHandler(fn, unpack_variants)
)
del self._signal_handlers[intr_signal.name][i]
if not self._signal_handlers[intr_signal.name]:
del self._signal_handlers[intr_signal.name]

View File

@@ -5,6 +5,17 @@ from .errors import InvalidSignatureError, SignatureBodyMismatchError
from .validators import is_object_path_valid
def unpack_variants(data: Any):
"""Unpack variants and remove signature info."""
if isinstance(data, Variant):
return unpack_variants(data.value)
if isinstance(data, dict):
return {k: unpack_variants(v) for k, v in data.items()}
if isinstance(data, list):
return [unpack_variants(item) for item in data]
return data
class SignatureType:
"""A class that represents a single complete type within a signature.