feat: add unpack variants option (#20)
This commit is contained in:
@@ -9,6 +9,7 @@ from ..message import Message, MessageFlag
|
||||
from ..message_bus import BaseMessageBus
|
||||
from ..proxy_object import BaseProxyInterface, BaseProxyObject
|
||||
from ..signature import Variant
|
||||
from ..signature import unpack_variants as unpack
|
||||
|
||||
|
||||
class ProxyInterface(BaseProxyInterface):
|
||||
@@ -74,7 +75,9 @@ class ProxyInterface(BaseProxyInterface):
|
||||
"""
|
||||
|
||||
def _add_method(self, intr_method):
|
||||
async def method_fn(*args, flags=MessageFlag.NONE):
|
||||
async def method_fn(
|
||||
*args, flags=MessageFlag.NONE, unpack_variants: bool = False
|
||||
):
|
||||
input_body, unix_fds = replace_fds_with_idx(
|
||||
intr_method.in_signature, list(args)
|
||||
)
|
||||
@@ -103,16 +106,24 @@ class ProxyInterface(BaseProxyInterface):
|
||||
|
||||
if not out_len:
|
||||
return None
|
||||
elif out_len == 1:
|
||||
|
||||
if unpack_variants:
|
||||
body = unpack(body)
|
||||
|
||||
if out_len == 1:
|
||||
return body[0]
|
||||
else:
|
||||
return body
|
||||
return body
|
||||
|
||||
method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
|
||||
setattr(self, method_name, method_fn)
|
||||
|
||||
def _add_property(self, intr_property):
|
||||
async def property_getter():
|
||||
def _add_property(
|
||||
self,
|
||||
intr_property,
|
||||
):
|
||||
async def property_getter(
|
||||
*, flags=MessageFlag.NONE, unpack_variants: bool = False
|
||||
):
|
||||
msg = await self.bus.call(
|
||||
Message(
|
||||
destination=self.bus_name,
|
||||
@@ -133,7 +144,11 @@ class ProxyInterface(BaseProxyInterface):
|
||||
msg,
|
||||
)
|
||||
|
||||
return replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value
|
||||
body = replace_idx_with_fds("v", msg.body, msg.unix_fds)[0].value
|
||||
|
||||
if unpack_variants:
|
||||
return unpack(body)
|
||||
return body
|
||||
|
||||
async def property_setter(val):
|
||||
variant = Variant(intr_property.signature, val)
|
||||
|
||||
@@ -8,6 +8,7 @@ from ..message import Message
|
||||
from ..message_bus import BaseMessageBus
|
||||
from ..proxy_object import BaseProxyInterface, BaseProxyObject
|
||||
from ..signature import Variant
|
||||
from ..signature import unpack_variants as unpack
|
||||
|
||||
# glib is optional
|
||||
try:
|
||||
@@ -113,7 +114,7 @@ class ProxyInterface(BaseProxyInterface):
|
||||
in_len = len(intr_method.in_args)
|
||||
out_len = len(intr_method.out_args)
|
||||
|
||||
def method_fn(*args):
|
||||
def method_fn(*args, unpack_variants: bool = False):
|
||||
if len(args) != in_len + 1:
|
||||
raise TypeError(
|
||||
f"method {intr_method.name} expects {in_len} arguments and a callback (got {len(args)} args)"
|
||||
@@ -136,7 +137,10 @@ class ProxyInterface(BaseProxyInterface):
|
||||
except DBusError as e:
|
||||
err = e
|
||||
|
||||
callback(msg.body, err)
|
||||
if unpack_variants:
|
||||
callback(unpack(msg.body), err)
|
||||
else:
|
||||
callback(msg.body, err)
|
||||
|
||||
self.bus.call(
|
||||
Message(
|
||||
@@ -150,7 +154,7 @@ class ProxyInterface(BaseProxyInterface):
|
||||
call_notify,
|
||||
)
|
||||
|
||||
def method_fn_sync(*args):
|
||||
def method_fn_sync(*args, unpack_variants: bool = False):
|
||||
main = GLib.MainLoop()
|
||||
call_error = None
|
||||
call_body = None
|
||||
@@ -171,10 +175,13 @@ class ProxyInterface(BaseProxyInterface):
|
||||
|
||||
if not out_len:
|
||||
return None
|
||||
elif out_len == 1:
|
||||
|
||||
if unpack_variants:
|
||||
call_body = unpack(call_body)
|
||||
|
||||
if out_len == 1:
|
||||
return call_body[0]
|
||||
else:
|
||||
return call_body
|
||||
return call_body
|
||||
|
||||
method_name = f"call_{BaseProxyInterface._to_snake_case(intr_method.name)}"
|
||||
method_name_sync = f"{method_name}_sync"
|
||||
@@ -183,7 +190,7 @@ class ProxyInterface(BaseProxyInterface):
|
||||
setattr(self, method_name_sync, method_fn_sync)
|
||||
|
||||
def _add_property(self, intr_property):
|
||||
def property_getter(callback):
|
||||
def property_getter(callback, *, unpack_variants: bool = False):
|
||||
def call_notify(msg, err):
|
||||
if err:
|
||||
callback(None, err)
|
||||
@@ -204,8 +211,10 @@ class ProxyInterface(BaseProxyInterface):
|
||||
)
|
||||
callback(None, err)
|
||||
return
|
||||
|
||||
callback(variant.value, None)
|
||||
if unpack_variants:
|
||||
callback(unpack(variant.value), None)
|
||||
else:
|
||||
callback(variant.value, None)
|
||||
|
||||
self.bus.call(
|
||||
Message(
|
||||
@@ -219,7 +228,7 @@ class ProxyInterface(BaseProxyInterface):
|
||||
call_notify,
|
||||
)
|
||||
|
||||
def property_getter_sync():
|
||||
def property_getter_sync(*, unpack_variants: bool = False):
|
||||
property_value = None
|
||||
reply_error = None
|
||||
|
||||
@@ -236,6 +245,8 @@ class ProxyInterface(BaseProxyInterface):
|
||||
main.run()
|
||||
if reply_error:
|
||||
raise reply_error
|
||||
if unpack_variants:
|
||||
return unpack(property_value)
|
||||
return property_value
|
||||
|
||||
def property_setter(value, callback):
|
||||
|
||||
@@ -3,7 +3,8 @@ import inspect
|
||||
import logging
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Coroutine, List, Type, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Coroutine, Dict, List, Type, Union
|
||||
|
||||
from . import introspection as intr
|
||||
from . import message_bus
|
||||
@@ -11,9 +12,18 @@ from ._private.util import replace_idx_with_fds
|
||||
from .constants import ErrorType, MessageType
|
||||
from .errors import DBusError, InterfaceNotFoundError
|
||||
from .message import Message
|
||||
from .signature import unpack_variants as unpack
|
||||
from .validators import assert_bus_name_valid, assert_object_path_valid
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignalHandler:
|
||||
"""Signal handler."""
|
||||
|
||||
fn: Callable
|
||||
unpack_variants: bool
|
||||
|
||||
|
||||
class BaseProxyInterface:
|
||||
"""An abstract class representing a proxy to an interface exported on the bus by another client.
|
||||
|
||||
@@ -46,7 +56,7 @@ class BaseProxyInterface:
|
||||
self.path = path
|
||||
self.introspection = introspection
|
||||
self.bus = bus
|
||||
self._signal_handlers = {}
|
||||
self._signal_handlers: Dict[str, List[SignalHandler]] = {}
|
||||
self._signal_match_rule = f"type='signal',sender={bus_name},interface={introspection.name},path={path}"
|
||||
|
||||
_underscorer1 = re.compile(r"(.)([A-Z][a-z]+)")
|
||||
@@ -110,13 +120,21 @@ class BaseProxyInterface:
|
||||
return
|
||||
|
||||
body = replace_idx_with_fds(msg.signature, msg.body, msg.unix_fds)
|
||||
no_sig = None
|
||||
for handler in self._signal_handlers[msg.member]:
|
||||
cb_result = handler(*body)
|
||||
if handler.unpack_variants:
|
||||
if not no_sig:
|
||||
no_sig = unpack(body)
|
||||
data = no_sig
|
||||
else:
|
||||
data = body
|
||||
|
||||
cb_result = handler.fn(*data)
|
||||
if isinstance(cb_result, Coroutine):
|
||||
asyncio.create_task(cb_result)
|
||||
|
||||
def _add_signal(self, intr_signal, interface):
|
||||
def on_signal_fn(fn):
|
||||
def on_signal_fn(fn, *, unpack_variants: bool = False):
|
||||
fn_signature = inspect.signature(fn)
|
||||
if len(fn_signature.parameters) != len(intr_signal.args) and (
|
||||
inspect.Parameter.VAR_POSITIONAL
|
||||
@@ -134,11 +152,15 @@ class BaseProxyInterface:
|
||||
if intr_signal.name not in self._signal_handlers:
|
||||
self._signal_handlers[intr_signal.name] = []
|
||||
|
||||
self._signal_handlers[intr_signal.name].append(fn)
|
||||
self._signal_handlers[intr_signal.name].append(
|
||||
SignalHandler(fn, unpack_variants)
|
||||
)
|
||||
|
||||
def off_signal_fn(fn):
|
||||
def off_signal_fn(fn, *, unpack_variants: bool = False):
|
||||
try:
|
||||
i = self._signal_handlers[intr_signal.name].index(fn)
|
||||
i = self._signal_handlers[intr_signal.name].index(
|
||||
SignalHandler(fn, unpack_variants)
|
||||
)
|
||||
del self._signal_handlers[intr_signal.name][i]
|
||||
if not self._signal_handlers[intr_signal.name]:
|
||||
del self._signal_handlers[intr_signal.name]
|
||||
|
||||
@@ -5,6 +5,17 @@ from .errors import InvalidSignatureError, SignatureBodyMismatchError
|
||||
from .validators import is_object_path_valid
|
||||
|
||||
|
||||
def unpack_variants(data: Any):
|
||||
"""Unpack variants and remove signature info."""
|
||||
if isinstance(data, Variant):
|
||||
return unpack_variants(data.value)
|
||||
if isinstance(data, dict):
|
||||
return {k: unpack_variants(v) for k, v in data.items()}
|
||||
if isinstance(data, list):
|
||||
return [unpack_variants(item) for item in data]
|
||||
return data
|
||||
|
||||
|
||||
class SignatureType:
|
||||
"""A class that represents a single complete type within a signature.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user