feat: speed up unmarshaller (#1)
This commit is contained in:
parent
2c9cdcc173
commit
eca1d31781
22
bench/unmarshall.py
Normal file
22
bench/unmarshall.py
Normal file
@ -0,0 +1,22 @@
|
||||
import io
|
||||
import timeit
|
||||
|
||||
from dbus_fast._private.unmarshaller import Unmarshaller
|
||||
|
||||
bluez_rssi_message = (
|
||||
"6c04010134000000e25389019500000001016f00250000002f6f72672f626c75657a2f686369302f6465"
|
||||
"765f30385f33415f46325f31455f32425f3631000000020173001f0000006f72672e667265656465736b"
|
||||
"746f702e444275732e50726f7065727469657300030173001100000050726f706572746965734368616e"
|
||||
"67656400000000000000080167000873617b73767d617300000007017300040000003a312e3400000000"
|
||||
"110000006f72672e626c75657a2e446576696365310000000e0000000000000004000000525353490001"
|
||||
"6e00a7ff000000000000"
|
||||
)
|
||||
|
||||
|
||||
def unmarhsall_bluez_rssi_message():
|
||||
Unmarshaller(io.BytesIO(bytes.fromhex(bluez_rssi_message))).unmarshall()
|
||||
|
||||
|
||||
count = 1000000
|
||||
time = timeit.Timer(unmarhsall_bluez_rssi_message).timeit(count)
|
||||
print(f"Unmarshalling {count} bluetooth rssi messages took {time} seconds")
|
||||
@ -16,3 +16,6 @@ class HeaderField(Enum):
|
||||
SENDER = 7
|
||||
SIGNATURE = 8
|
||||
UNIX_FDS = 9
|
||||
|
||||
|
||||
HEADER_NAME_MAP = {field.value: field.name for field in HeaderField}
|
||||
|
||||
@ -1,315 +1,348 @@
|
||||
import array
|
||||
import io
|
||||
import socket
|
||||
from codecs import decode
|
||||
from struct import unpack_from
|
||||
import sys
|
||||
from struct import Struct
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from ..constants import MessageFlag, MessageType
|
||||
from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP, MessageFlag, MessageType
|
||||
from ..errors import InvalidMessageError
|
||||
from ..message import Message
|
||||
from ..signature import SignatureTree, Variant
|
||||
from .constants import BIG_ENDIAN, LITTLE_ENDIAN, PROTOCOL_VERSION, HeaderField
|
||||
from ..signature import SignatureTree, SignatureType, Variant
|
||||
from .constants import (
|
||||
BIG_ENDIAN,
|
||||
HEADER_NAME_MAP,
|
||||
LITTLE_ENDIAN,
|
||||
PROTOCOL_VERSION,
|
||||
HeaderField,
|
||||
)
|
||||
|
||||
MAX_UNIX_FDS = 16
|
||||
|
||||
UNPACK_SYMBOL = {LITTLE_ENDIAN: "<", BIG_ENDIAN: ">"}
|
||||
UNPACK_LENGTHS = {BIG_ENDIAN: Struct(">III"), LITTLE_ENDIAN: Struct("<III")}
|
||||
|
||||
DBUS_TO_CTYPE = {
|
||||
"y": ("B", 1), # byte
|
||||
"n": ("h", 2), # int16
|
||||
"q": ("H", 2), # uint16
|
||||
"i": ("i", 4), # int32
|
||||
"u": ("I", 4), # uint32
|
||||
"x": ("q", 8), # int64
|
||||
"t": ("Q", 8), # uint64
|
||||
"d": ("d", 8), # double
|
||||
"h": ("I", 4), # uint32
|
||||
}
|
||||
|
||||
HEADER_SIGNATURE_SIZE = 16
|
||||
HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION = 12
|
||||
|
||||
UINT32_SIGNATURE = SignatureTree._get("u").types[0]
|
||||
|
||||
HEADER_DESTINATION = HeaderField.DESTINATION.name
|
||||
HEADER_PATH = HeaderField.PATH.name
|
||||
HEADER_INTERFACE = HeaderField.INTERFACE.name
|
||||
HEADER_MEMBER = HeaderField.MEMBER.name
|
||||
HEADER_ERROR_NAME = HeaderField.ERROR_NAME.name
|
||||
HEADER_REPLY_SERIAL = HeaderField.REPLY_SERIAL.name
|
||||
HEADER_SENDER = HeaderField.SENDER.name
|
||||
|
||||
READER_TYPE = Dict[
|
||||
str,
|
||||
Tuple[
|
||||
Optional[Callable[["Unmarshaller", SignatureType], Any]],
|
||||
Optional[str],
|
||||
Optional[int],
|
||||
Optional[Struct],
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
class MarshallerStreamEndError(Exception):
|
||||
"""This exception is raised when the end of the stream is reached.
|
||||
|
||||
This means more data is expected on the wire that has not yet been
|
||||
received. The caller should call unmarshall later when more data is
|
||||
available.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
#
|
||||
# Alignment padding is handled with the following formula below
|
||||
#
|
||||
# For any align value, the correct padding formula is:
|
||||
#
|
||||
# (align - (offset % align)) % align
|
||||
#
|
||||
# However, if align is a power of 2 (always the case here), the slow MOD
|
||||
# operator can be replaced by a bitwise AND:
|
||||
#
|
||||
# (align - (offset & (align - 1))) & (align - 1)
|
||||
#
|
||||
# Which can be simplified to:
|
||||
#
|
||||
# (-offset) & (align - 1)
|
||||
#
|
||||
#
|
||||
class Unmarshaller:
|
||||
def __init__(self, stream, sock=None):
|
||||
self.unix_fds = []
|
||||
self.buf = bytearray()
|
||||
|
||||
buf: bytearray
|
||||
view: memoryview
|
||||
message: Message
|
||||
unpack: Dict[str, Struct]
|
||||
readers: READER_TYPE
|
||||
|
||||
def __init__(self, stream: io.BufferedRWPair, sock=None):
|
||||
self.unix_fds: List[int] = []
|
||||
self.can_cast = False
|
||||
self.buf = bytearray() # Actual buffer
|
||||
self.view = None # Memory view of the buffer
|
||||
self.offset = 0
|
||||
self.stream = stream
|
||||
self.sock = sock
|
||||
self.endian = None
|
||||
self.message = None
|
||||
self.readers = None
|
||||
self.body_len: int | None = None
|
||||
self.serial: int | None = None
|
||||
self.header_len: int | None = None
|
||||
self.message_type: MessageType | None = None
|
||||
self.flag: MessageFlag | None = None
|
||||
|
||||
self.readers = {
|
||||
"y": self.read_byte,
|
||||
"b": self.read_boolean,
|
||||
"n": self.read_int16,
|
||||
"q": self.read_uint16,
|
||||
"i": self.read_int32,
|
||||
"u": self.read_uint32,
|
||||
"x": self.read_int64,
|
||||
"t": self.read_uint64,
|
||||
"d": self.read_double,
|
||||
"h": self.read_uint32,
|
||||
"o": self.read_string,
|
||||
"s": self.read_string,
|
||||
"g": self.read_signature,
|
||||
"a": self.read_array,
|
||||
"(": self.read_struct,
|
||||
"{": self.read_dict_entry,
|
||||
"v": self.read_variant,
|
||||
}
|
||||
def read_sock(self, length: int) -> bytes:
|
||||
"""reads from the socket, storing any fds sent and handling errors
|
||||
from the read itself"""
|
||||
unix_fd_list = array.array("i")
|
||||
|
||||
def read(self, n, prefetch=False):
|
||||
try:
|
||||
msg, ancdata, *_ = self.sock.recvmsg(
|
||||
length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize)
|
||||
)
|
||||
except BlockingIOError:
|
||||
raise MarshallerStreamEndError()
|
||||
|
||||
for level, type_, data in ancdata:
|
||||
if not (level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS):
|
||||
continue
|
||||
unix_fd_list.frombytes(
|
||||
data[: len(data) - (len(data) % unix_fd_list.itemsize)]
|
||||
)
|
||||
self.unix_fds.extend(list(unix_fd_list))
|
||||
|
||||
return msg
|
||||
|
||||
def read_to_offset(self, offset: int) -> None:
|
||||
"""
|
||||
Read from underlying socket into buffer and advance offset accordingly.
|
||||
Read from underlying socket into buffer.
|
||||
|
||||
:arg n:
|
||||
Number of bytes to read. If not enough bytes are available in the
|
||||
Raises MarshallerStreamEndError if there is not enough data to be read.
|
||||
|
||||
:arg offset:
|
||||
The offset to read to. If not enough bytes are available in the
|
||||
buffer, read more from it.
|
||||
:arg prefetch:
|
||||
Do not update current offset after reading.
|
||||
|
||||
:returns:
|
||||
Previous offset (before reading). To get the actual read bytes,
|
||||
use the returned value and self.buf.
|
||||
None
|
||||
"""
|
||||
|
||||
def read_sock(length):
|
||||
"""reads from the socket, storing any fds sent and handling errors
|
||||
from the read itself"""
|
||||
if self.sock is not None:
|
||||
unix_fd_list = array.array("i")
|
||||
|
||||
try:
|
||||
msg, ancdata, *_ = self.sock.recvmsg(
|
||||
length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize)
|
||||
)
|
||||
except BlockingIOError:
|
||||
raise MarshallerStreamEndError()
|
||||
|
||||
for level, type_, data in ancdata:
|
||||
if not (level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS):
|
||||
continue
|
||||
unix_fd_list.frombytes(
|
||||
data[: len(data) - (len(data) % unix_fd_list.itemsize)]
|
||||
)
|
||||
self.unix_fds.extend(list(unix_fd_list))
|
||||
|
||||
return msg
|
||||
else:
|
||||
return self.stream.read(length)
|
||||
|
||||
# store previously read data in a buffer so we can resume on socket
|
||||
# interruptions
|
||||
missing_bytes = n - (len(self.buf) - self.offset)
|
||||
if missing_bytes > 0:
|
||||
data = read_sock(missing_bytes)
|
||||
if data == b"":
|
||||
raise EOFError()
|
||||
elif data is None:
|
||||
raise MarshallerStreamEndError()
|
||||
self.buf.extend(data)
|
||||
if len(data) != missing_bytes:
|
||||
raise MarshallerStreamEndError()
|
||||
prev = self.offset
|
||||
if not prefetch:
|
||||
self.offset += n
|
||||
return prev
|
||||
|
||||
@staticmethod
|
||||
def _padding(offset, align):
|
||||
"""
|
||||
Get padding bytes to get to the next align bytes mark.
|
||||
|
||||
For any align value, the correct padding formula is:
|
||||
|
||||
(align - (offset % align)) % align
|
||||
|
||||
However, if align is a power of 2 (always the case here), the slow MOD
|
||||
operator can be replaced by a bitwise AND:
|
||||
|
||||
(align - (offset & (align - 1))) & (align - 1)
|
||||
|
||||
Which can be simplified to:
|
||||
|
||||
(-offset) & (align - 1)
|
||||
"""
|
||||
return (-offset) & (align - 1)
|
||||
|
||||
def align(self, n):
|
||||
padding = self._padding(self.offset, n)
|
||||
if padding > 0:
|
||||
self.read(padding)
|
||||
|
||||
def read_byte(self, _=None):
|
||||
return self.buf[self.read(1)]
|
||||
start_len = len(self.buf)
|
||||
missing_bytes = offset - (start_len - self.offset)
|
||||
if self.sock is None:
|
||||
data = self.stream.read(missing_bytes)
|
||||
else:
|
||||
data = self.read_sock(missing_bytes)
|
||||
if data == b"":
|
||||
raise EOFError()
|
||||
if data is None:
|
||||
raise MarshallerStreamEndError()
|
||||
self.buf.extend(data)
|
||||
if len(data) + start_len != offset:
|
||||
raise MarshallerStreamEndError()
|
||||
|
||||
def read_boolean(self, _=None):
|
||||
data = self.read_uint32()
|
||||
if data:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def read_int16(self, _=None):
|
||||
return self.read_ctype("h", 2)
|
||||
|
||||
def read_uint16(self, _=None):
|
||||
return self.read_ctype("H", 2)
|
||||
|
||||
def read_int32(self, _=None):
|
||||
return self.read_ctype("i", 4)
|
||||
|
||||
def read_uint32(self, _=None):
|
||||
return self.read_ctype("I", 4)
|
||||
|
||||
def read_int64(self, _=None):
|
||||
return self.read_ctype("q", 8)
|
||||
|
||||
def read_uint64(self, _=None):
|
||||
return self.read_ctype("Q", 8)
|
||||
|
||||
def read_double(self, _=None):
|
||||
return self.read_ctype("d", 8)
|
||||
|
||||
def read_ctype(self, fmt, size):
|
||||
self.align(size)
|
||||
if self.endian == LITTLE_ENDIAN:
|
||||
fmt = "<" + fmt
|
||||
else:
|
||||
fmt = ">" + fmt
|
||||
o = self.read(size)
|
||||
return unpack_from(fmt, self.buf, o)[0]
|
||||
return bool(self.read_argument(UINT32_SIGNATURE))
|
||||
|
||||
def read_string(self, _=None):
|
||||
str_length = self.read_uint32()
|
||||
o = self.read(str_length + 1) # read terminating '\0' byte as well
|
||||
# avoid buffer copies when slicing
|
||||
str_mem_slice = memoryview(self.buf)[o : o + str_length]
|
||||
return decode(str_mem_slice)
|
||||
str_length = self.read_argument(UINT32_SIGNATURE)
|
||||
str_start = self.offset
|
||||
# read terminating '\0' byte as well (str_length + 1)
|
||||
self.offset += str_length + 1
|
||||
return self.buf[str_start : str_start + str_length].decode()
|
||||
|
||||
def read_signature(self, _=None):
|
||||
signature_len = self.read_byte()
|
||||
o = self.read(signature_len + 1) # read terminating '\0' byte as well
|
||||
# avoid buffer copies when slicing
|
||||
sig_mem_slice = memoryview(self.buf)[o : o + signature_len]
|
||||
return decode(sig_mem_slice)
|
||||
signature_len = self.view[self.offset] # byte
|
||||
o = self.offset + 1
|
||||
# read terminating '\0' byte as well (str_length + 1)
|
||||
self.offset = o + signature_len + 1
|
||||
return self.buf[o : o + signature_len].decode()
|
||||
|
||||
def read_variant(self, _=None):
|
||||
signature = self.read_signature()
|
||||
signature_tree = SignatureTree._get(signature)
|
||||
value = self.read_argument(signature_tree.types[0])
|
||||
return Variant(signature_tree, value)
|
||||
tree = SignatureTree._get(self.read_signature())
|
||||
# verify in Variant is only useful on construction not unmarshalling
|
||||
return Variant(tree, self.read_argument(tree.types[0]), verify=False)
|
||||
|
||||
def read_struct(self, type_):
|
||||
self.align(8)
|
||||
def read_struct(self, type_: SignatureType):
|
||||
self.offset += -self.offset & 7 # align 8
|
||||
return [self.read_argument(child_type) for child_type in type_.children]
|
||||
|
||||
result = []
|
||||
for child_type in type_.children:
|
||||
result.append(self.read_argument(child_type))
|
||||
def read_dict_entry(self, type_: SignatureType):
|
||||
self.offset += -self.offset & 7 # align 8
|
||||
return self.read_argument(type_.children[0]), self.read_argument(
|
||||
type_.children[1]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def read_dict_entry(self, type_):
|
||||
self.align(8)
|
||||
|
||||
key = self.read_argument(type_.children[0])
|
||||
value = self.read_argument(type_.children[1])
|
||||
|
||||
return key, value
|
||||
|
||||
def read_array(self, type_):
|
||||
self.align(4)
|
||||
array_length = self.read_uint32()
|
||||
def read_array(self, type_: SignatureType):
|
||||
self.offset += -self.offset & 3 # align 4 for the array
|
||||
array_length = self.read_argument(UINT32_SIGNATURE)
|
||||
|
||||
child_type = type_.children[0]
|
||||
if child_type.token in "xtd{(":
|
||||
# the first alignment is not included in the array size
|
||||
self.align(8)
|
||||
self.offset += -self.offset & 7 # align 8
|
||||
|
||||
if child_type.token == "y":
|
||||
self.offset += array_length
|
||||
return self.buf[self.offset - array_length : self.offset]
|
||||
|
||||
beginning_offset = self.offset
|
||||
|
||||
result = None
|
||||
if child_type.token == "{":
|
||||
result = {}
|
||||
result_dict = {}
|
||||
while self.offset - beginning_offset < array_length:
|
||||
key, value = self.read_dict_entry(child_type)
|
||||
result[key] = value
|
||||
elif child_type.token == "y":
|
||||
o = self.read(array_length)
|
||||
# avoid buffer copies when slicing
|
||||
array_mem_slice = memoryview(self.buf)[o : o + array_length]
|
||||
result = array_mem_slice.tobytes()
|
||||
else:
|
||||
result = []
|
||||
while self.offset - beginning_offset < array_length:
|
||||
result.append(self.read_argument(child_type))
|
||||
result_dict[key] = value
|
||||
return result_dict
|
||||
|
||||
return result
|
||||
result_list = []
|
||||
while self.offset - beginning_offset < array_length:
|
||||
result_list.append(self.read_argument(child_type))
|
||||
return result_list
|
||||
|
||||
def read_argument(self, type_):
|
||||
t = type_.token
|
||||
def read_argument(self, type_: SignatureType) -> Any:
|
||||
"""Dispatch to an argument reader or cast/unpack a C type."""
|
||||
token = type_.token
|
||||
reader, ctype, size, struct = self.readers[token]
|
||||
if reader: # complex type
|
||||
return reader(self, type_)
|
||||
self.offset += size + (-self.offset & (size - 1)) # align
|
||||
if self.can_cast:
|
||||
return self.view[self.offset - size : self.offset].cast(ctype)[0]
|
||||
return struct.unpack_from(self.view, self.offset - size)[0]
|
||||
|
||||
if t not in self.readers:
|
||||
raise Exception(f'dont know how to read yet: "{t}"')
|
||||
def header_fields(self, header_length):
|
||||
"""Header fields are always a(yv)."""
|
||||
beginning_offset = self.offset
|
||||
headers = {}
|
||||
while self.offset - beginning_offset < header_length:
|
||||
# Now read the y (byte) of struct (yv)
|
||||
self.offset += (-self.offset & 7) + 1 # align 8 + 1 for 'y' byte
|
||||
field_0 = self.view[self.offset - 1]
|
||||
|
||||
return self.readers[t](type_)
|
||||
# Now read the v (variant) of struct (yv)
|
||||
signature_len = self.view[self.offset] # byte
|
||||
o = self.offset + 1
|
||||
self.offset += signature_len + 2 # one for the byte, one for the '\0'
|
||||
tree = SignatureTree._get(self.buf[o : o + signature_len].decode())
|
||||
headers[HEADER_NAME_MAP[field_0]] = self.read_argument(tree.types[0])
|
||||
return headers
|
||||
|
||||
def _unmarshall(self):
|
||||
self.offset = 0
|
||||
self.read(16, prefetch=True)
|
||||
self.endian = self.read_byte()
|
||||
if self.endian != LITTLE_ENDIAN and self.endian != BIG_ENDIAN:
|
||||
raise InvalidMessageError("Expecting endianness as the first byte")
|
||||
message_type = MessageType(self.read_byte())
|
||||
flags = MessageFlag(self.read_byte())
|
||||
|
||||
protocol_version = self.read_byte()
|
||||
def _read_header(self):
|
||||
"""Read the header of the message."""
|
||||
# Signature is of the header is
|
||||
# BYTE, BYTE, BYTE, BYTE, UINT32, UINT32, ARRAY of STRUCT of (BYTE,VARIANT)
|
||||
self.read_to_offset(HEADER_SIGNATURE_SIZE)
|
||||
buffer = self.buf
|
||||
endian = buffer[0]
|
||||
self.message_type = MESSAGE_TYPE_MAP[buffer[1]]
|
||||
self.flag = MESSAGE_FLAG_MAP[buffer[2]]
|
||||
protocol_version = buffer[3]
|
||||
|
||||
if endian != LITTLE_ENDIAN and endian != BIG_ENDIAN:
|
||||
raise InvalidMessageError(
|
||||
f"Expecting endianness as the first byte, got {endian} from {buffer}"
|
||||
)
|
||||
if protocol_version != PROTOCOL_VERSION:
|
||||
raise InvalidMessageError(
|
||||
f"got unknown protocol version: {protocol_version}"
|
||||
)
|
||||
|
||||
body_len = self.read_uint32()
|
||||
serial = self.read_uint32()
|
||||
|
||||
header_len = self.read_uint32()
|
||||
msg_len = header_len + self._padding(header_len, 8) + body_len
|
||||
self.read(msg_len, prefetch=True)
|
||||
# backtrack offset since header array length needs to be read again
|
||||
self.offset -= 4
|
||||
|
||||
header_fields = {}
|
||||
for field_struct in self.read_argument(SignatureTree._get("a(yv)").types[0]):
|
||||
field = HeaderField(field_struct[0])
|
||||
header_fields[field.name] = field_struct[1].value
|
||||
|
||||
self.align(8)
|
||||
|
||||
path = header_fields.get(HeaderField.PATH.name)
|
||||
interface = header_fields.get(HeaderField.INTERFACE.name)
|
||||
member = header_fields.get(HeaderField.MEMBER.name)
|
||||
error_name = header_fields.get(HeaderField.ERROR_NAME.name)
|
||||
reply_serial = header_fields.get(HeaderField.REPLY_SERIAL.name)
|
||||
destination = header_fields.get(HeaderField.DESTINATION.name)
|
||||
sender = header_fields.get(HeaderField.SENDER.name)
|
||||
signature = header_fields.get(HeaderField.SIGNATURE.name, "")
|
||||
signature_tree = SignatureTree._get(signature)
|
||||
# unix_fds = header_fields.get(HeaderField.UNIX_FDS.name, 0)
|
||||
|
||||
body = []
|
||||
|
||||
if body_len:
|
||||
for type_ in signature_tree.types:
|
||||
body.append(self.read_argument(type_))
|
||||
self.body_len, self.serial, self.header_len = UNPACK_LENGTHS[
|
||||
endian
|
||||
].unpack_from(buffer, 4)
|
||||
self.msg_len = (
|
||||
self.header_len + (-self.header_len & 7) + self.body_len
|
||||
) # align 8
|
||||
if (sys.byteorder == "little" and endian == LITTLE_ENDIAN) or (
|
||||
sys.byteorder == "big" and endian == BIG_ENDIAN
|
||||
):
|
||||
self.can_cast = True
|
||||
self.readers = self._readers_by_type[endian]
|
||||
|
||||
def _read_body(self):
|
||||
"""Read the body of the message."""
|
||||
self.read_to_offset(HEADER_SIGNATURE_SIZE + self.msg_len)
|
||||
self.view = memoryview(self.buf)
|
||||
self.offset = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION
|
||||
header_fields = self.header_fields(self.header_len)
|
||||
self.offset += -self.offset & 7 # align 8
|
||||
tree = SignatureTree._get(header_fields.get(HeaderField.SIGNATURE.name, ""))
|
||||
self.message = Message(
|
||||
destination=destination,
|
||||
path=path,
|
||||
interface=interface,
|
||||
member=member,
|
||||
message_type=message_type,
|
||||
flags=flags,
|
||||
error_name=error_name,
|
||||
reply_serial=reply_serial,
|
||||
sender=sender,
|
||||
destination=header_fields.get(HEADER_DESTINATION),
|
||||
path=header_fields.get(HEADER_PATH),
|
||||
interface=header_fields.get(HEADER_INTERFACE),
|
||||
member=header_fields.get(HEADER_MEMBER),
|
||||
message_type=self.message_type,
|
||||
flags=self.flag,
|
||||
error_name=header_fields.get(HEADER_ERROR_NAME),
|
||||
reply_serial=header_fields.get(HEADER_REPLY_SERIAL),
|
||||
sender=header_fields.get(HEADER_SENDER),
|
||||
unix_fds=self.unix_fds,
|
||||
signature=signature_tree,
|
||||
body=body,
|
||||
serial=serial,
|
||||
signature=tree.signature,
|
||||
body=[self.read_argument(t) for t in tree.types] if self.body_len else [],
|
||||
serial=self.serial,
|
||||
)
|
||||
|
||||
def unmarshall(self):
|
||||
"""Unmarshall the message.
|
||||
|
||||
The underlying read function will raise MarshallerStreamEndError
|
||||
if there are not enough bytes in the buffer. This allows unmarshall
|
||||
to be resumed when more data comes in over the wire.
|
||||
"""
|
||||
try:
|
||||
self._unmarshall()
|
||||
return self.message
|
||||
if not self.message_type:
|
||||
self._read_header()
|
||||
self._read_body()
|
||||
except MarshallerStreamEndError:
|
||||
return None
|
||||
return self.message
|
||||
|
||||
_complex_parsers: Dict[
|
||||
str, Tuple[Callable[["Unmarshaller", SignatureType], Any], None, None, None]
|
||||
] = {
|
||||
"b": (read_boolean, None, None, None),
|
||||
"o": (read_string, None, None, None),
|
||||
"s": (read_string, None, None, None),
|
||||
"g": (read_signature, None, None, None),
|
||||
"a": (read_array, None, None, None),
|
||||
"(": (read_struct, None, None, None),
|
||||
"{": (read_dict_entry, None, None, None),
|
||||
"v": (read_variant, None, None, None),
|
||||
}
|
||||
|
||||
_ctype_by_endian: Dict[int, Dict[str, Tuple[None, str, int, Struct]]] = {
|
||||
endian: {
|
||||
dbus_type: (
|
||||
None,
|
||||
*ctype_size,
|
||||
Struct(f"{UNPACK_SYMBOL[endian]}{ctype_size[0]}"),
|
||||
)
|
||||
for dbus_type, ctype_size in DBUS_TO_CTYPE.items()
|
||||
}
|
||||
for endian in (BIG_ENDIAN, LITTLE_ENDIAN)
|
||||
}
|
||||
|
||||
_readers_by_type: Dict[int, READER_TYPE] = {
|
||||
BIG_ENDIAN: {**_ctype_by_endian[BIG_ENDIAN], **_complex_parsers},
|
||||
LITTLE_ENDIAN: {**_ctype_by_endian[LITTLE_ENDIAN], **_complex_parsers},
|
||||
}
|
||||
|
||||
@ -19,6 +19,9 @@ class MessageType(Enum):
|
||||
SIGNAL = 4 #: A broadcast signal to subscribed connections
|
||||
|
||||
|
||||
MESSAGE_TYPE_MAP = {field.value: field for field in MessageType}
|
||||
|
||||
|
||||
class MessageFlag(IntFlag):
|
||||
"""Flags that affect the behavior of sent and received messages"""
|
||||
|
||||
@ -28,6 +31,9 @@ class MessageFlag(IntFlag):
|
||||
ALLOW_INTERACTIVE_AUTHORIZATION = 4
|
||||
|
||||
|
||||
MESSAGE_FLAG_MAP = {field.value: field for field in MessageFlag}
|
||||
|
||||
|
||||
class NameFlag(IntFlag):
|
||||
"""A flag that affects the behavior of a name request."""
|
||||
|
||||
|
||||
@ -12,6 +12,13 @@ from .validators import (
|
||||
assert_object_path_valid,
|
||||
)
|
||||
|
||||
REQUIRED_FIELDS = {
|
||||
MessageType.METHOD_CALL: ("path", "member"),
|
||||
MessageType.SIGNAL: ("path", "member", "interface"),
|
||||
MessageType.ERROR: ("error_name", "reply_serial"),
|
||||
MessageType.METHOD_RETURN: ("reply_serial",),
|
||||
}
|
||||
|
||||
|
||||
class Message:
|
||||
"""A class for sending and receiving messages through the
|
||||
@ -112,21 +119,12 @@ class Message:
|
||||
if self.error_name is not None:
|
||||
assert_interface_name_valid(self.error_name)
|
||||
|
||||
def require_fields(*fields):
|
||||
for field in fields:
|
||||
if not getattr(self, field):
|
||||
raise InvalidMessageError(f"missing required field: {field}")
|
||||
|
||||
if self.message_type == MessageType.METHOD_CALL:
|
||||
require_fields("path", "member")
|
||||
elif self.message_type == MessageType.SIGNAL:
|
||||
require_fields("path", "member", "interface")
|
||||
elif self.message_type == MessageType.ERROR:
|
||||
require_fields("error_name", "reply_serial")
|
||||
elif self.message_type == MessageType.METHOD_RETURN:
|
||||
require_fields("reply_serial")
|
||||
else:
|
||||
required_fields = REQUIRED_FIELDS.get(self.message_type)
|
||||
if not required_fields:
|
||||
raise InvalidMessageError(f"got unknown message type: {self.message_type}")
|
||||
for field in required_fields:
|
||||
if not getattr(self, field):
|
||||
raise InvalidMessageError(f"missing required field: {field}")
|
||||
|
||||
@staticmethod
|
||||
def new_error(msg: "Message", error_name: str, error_text: str) -> "Message":
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Union
|
||||
|
||||
from .errors import InvalidSignatureError, SignatureBodyMismatchError
|
||||
@ -21,9 +22,9 @@ class SignatureType:
|
||||
|
||||
_tokens = "ybnqiuxtdsogavh({"
|
||||
|
||||
def __init__(self, token):
|
||||
def __init__(self, token: str) -> None:
|
||||
self.token = token
|
||||
self.children = []
|
||||
self.children: List[SignatureType] = []
|
||||
self._signature = None
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -240,7 +241,7 @@ class SignatureType:
|
||||
child_type.children[0].verify(key)
|
||||
child_type.children[1].verify(value)
|
||||
elif child_type.token == "y":
|
||||
if not isinstance(body, bytes):
|
||||
if not isinstance(body, (bytearray, bytes)):
|
||||
raise SignatureBodyMismatchError(
|
||||
f'DBus ARRAY type "a" with BYTE child must be Python type "bytes", got {type(body)}'
|
||||
)
|
||||
@ -284,43 +285,33 @@ class SignatureType:
|
||||
"""
|
||||
if body is None:
|
||||
raise SignatureBodyMismatchError('Cannot serialize Python type "None"')
|
||||
elif self.token == "y":
|
||||
self._verify_byte(body)
|
||||
elif self.token == "b":
|
||||
self._verify_boolean(body)
|
||||
elif self.token == "n":
|
||||
self._verify_int16(body)
|
||||
elif self.token == "q":
|
||||
self._verify_uint16(body)
|
||||
elif self.token == "i":
|
||||
self._verify_int32(body)
|
||||
elif self.token == "u":
|
||||
self._verify_uint32(body)
|
||||
elif self.token == "x":
|
||||
self._verify_int64(body)
|
||||
elif self.token == "t":
|
||||
self._verify_uint64(body)
|
||||
elif self.token == "d":
|
||||
self._verify_double(body)
|
||||
elif self.token == "h":
|
||||
self._verify_unix_fd(body)
|
||||
elif self.token == "o":
|
||||
self._verify_object_path(body)
|
||||
elif self.token == "s":
|
||||
self._verify_string(body)
|
||||
elif self.token == "g":
|
||||
self._verify_signature(body)
|
||||
elif self.token == "a":
|
||||
self._verify_array(body)
|
||||
elif self.token == "(":
|
||||
self._verify_struct(body)
|
||||
elif self.token == "v":
|
||||
self._verify_variant(body)
|
||||
validator = self.validators.get(self.token)
|
||||
if validator:
|
||||
validator(self, body)
|
||||
else:
|
||||
raise Exception(f"cannot verify type with token {self.token}")
|
||||
|
||||
return True
|
||||
|
||||
validators = {
|
||||
"y": _verify_byte,
|
||||
"b": _verify_boolean,
|
||||
"n": _verify_int16,
|
||||
"q": _verify_uint16,
|
||||
"i": _verify_int32,
|
||||
"u": _verify_uint32,
|
||||
"x": _verify_int64,
|
||||
"t": _verify_uint64,
|
||||
"d": _verify_double,
|
||||
"h": _verify_uint32,
|
||||
"o": _verify_string,
|
||||
"s": _verify_string,
|
||||
"g": _verify_signature,
|
||||
"a": _verify_array,
|
||||
"(": _verify_struct,
|
||||
"v": _verify_variant,
|
||||
}
|
||||
|
||||
|
||||
class SignatureTree:
|
||||
"""A class that represents a signature as a tree structure for conveniently
|
||||
@ -338,19 +329,15 @@ class SignatureTree:
|
||||
:class:`InvalidSignatureError` if the given signature is not valid.
|
||||
"""
|
||||
|
||||
_cache = {}
|
||||
|
||||
@staticmethod
|
||||
def _get(signature: str = ""):
|
||||
if signature in SignatureTree._cache:
|
||||
return SignatureTree._cache[signature]
|
||||
SignatureTree._cache[signature] = SignatureTree(signature)
|
||||
return SignatureTree._cache[signature]
|
||||
@lru_cache(maxsize=None)
|
||||
def _get(signature: str = "") -> "SignatureTree":
|
||||
return SignatureTree(signature)
|
||||
|
||||
def __init__(self, signature: str = ""):
|
||||
self.signature = signature
|
||||
|
||||
self.types = []
|
||||
self.types: List[SignatureType] = []
|
||||
|
||||
if len(signature) > 0xFF:
|
||||
raise InvalidSignatureError("A signature must be less than 256 characters")
|
||||
@ -411,7 +398,12 @@ class Variant:
|
||||
:class:`SignatureBodyMismatchError` if the signature does not match the body.
|
||||
"""
|
||||
|
||||
def __init__(self, signature: Union[str, SignatureTree, SignatureType], value: Any):
|
||||
def __init__(
|
||||
self,
|
||||
signature: Union[str, SignatureTree, SignatureType],
|
||||
value: Any,
|
||||
verify: bool = True,
|
||||
):
|
||||
signature_str = ""
|
||||
signature_tree = None
|
||||
signature_type = None
|
||||
@ -429,14 +421,15 @@ class Variant:
|
||||
)
|
||||
|
||||
if signature_tree:
|
||||
if len(signature_tree.types) != 1:
|
||||
if verify and len(signature_tree.types) != 1:
|
||||
raise ValueError(
|
||||
"variants must have a signature for a single complete type"
|
||||
)
|
||||
signature_str = signature_tree.signature
|
||||
signature_type = signature_tree.types[0]
|
||||
|
||||
signature_type.verify(value)
|
||||
if verify:
|
||||
signature_type.verify(value)
|
||||
|
||||
self.type = signature_type
|
||||
self.signature = signature_str
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
from .errors import (
|
||||
InvalidBusNameError,
|
||||
@ -13,6 +14,7 @@ _element_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
_member_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_-]*$")
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def is_bus_name_valid(name: str) -> bool:
|
||||
"""Whether this is a valid bus name.
|
||||
|
||||
@ -47,6 +49,7 @@ def is_bus_name_valid(name: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def is_object_path_valid(path: str) -> bool:
|
||||
"""Whether this is a valid object path.
|
||||
|
||||
@ -77,6 +80,7 @@ def is_object_path_valid(path: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def is_interface_name_valid(name: str) -> bool:
|
||||
"""Whether this is a valid interface name.
|
||||
|
||||
@ -107,6 +111,7 @@ def is_interface_name_valid(name: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def is_member_name_valid(member: str) -> bool:
|
||||
"""Whether this is a valid member name.
|
||||
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from dbus_fast import Message, SignatureTree, Variant
|
||||
import pytest
|
||||
|
||||
from dbus_fast import Message, MessageFlag, MessageType, SignatureTree, Variant
|
||||
from dbus_fast._private.unmarshaller import Unmarshaller
|
||||
|
||||
|
||||
@ -20,6 +23,16 @@ def print_buf(buf):
|
||||
table = json.load(open(os.path.dirname(__file__) + "/data/messages.json"))
|
||||
|
||||
|
||||
def json_to_message(message: Dict[str, Any]) -> Message:
|
||||
copy = dict(message)
|
||||
if "message_type" in copy:
|
||||
copy["message_type"] = MessageType(copy["message_type"])
|
||||
if "flags" in copy:
|
||||
copy["flags"] = MessageFlag(copy["flags"])
|
||||
|
||||
return Message(**copy)
|
||||
|
||||
|
||||
# variants are an object in the json
|
||||
def replace_variants(type_, item):
|
||||
if type_.token == "v" and type(item) is not Variant:
|
||||
@ -56,7 +69,7 @@ def json_dump(what):
|
||||
|
||||
def test_marshalling_with_table():
|
||||
for item in table:
|
||||
message = Message(**item["message"])
|
||||
message = json_to_message(item["message"])
|
||||
|
||||
body = []
|
||||
for i, type_ in enumerate(message.signature_tree.types):
|
||||
@ -79,8 +92,9 @@ def test_marshalling_with_table():
|
||||
assert buf == data
|
||||
|
||||
|
||||
def test_unmarshalling_with_table():
|
||||
for item in table:
|
||||
@pytest.mark.parametrize("unmarshall_table", (table,))
|
||||
def test_unmarshalling_with_table(unmarshall_table):
|
||||
for item in unmarshall_table:
|
||||
|
||||
stream = io.BytesIO(bytes.fromhex(item["data"]))
|
||||
unmarshaller = Unmarshaller(stream)
|
||||
@ -91,7 +105,7 @@ def test_unmarshalling_with_table():
|
||||
print(json_dump(item["message"]))
|
||||
raise e
|
||||
|
||||
message = Message(**item["message"])
|
||||
message = json_to_message(item["message"])
|
||||
|
||||
body = []
|
||||
for i, type_ in enumerate(message.signature_tree.types):
|
||||
@ -114,6 +128,39 @@ def test_unmarshalling_with_table():
|
||||
), f"attr doesnt match: {attr}"
|
||||
|
||||
|
||||
def test_unmarshall_can_resume():
|
||||
"""Verify resume works."""
|
||||
bluez_rssi_message = (
|
||||
"6c04010134000000e25389019500000001016f00250000002f6f72672f626c75657a2f686369302f6465"
|
||||
"765f30385f33415f46325f31455f32425f3631000000020173001f0000006f72672e667265656465736b"
|
||||
"746f702e444275732e50726f7065727469657300030173001100000050726f706572746965734368616e"
|
||||
"67656400000000000000080167000873617b73767d617300000007017300040000003a312e3400000000"
|
||||
"110000006f72672e626c75657a2e446576696365310000000e0000000000000004000000525353490001"
|
||||
"6e00a7ff000000000000"
|
||||
)
|
||||
message_bytes = bytes.fromhex(bluez_rssi_message)
|
||||
|
||||
class SlowStream(io.IOBase):
|
||||
"""A fake stream that will only give us one byte at a time."""
|
||||
|
||||
def __init__(self):
|
||||
self.data = message_bytes
|
||||
self.pos = 0
|
||||
|
||||
def read(self, n) -> bytes:
|
||||
data = self.data[self.pos : self.pos + 1]
|
||||
self.pos += 1
|
||||
return data
|
||||
|
||||
stream = SlowStream()
|
||||
unmarshaller = Unmarshaller(stream)
|
||||
|
||||
for _ in range(len(bluez_rssi_message)):
|
||||
if unmarshaller.unmarshall():
|
||||
break
|
||||
assert unmarshaller.message is not None
|
||||
|
||||
|
||||
def test_ay_buffer():
|
||||
body = [bytes(10000)]
|
||||
msg = Message(path="/test", member="test", signature="ay", body=body)
|
||||
|
||||
@ -10,7 +10,6 @@ def test_object_path_validator():
|
||||
valid_paths = ["/", "/foo", "/foo/bar", "/foo/bar/bat"]
|
||||
invalid_paths = [
|
||||
None,
|
||||
{},
|
||||
"",
|
||||
"foo",
|
||||
"foo/bar",
|
||||
@ -37,7 +36,6 @@ def test_bus_name_validator():
|
||||
]
|
||||
invalid_names = [
|
||||
None,
|
||||
{},
|
||||
"",
|
||||
"5foo.bar",
|
||||
"foo.6bar",
|
||||
@ -57,7 +55,6 @@ def test_interface_name_validator():
|
||||
valid_names = ["foo.bar", "foo.bar.bat", "_foo._bar", "foo.bar69"]
|
||||
invalid_names = [
|
||||
None,
|
||||
{},
|
||||
"",
|
||||
"5foo.bar",
|
||||
"foo.6bar",
|
||||
@ -80,7 +77,7 @@ def test_interface_name_validator():
|
||||
|
||||
def test_member_name_validator():
|
||||
valid_members = ["foo", "FooBar", "Bat_Baz69", "foo-bar"]
|
||||
invalid_members = [None, {}, "", "foo.bar", "5foo", "foo$bar"]
|
||||
invalid_members = [None, "", "foo.bar", "5foo", "foo$bar"]
|
||||
|
||||
for member in valid_members:
|
||||
assert is_member_name_valid(member), f'member name should be valid: "{member}"'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user