feat: speed up unmarshaller (#1)

This commit is contained in:
J. Nick Koston 2022-09-09 09:58:12 -05:00 committed by GitHub
parent 2c9cdcc173
commit eca1d31781
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 421 additions and 317 deletions

22
bench/unmarshall.py Normal file
View File

@ -0,0 +1,22 @@
import io
import timeit
from dbus_fast._private.unmarshaller import Unmarshaller
bluez_rssi_message = (
"6c04010134000000e25389019500000001016f00250000002f6f72672f626c75657a2f686369302f6465"
"765f30385f33415f46325f31455f32425f3631000000020173001f0000006f72672e667265656465736b"
"746f702e444275732e50726f7065727469657300030173001100000050726f706572746965734368616e"
"67656400000000000000080167000873617b73767d617300000007017300040000003a312e3400000000"
"110000006f72672e626c75657a2e446576696365310000000e0000000000000004000000525353490001"
"6e00a7ff000000000000"
)
def unmarhsall_bluez_rssi_message():
Unmarshaller(io.BytesIO(bytes.fromhex(bluez_rssi_message))).unmarshall()
count = 1000000
time = timeit.Timer(unmarhsall_bluez_rssi_message).timeit(count)
print(f"Unmarshalling {count} bluetooth rssi messages took {time} seconds")

View File

@ -16,3 +16,6 @@ class HeaderField(Enum):
SENDER = 7
SIGNATURE = 8
UNIX_FDS = 9
HEADER_NAME_MAP = {field.value: field.name for field in HeaderField}

View File

@ -1,315 +1,348 @@
import array
import io
import socket
from codecs import decode
from struct import unpack_from
import sys
from struct import Struct
from typing import Any, Callable, Dict, List, Optional, Tuple
from ..constants import MessageFlag, MessageType
from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP, MessageFlag, MessageType
from ..errors import InvalidMessageError
from ..message import Message
from ..signature import SignatureTree, Variant
from .constants import BIG_ENDIAN, LITTLE_ENDIAN, PROTOCOL_VERSION, HeaderField
from ..signature import SignatureTree, SignatureType, Variant
from .constants import (
BIG_ENDIAN,
HEADER_NAME_MAP,
LITTLE_ENDIAN,
PROTOCOL_VERSION,
HeaderField,
)
MAX_UNIX_FDS = 16
UNPACK_SYMBOL = {LITTLE_ENDIAN: "<", BIG_ENDIAN: ">"}
UNPACK_LENGTHS = {BIG_ENDIAN: Struct(">III"), LITTLE_ENDIAN: Struct("<III")}
DBUS_TO_CTYPE = {
"y": ("B", 1), # byte
"n": ("h", 2), # int16
"q": ("H", 2), # uint16
"i": ("i", 4), # int32
"u": ("I", 4), # uint32
"x": ("q", 8), # int64
"t": ("Q", 8), # uint64
"d": ("d", 8), # double
"h": ("I", 4), # uint32
}
HEADER_SIGNATURE_SIZE = 16
HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION = 12
UINT32_SIGNATURE = SignatureTree._get("u").types[0]
HEADER_DESTINATION = HeaderField.DESTINATION.name
HEADER_PATH = HeaderField.PATH.name
HEADER_INTERFACE = HeaderField.INTERFACE.name
HEADER_MEMBER = HeaderField.MEMBER.name
HEADER_ERROR_NAME = HeaderField.ERROR_NAME.name
HEADER_REPLY_SERIAL = HeaderField.REPLY_SERIAL.name
HEADER_SENDER = HeaderField.SENDER.name
READER_TYPE = Dict[
str,
Tuple[
Optional[Callable[["Unmarshaller", SignatureType], Any]],
Optional[str],
Optional[int],
Optional[Struct],
],
]
class MarshallerStreamEndError(Exception):
"""This exception is raised when the end of the stream is reached.
This means more data is expected on the wire that has not yet been
received. The caller should call unmarshall later when more data is
available.
"""
pass
#
# Alignment padding is handled with the following formula below
#
# For any align value, the correct padding formula is:
#
# (align - (offset % align)) % align
#
# However, if align is a power of 2 (always the case here), the slow MOD
# operator can be replaced by a bitwise AND:
#
# (align - (offset & (align - 1))) & (align - 1)
#
# Which can be simplified to:
#
# (-offset) & (align - 1)
#
#
class Unmarshaller:
def __init__(self, stream, sock=None):
self.unix_fds = []
self.buf = bytearray()
buf: bytearray
view: memoryview
message: Message
unpack: Dict[str, Struct]
readers: READER_TYPE
def __init__(self, stream: io.BufferedRWPair, sock=None):
self.unix_fds: List[int] = []
self.can_cast = False
self.buf = bytearray() # Actual buffer
self.view = None # Memory view of the buffer
self.offset = 0
self.stream = stream
self.sock = sock
self.endian = None
self.message = None
self.readers = None
self.body_len: int | None = None
self.serial: int | None = None
self.header_len: int | None = None
self.message_type: MessageType | None = None
self.flag: MessageFlag | None = None
self.readers = {
"y": self.read_byte,
"b": self.read_boolean,
"n": self.read_int16,
"q": self.read_uint16,
"i": self.read_int32,
"u": self.read_uint32,
"x": self.read_int64,
"t": self.read_uint64,
"d": self.read_double,
"h": self.read_uint32,
"o": self.read_string,
"s": self.read_string,
"g": self.read_signature,
"a": self.read_array,
"(": self.read_struct,
"{": self.read_dict_entry,
"v": self.read_variant,
}
def read_sock(self, length: int) -> bytes:
"""reads from the socket, storing any fds sent and handling errors
from the read itself"""
unix_fd_list = array.array("i")
def read(self, n, prefetch=False):
try:
msg, ancdata, *_ = self.sock.recvmsg(
length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize)
)
except BlockingIOError:
raise MarshallerStreamEndError()
for level, type_, data in ancdata:
if not (level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS):
continue
unix_fd_list.frombytes(
data[: len(data) - (len(data) % unix_fd_list.itemsize)]
)
self.unix_fds.extend(list(unix_fd_list))
return msg
def read_to_offset(self, offset: int) -> None:
"""
Read from underlying socket into buffer and advance offset accordingly.
Read from underlying socket into buffer.
:arg n:
Number of bytes to read. If not enough bytes are available in the
Raises MarshallerStreamEndError if there is not enough data to be read.
:arg offset:
The offset to read to. If not enough bytes are available in the
buffer, read more from it.
:arg prefetch:
Do not update current offset after reading.
:returns:
Previous offset (before reading). To get the actual read bytes,
use the returned value and self.buf.
None
"""
def read_sock(length):
"""reads from the socket, storing any fds sent and handling errors
from the read itself"""
if self.sock is not None:
unix_fd_list = array.array("i")
try:
msg, ancdata, *_ = self.sock.recvmsg(
length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize)
)
except BlockingIOError:
raise MarshallerStreamEndError()
for level, type_, data in ancdata:
if not (level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS):
continue
unix_fd_list.frombytes(
data[: len(data) - (len(data) % unix_fd_list.itemsize)]
)
self.unix_fds.extend(list(unix_fd_list))
return msg
else:
return self.stream.read(length)
# store previously read data in a buffer so we can resume on socket
# interruptions
missing_bytes = n - (len(self.buf) - self.offset)
if missing_bytes > 0:
data = read_sock(missing_bytes)
if data == b"":
raise EOFError()
elif data is None:
raise MarshallerStreamEndError()
self.buf.extend(data)
if len(data) != missing_bytes:
raise MarshallerStreamEndError()
prev = self.offset
if not prefetch:
self.offset += n
return prev
@staticmethod
def _padding(offset, align):
"""
Get padding bytes to get to the next align bytes mark.
For any align value, the correct padding formula is:
(align - (offset % align)) % align
However, if align is a power of 2 (always the case here), the slow MOD
operator can be replaced by a bitwise AND:
(align - (offset & (align - 1))) & (align - 1)
Which can be simplified to:
(-offset) & (align - 1)
"""
return (-offset) & (align - 1)
def align(self, n):
padding = self._padding(self.offset, n)
if padding > 0:
self.read(padding)
def read_byte(self, _=None):
return self.buf[self.read(1)]
start_len = len(self.buf)
missing_bytes = offset - (start_len - self.offset)
if self.sock is None:
data = self.stream.read(missing_bytes)
else:
data = self.read_sock(missing_bytes)
if data == b"":
raise EOFError()
if data is None:
raise MarshallerStreamEndError()
self.buf.extend(data)
if len(data) + start_len != offset:
raise MarshallerStreamEndError()
def read_boolean(self, _=None):
data = self.read_uint32()
if data:
return True
else:
return False
def read_int16(self, _=None):
return self.read_ctype("h", 2)
def read_uint16(self, _=None):
return self.read_ctype("H", 2)
def read_int32(self, _=None):
return self.read_ctype("i", 4)
def read_uint32(self, _=None):
return self.read_ctype("I", 4)
def read_int64(self, _=None):
return self.read_ctype("q", 8)
def read_uint64(self, _=None):
return self.read_ctype("Q", 8)
def read_double(self, _=None):
return self.read_ctype("d", 8)
def read_ctype(self, fmt, size):
self.align(size)
if self.endian == LITTLE_ENDIAN:
fmt = "<" + fmt
else:
fmt = ">" + fmt
o = self.read(size)
return unpack_from(fmt, self.buf, o)[0]
return bool(self.read_argument(UINT32_SIGNATURE))
def read_string(self, _=None):
str_length = self.read_uint32()
o = self.read(str_length + 1) # read terminating '\0' byte as well
# avoid buffer copies when slicing
str_mem_slice = memoryview(self.buf)[o : o + str_length]
return decode(str_mem_slice)
str_length = self.read_argument(UINT32_SIGNATURE)
str_start = self.offset
# read terminating '\0' byte as well (str_length + 1)
self.offset += str_length + 1
return self.buf[str_start : str_start + str_length].decode()
def read_signature(self, _=None):
signature_len = self.read_byte()
o = self.read(signature_len + 1) # read terminating '\0' byte as well
# avoid buffer copies when slicing
sig_mem_slice = memoryview(self.buf)[o : o + signature_len]
return decode(sig_mem_slice)
signature_len = self.view[self.offset] # byte
o = self.offset + 1
# read terminating '\0' byte as well (str_length + 1)
self.offset = o + signature_len + 1
return self.buf[o : o + signature_len].decode()
def read_variant(self, _=None):
signature = self.read_signature()
signature_tree = SignatureTree._get(signature)
value = self.read_argument(signature_tree.types[0])
return Variant(signature_tree, value)
tree = SignatureTree._get(self.read_signature())
# verify in Variant is only useful on construction not unmarshalling
return Variant(tree, self.read_argument(tree.types[0]), verify=False)
def read_struct(self, type_):
self.align(8)
def read_struct(self, type_: SignatureType):
self.offset += -self.offset & 7 # align 8
return [self.read_argument(child_type) for child_type in type_.children]
result = []
for child_type in type_.children:
result.append(self.read_argument(child_type))
def read_dict_entry(self, type_: SignatureType):
self.offset += -self.offset & 7 # align 8
return self.read_argument(type_.children[0]), self.read_argument(
type_.children[1]
)
return result
def read_dict_entry(self, type_):
self.align(8)
key = self.read_argument(type_.children[0])
value = self.read_argument(type_.children[1])
return key, value
def read_array(self, type_):
self.align(4)
array_length = self.read_uint32()
def read_array(self, type_: SignatureType):
self.offset += -self.offset & 3 # align 4 for the array
array_length = self.read_argument(UINT32_SIGNATURE)
child_type = type_.children[0]
if child_type.token in "xtd{(":
# the first alignment is not included in the array size
self.align(8)
self.offset += -self.offset & 7 # align 8
if child_type.token == "y":
self.offset += array_length
return self.buf[self.offset - array_length : self.offset]
beginning_offset = self.offset
result = None
if child_type.token == "{":
result = {}
result_dict = {}
while self.offset - beginning_offset < array_length:
key, value = self.read_dict_entry(child_type)
result[key] = value
elif child_type.token == "y":
o = self.read(array_length)
# avoid buffer copies when slicing
array_mem_slice = memoryview(self.buf)[o : o + array_length]
result = array_mem_slice.tobytes()
else:
result = []
while self.offset - beginning_offset < array_length:
result.append(self.read_argument(child_type))
result_dict[key] = value
return result_dict
return result
result_list = []
while self.offset - beginning_offset < array_length:
result_list.append(self.read_argument(child_type))
return result_list
def read_argument(self, type_):
t = type_.token
def read_argument(self, type_: SignatureType) -> Any:
"""Dispatch to an argument reader or cast/unpack a C type."""
token = type_.token
reader, ctype, size, struct = self.readers[token]
if reader: # complex type
return reader(self, type_)
self.offset += size + (-self.offset & (size - 1)) # align
if self.can_cast:
return self.view[self.offset - size : self.offset].cast(ctype)[0]
return struct.unpack_from(self.view, self.offset - size)[0]
if t not in self.readers:
raise Exception(f'dont know how to read yet: "{t}"')
def header_fields(self, header_length):
"""Header fields are always a(yv)."""
beginning_offset = self.offset
headers = {}
while self.offset - beginning_offset < header_length:
# Now read the y (byte) of struct (yv)
self.offset += (-self.offset & 7) + 1 # align 8 + 1 for 'y' byte
field_0 = self.view[self.offset - 1]
return self.readers[t](type_)
# Now read the v (variant) of struct (yv)
signature_len = self.view[self.offset] # byte
o = self.offset + 1
self.offset += signature_len + 2 # one for the byte, one for the '\0'
tree = SignatureTree._get(self.buf[o : o + signature_len].decode())
headers[HEADER_NAME_MAP[field_0]] = self.read_argument(tree.types[0])
return headers
def _unmarshall(self):
self.offset = 0
self.read(16, prefetch=True)
self.endian = self.read_byte()
if self.endian != LITTLE_ENDIAN and self.endian != BIG_ENDIAN:
raise InvalidMessageError("Expecting endianness as the first byte")
message_type = MessageType(self.read_byte())
flags = MessageFlag(self.read_byte())
protocol_version = self.read_byte()
def _read_header(self):
"""Read the header of the message."""
# Signature is of the header is
# BYTE, BYTE, BYTE, BYTE, UINT32, UINT32, ARRAY of STRUCT of (BYTE,VARIANT)
self.read_to_offset(HEADER_SIGNATURE_SIZE)
buffer = self.buf
endian = buffer[0]
self.message_type = MESSAGE_TYPE_MAP[buffer[1]]
self.flag = MESSAGE_FLAG_MAP[buffer[2]]
protocol_version = buffer[3]
if endian != LITTLE_ENDIAN and endian != BIG_ENDIAN:
raise InvalidMessageError(
f"Expecting endianness as the first byte, got {endian} from {buffer}"
)
if protocol_version != PROTOCOL_VERSION:
raise InvalidMessageError(
f"got unknown protocol version: {protocol_version}"
)
body_len = self.read_uint32()
serial = self.read_uint32()
header_len = self.read_uint32()
msg_len = header_len + self._padding(header_len, 8) + body_len
self.read(msg_len, prefetch=True)
# backtrack offset since header array length needs to be read again
self.offset -= 4
header_fields = {}
for field_struct in self.read_argument(SignatureTree._get("a(yv)").types[0]):
field = HeaderField(field_struct[0])
header_fields[field.name] = field_struct[1].value
self.align(8)
path = header_fields.get(HeaderField.PATH.name)
interface = header_fields.get(HeaderField.INTERFACE.name)
member = header_fields.get(HeaderField.MEMBER.name)
error_name = header_fields.get(HeaderField.ERROR_NAME.name)
reply_serial = header_fields.get(HeaderField.REPLY_SERIAL.name)
destination = header_fields.get(HeaderField.DESTINATION.name)
sender = header_fields.get(HeaderField.SENDER.name)
signature = header_fields.get(HeaderField.SIGNATURE.name, "")
signature_tree = SignatureTree._get(signature)
# unix_fds = header_fields.get(HeaderField.UNIX_FDS.name, 0)
body = []
if body_len:
for type_ in signature_tree.types:
body.append(self.read_argument(type_))
self.body_len, self.serial, self.header_len = UNPACK_LENGTHS[
endian
].unpack_from(buffer, 4)
self.msg_len = (
self.header_len + (-self.header_len & 7) + self.body_len
) # align 8
if (sys.byteorder == "little" and endian == LITTLE_ENDIAN) or (
sys.byteorder == "big" and endian == BIG_ENDIAN
):
self.can_cast = True
self.readers = self._readers_by_type[endian]
def _read_body(self):
"""Read the body of the message."""
self.read_to_offset(HEADER_SIGNATURE_SIZE + self.msg_len)
self.view = memoryview(self.buf)
self.offset = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION
header_fields = self.header_fields(self.header_len)
self.offset += -self.offset & 7 # align 8
tree = SignatureTree._get(header_fields.get(HeaderField.SIGNATURE.name, ""))
self.message = Message(
destination=destination,
path=path,
interface=interface,
member=member,
message_type=message_type,
flags=flags,
error_name=error_name,
reply_serial=reply_serial,
sender=sender,
destination=header_fields.get(HEADER_DESTINATION),
path=header_fields.get(HEADER_PATH),
interface=header_fields.get(HEADER_INTERFACE),
member=header_fields.get(HEADER_MEMBER),
message_type=self.message_type,
flags=self.flag,
error_name=header_fields.get(HEADER_ERROR_NAME),
reply_serial=header_fields.get(HEADER_REPLY_SERIAL),
sender=header_fields.get(HEADER_SENDER),
unix_fds=self.unix_fds,
signature=signature_tree,
body=body,
serial=serial,
signature=tree.signature,
body=[self.read_argument(t) for t in tree.types] if self.body_len else [],
serial=self.serial,
)
def unmarshall(self):
"""Unmarshall the message.
The underlying read function will raise MarshallerStreamEndError
if there are not enough bytes in the buffer. This allows unmarshall
to be resumed when more data comes in over the wire.
"""
try:
self._unmarshall()
return self.message
if not self.message_type:
self._read_header()
self._read_body()
except MarshallerStreamEndError:
return None
return self.message
_complex_parsers: Dict[
str, Tuple[Callable[["Unmarshaller", SignatureType], Any], None, None, None]
] = {
"b": (read_boolean, None, None, None),
"o": (read_string, None, None, None),
"s": (read_string, None, None, None),
"g": (read_signature, None, None, None),
"a": (read_array, None, None, None),
"(": (read_struct, None, None, None),
"{": (read_dict_entry, None, None, None),
"v": (read_variant, None, None, None),
}
_ctype_by_endian: Dict[int, Dict[str, Tuple[None, str, int, Struct]]] = {
endian: {
dbus_type: (
None,
*ctype_size,
Struct(f"{UNPACK_SYMBOL[endian]}{ctype_size[0]}"),
)
for dbus_type, ctype_size in DBUS_TO_CTYPE.items()
}
for endian in (BIG_ENDIAN, LITTLE_ENDIAN)
}
_readers_by_type: Dict[int, READER_TYPE] = {
BIG_ENDIAN: {**_ctype_by_endian[BIG_ENDIAN], **_complex_parsers},
LITTLE_ENDIAN: {**_ctype_by_endian[LITTLE_ENDIAN], **_complex_parsers},
}

View File

@ -19,6 +19,9 @@ class MessageType(Enum):
SIGNAL = 4 #: A broadcast signal to subscribed connections
MESSAGE_TYPE_MAP = {field.value: field for field in MessageType}
class MessageFlag(IntFlag):
"""Flags that affect the behavior of sent and received messages"""
@ -28,6 +31,9 @@ class MessageFlag(IntFlag):
ALLOW_INTERACTIVE_AUTHORIZATION = 4
MESSAGE_FLAG_MAP = {field.value: field for field in MessageFlag}
class NameFlag(IntFlag):
"""A flag that affects the behavior of a name request."""

View File

@ -12,6 +12,13 @@ from .validators import (
assert_object_path_valid,
)
REQUIRED_FIELDS = {
MessageType.METHOD_CALL: ("path", "member"),
MessageType.SIGNAL: ("path", "member", "interface"),
MessageType.ERROR: ("error_name", "reply_serial"),
MessageType.METHOD_RETURN: ("reply_serial",),
}
class Message:
"""A class for sending and receiving messages through the
@ -112,21 +119,12 @@ class Message:
if self.error_name is not None:
assert_interface_name_valid(self.error_name)
def require_fields(*fields):
for field in fields:
if not getattr(self, field):
raise InvalidMessageError(f"missing required field: {field}")
if self.message_type == MessageType.METHOD_CALL:
require_fields("path", "member")
elif self.message_type == MessageType.SIGNAL:
require_fields("path", "member", "interface")
elif self.message_type == MessageType.ERROR:
require_fields("error_name", "reply_serial")
elif self.message_type == MessageType.METHOD_RETURN:
require_fields("reply_serial")
else:
required_fields = REQUIRED_FIELDS.get(self.message_type)
if not required_fields:
raise InvalidMessageError(f"got unknown message type: {self.message_type}")
for field in required_fields:
if not getattr(self, field):
raise InvalidMessageError(f"missing required field: {field}")
@staticmethod
def new_error(msg: "Message", error_name: str, error_text: str) -> "Message":

View File

@ -1,3 +1,4 @@
from functools import lru_cache
from typing import Any, List, Union
from .errors import InvalidSignatureError, SignatureBodyMismatchError
@ -21,9 +22,9 @@ class SignatureType:
_tokens = "ybnqiuxtdsogavh({"
def __init__(self, token):
def __init__(self, token: str) -> None:
self.token = token
self.children = []
self.children: List[SignatureType] = []
self._signature = None
def __eq__(self, other):
@ -240,7 +241,7 @@ class SignatureType:
child_type.children[0].verify(key)
child_type.children[1].verify(value)
elif child_type.token == "y":
if not isinstance(body, bytes):
if not isinstance(body, (bytearray, bytes)):
raise SignatureBodyMismatchError(
f'DBus ARRAY type "a" with BYTE child must be Python type "bytes", got {type(body)}'
)
@ -284,43 +285,33 @@ class SignatureType:
"""
if body is None:
raise SignatureBodyMismatchError('Cannot serialize Python type "None"')
elif self.token == "y":
self._verify_byte(body)
elif self.token == "b":
self._verify_boolean(body)
elif self.token == "n":
self._verify_int16(body)
elif self.token == "q":
self._verify_uint16(body)
elif self.token == "i":
self._verify_int32(body)
elif self.token == "u":
self._verify_uint32(body)
elif self.token == "x":
self._verify_int64(body)
elif self.token == "t":
self._verify_uint64(body)
elif self.token == "d":
self._verify_double(body)
elif self.token == "h":
self._verify_unix_fd(body)
elif self.token == "o":
self._verify_object_path(body)
elif self.token == "s":
self._verify_string(body)
elif self.token == "g":
self._verify_signature(body)
elif self.token == "a":
self._verify_array(body)
elif self.token == "(":
self._verify_struct(body)
elif self.token == "v":
self._verify_variant(body)
validator = self.validators.get(self.token)
if validator:
validator(self, body)
else:
raise Exception(f"cannot verify type with token {self.token}")
return True
validators = {
"y": _verify_byte,
"b": _verify_boolean,
"n": _verify_int16,
"q": _verify_uint16,
"i": _verify_int32,
"u": _verify_uint32,
"x": _verify_int64,
"t": _verify_uint64,
"d": _verify_double,
"h": _verify_uint32,
"o": _verify_string,
"s": _verify_string,
"g": _verify_signature,
"a": _verify_array,
"(": _verify_struct,
"v": _verify_variant,
}
class SignatureTree:
"""A class that represents a signature as a tree structure for conveniently
@ -338,19 +329,15 @@ class SignatureTree:
:class:`InvalidSignatureError` if the given signature is not valid.
"""
_cache = {}
@staticmethod
def _get(signature: str = ""):
if signature in SignatureTree._cache:
return SignatureTree._cache[signature]
SignatureTree._cache[signature] = SignatureTree(signature)
return SignatureTree._cache[signature]
@lru_cache(maxsize=None)
def _get(signature: str = "") -> "SignatureTree":
return SignatureTree(signature)
def __init__(self, signature: str = ""):
self.signature = signature
self.types = []
self.types: List[SignatureType] = []
if len(signature) > 0xFF:
raise InvalidSignatureError("A signature must be less than 256 characters")
@ -411,7 +398,12 @@ class Variant:
:class:`SignatureBodyMismatchError` if the signature does not match the body.
"""
def __init__(self, signature: Union[str, SignatureTree, SignatureType], value: Any):
def __init__(
self,
signature: Union[str, SignatureTree, SignatureType],
value: Any,
verify: bool = True,
):
signature_str = ""
signature_tree = None
signature_type = None
@ -429,14 +421,15 @@ class Variant:
)
if signature_tree:
if len(signature_tree.types) != 1:
if verify and len(signature_tree.types) != 1:
raise ValueError(
"variants must have a signature for a single complete type"
)
signature_str = signature_tree.signature
signature_type = signature_tree.types[0]
signature_type.verify(value)
if verify:
signature_type.verify(value)
self.type = signature_type
self.signature = signature_str

View File

@ -1,4 +1,5 @@
import re
from functools import lru_cache
from .errors import (
InvalidBusNameError,
@ -13,6 +14,7 @@ _element_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
_member_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_-]*$")
@lru_cache(maxsize=32)
def is_bus_name_valid(name: str) -> bool:
"""Whether this is a valid bus name.
@ -47,6 +49,7 @@ def is_bus_name_valid(name: str) -> bool:
return True
@lru_cache(maxsize=512)
def is_object_path_valid(path: str) -> bool:
"""Whether this is a valid object path.
@ -77,6 +80,7 @@ def is_object_path_valid(path: str) -> bool:
return True
@lru_cache(maxsize=32)
def is_interface_name_valid(name: str) -> bool:
"""Whether this is a valid interface name.
@ -107,6 +111,7 @@ def is_interface_name_valid(name: str) -> bool:
return True
@lru_cache(maxsize=512)
def is_member_name_valid(member: str) -> bool:
"""Whether this is a valid member name.

View File

@ -1,8 +1,11 @@
import io
import json
import os
from typing import Any, Dict
from dbus_fast import Message, SignatureTree, Variant
import pytest
from dbus_fast import Message, MessageFlag, MessageType, SignatureTree, Variant
from dbus_fast._private.unmarshaller import Unmarshaller
@ -20,6 +23,16 @@ def print_buf(buf):
table = json.load(open(os.path.dirname(__file__) + "/data/messages.json"))
def json_to_message(message: Dict[str, Any]) -> Message:
copy = dict(message)
if "message_type" in copy:
copy["message_type"] = MessageType(copy["message_type"])
if "flags" in copy:
copy["flags"] = MessageFlag(copy["flags"])
return Message(**copy)
# variants are an object in the json
def replace_variants(type_, item):
if type_.token == "v" and type(item) is not Variant:
@ -56,7 +69,7 @@ def json_dump(what):
def test_marshalling_with_table():
for item in table:
message = Message(**item["message"])
message = json_to_message(item["message"])
body = []
for i, type_ in enumerate(message.signature_tree.types):
@ -79,8 +92,9 @@ def test_marshalling_with_table():
assert buf == data
def test_unmarshalling_with_table():
for item in table:
@pytest.mark.parametrize("unmarshall_table", (table,))
def test_unmarshalling_with_table(unmarshall_table):
for item in unmarshall_table:
stream = io.BytesIO(bytes.fromhex(item["data"]))
unmarshaller = Unmarshaller(stream)
@ -91,7 +105,7 @@ def test_unmarshalling_with_table():
print(json_dump(item["message"]))
raise e
message = Message(**item["message"])
message = json_to_message(item["message"])
body = []
for i, type_ in enumerate(message.signature_tree.types):
@ -114,6 +128,39 @@ def test_unmarshalling_with_table():
), f"attr doesnt match: {attr}"
def test_unmarshall_can_resume():
"""Verify resume works."""
bluez_rssi_message = (
"6c04010134000000e25389019500000001016f00250000002f6f72672f626c75657a2f686369302f6465"
"765f30385f33415f46325f31455f32425f3631000000020173001f0000006f72672e667265656465736b"
"746f702e444275732e50726f7065727469657300030173001100000050726f706572746965734368616e"
"67656400000000000000080167000873617b73767d617300000007017300040000003a312e3400000000"
"110000006f72672e626c75657a2e446576696365310000000e0000000000000004000000525353490001"
"6e00a7ff000000000000"
)
message_bytes = bytes.fromhex(bluez_rssi_message)
class SlowStream(io.IOBase):
"""A fake stream that will only give us one byte at a time."""
def __init__(self):
self.data = message_bytes
self.pos = 0
def read(self, n) -> bytes:
data = self.data[self.pos : self.pos + 1]
self.pos += 1
return data
stream = SlowStream()
unmarshaller = Unmarshaller(stream)
for _ in range(len(bluez_rssi_message)):
if unmarshaller.unmarshall():
break
assert unmarshaller.message is not None
def test_ay_buffer():
body = [bytes(10000)]
msg = Message(path="/test", member="test", signature="ay", body=body)

View File

@ -10,7 +10,6 @@ def test_object_path_validator():
valid_paths = ["/", "/foo", "/foo/bar", "/foo/bar/bat"]
invalid_paths = [
None,
{},
"",
"foo",
"foo/bar",
@ -37,7 +36,6 @@ def test_bus_name_validator():
]
invalid_names = [
None,
{},
"",
"5foo.bar",
"foo.6bar",
@ -57,7 +55,6 @@ def test_interface_name_validator():
valid_names = ["foo.bar", "foo.bar.bat", "_foo._bar", "foo.bar69"]
invalid_names = [
None,
{},
"",
"5foo.bar",
"foo.6bar",
@ -80,7 +77,7 @@ def test_interface_name_validator():
def test_member_name_validator():
valid_members = ["foo", "FooBar", "Bat_Baz69", "foo-bar"]
invalid_members = [None, {}, "", "foo.bar", "5foo", "foo$bar"]
invalid_members = [None, "", "foo.bar", "5foo", "foo$bar"]
for member in valid_members:
assert is_member_name_valid(member), f'member name should be valid: "{member}"'