From f38e08fa7cc8d41e896663ab0f163aa37a472abe Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 3 Oct 2022 15:11:33 -1000 Subject: [PATCH] feat: speed up unmarshall performance (#71) --- bench/unmarshall.py | 9 +- bench/unmarshall_2.py | 9 +- src/dbus_fast/_private/unmarshaller.pxd | 32 ++-- src/dbus_fast/_private/unmarshaller.py | 243 +++++++++++++----------- src/dbus_fast/aio/message_bus.py | 13 +- src/dbus_fast/glib/message_bus.py | 10 + src/dbus_fast/message.py | 20 +- src/dbus_fast/message_bus.py | 8 - 8 files changed, 196 insertions(+), 148 deletions(-) diff --git a/bench/unmarshall.py b/bench/unmarshall.py index 71bfc0f..6c4fb26 100644 --- a/bench/unmarshall.py +++ b/bench/unmarshall.py @@ -16,8 +16,15 @@ bluez_rssi_message = ( ) +stream = io.BytesIO(bytes.fromhex(bluez_rssi_message)) + +unmarshaller = Unmarshaller(stream) + + def unmarhsall_bluez_rssi_message(): - Unmarshaller(io.BytesIO(bytes.fromhex(bluez_rssi_message))).unmarshall() + stream.seek(0) + unmarshaller.reset() + unmarshaller.unmarshall() count = 1000000 diff --git a/bench/unmarshall_2.py b/bench/unmarshall_2.py index 0c308d9..0df9a56 100644 --- a/bench/unmarshall_2.py +++ b/bench/unmarshall_2.py @@ -35,8 +35,15 @@ bluez_properties_message = ( ) +stream = io.BytesIO(bluez_properties_message) + +unmarshaller = Unmarshaller(stream) + + def unmarhsall_bluez_rssi_message(): - Unmarshaller(io.BytesIO(bluez_properties_message)).unmarshall() + stream.seek(0) + unmarshaller.reset() + unmarshaller.unmarshall() count = 1000000 diff --git a/src/dbus_fast/_private/unmarshaller.pxd b/src/dbus_fast/_private/unmarshaller.pxd index 3d68d58..4402ea0 100644 --- a/src/dbus_fast/_private/unmarshaller.pxd +++ b/src/dbus_fast/_private/unmarshaller.pxd @@ -8,25 +8,30 @@ from ..signature import SignatureType cdef unsigned int UINT32_SIZE cdef unsigned int HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION cdef unsigned int HEADER_SIGNATURE_SIZE +cdef unsigned int LITTLE_ENDIAN +cdef unsigned int BIG_ENDIAN +cdef str UINT32_CAST +cdef object UINT32_SIGNATURE cdef class Unmarshaller: - cdef object unix_fds - cdef bytearray buf - cdef object view - cdef unsigned int pos - cdef object stream - cdef object sock + cdef object _unix_fds + cdef bytearray _buf + cdef object _view + cdef unsigned int _pos + cdef object _stream + cdef object _sock cdef object _message - cdef object readers - cdef unsigned int body_len - cdef unsigned int serial - cdef unsigned int header_len - cdef object message_type - cdef object flag - cdef unsigned int msg_len + cdef object _readers + cdef unsigned int _body_len + cdef unsigned int _serial + cdef unsigned int _header_len + cdef unsigned int _message_type + cdef unsigned int _flag + cdef unsigned int _msg_len cdef object _uint32_unpack + cpdef reset(self) @cython.locals( start_len=cython.ulong, @@ -56,6 +61,7 @@ cdef class Unmarshaller: @cython.locals( endian=cython.uint, protocol_version=cython.uint, + can_cast=cython.bint ) cpdef _read_header(self) diff --git a/src/dbus_fast/_private/unmarshaller.py b/src/dbus_fast/_private/unmarshaller.py index f8e58d1..3ed53d2 100644 --- a/src/dbus_fast/_private/unmarshaller.py +++ b/src/dbus_fast/_private/unmarshaller.py @@ -65,8 +65,8 @@ def cast_parser_factory(ctype: str, size: int) -> READER_TYPE: """Build a parser that casts the bytes to the given ctype.""" def _cast_parser(self: "Unmarshaller", signature: SignatureType) -> Any: - self.pos += size + (-self.pos & (size - 1)) # align - return self.view[self.pos - size : self.pos].cast(ctype)[0] + self._pos += size + (-self._pos & (size - 1)) # align + return self._view[self._pos - size : self._pos].cast(ctype)[0] return _cast_parser @@ -75,8 +75,8 @@ def unpack_parser_factory(unpack_from: Callable, size: int) -> READER_TYPE: """Build a parser that unpacks the bytes using the given unpack_from function.""" def _unpack_from_parser(self: "Unmarshaller", signature: SignatureType) -> Any: - self.pos += size + (-self.pos & (size - 1)) # align - return unpack_from(self.view, self.pos - size)[0] + self._pos += size + (-self._pos & (size - 1)) # align + return unpack_from(self._view, self._pos - size)[0] return _unpack_from_parser @@ -129,41 +129,59 @@ class MarshallerStreamEndError(Exception): class Unmarshaller: __slots__ = ( - "unix_fds", - "buf", - "view", - "pos", - "stream", - "sock", + "_unix_fds", + "_buf", + "_view", + "_pos", + "_stream", + "_sock", "_message", - "readers", - "body_len", - "serial", - "header_len", - "message_type", - "flag", - "msg_len", + "_readers", + "_body_len", + "_serial", + "_header_len", + "_message_type", + "_flag", + "_msg_len", "_uint32_unpack", ) def __init__(self, stream: io.BufferedRWPair, sock=None): - self.unix_fds: List[int] = [] - self.buf = bytearray() # Actual buffer - self.view = None # Memory view of the buffer - self.pos = 0 - self.stream = stream - self.sock = sock + self._unix_fds: List[int] = [] + self._buf = bytearray() # Actual buffer + self._view = None # Memory view of the buffer + self._stream = stream + self._sock = sock self._message: Message | None = None - self.readers: Dict[str, READER_TYPE] = {} - self.body_len = 0 - self.serial = 0 - self.header_len = 0 - self.message_type: MessageType | None = None - self.flag: MessageFlag | None = None - self.msg_len = 0 + self._readers: Dict[str, READER_TYPE] = {} + self._pos = 0 + self._body_len = 0 + self._serial = 0 + self._header_len = 0 + self._message_type = 0 + self._flag = 0 + self._msg_len = 0 # Only set if we cannot cast self._uint32_unpack: Callable | None = None + def reset(self) -> None: + """Reset the unmarshaller to its initial state. + + Call this before processing a new message. + """ + self._unix_fds: List[int] = [] + self._view = None + self._buf.clear() + self._message = None + self._pos = 0 + self._body_len = 0 + self._serial = 0 + self._header_len = 0 + self._message_type = 0 + self._flag = 0 + self._msg_len = 0 + self._uint32_unpack = None + @property def message(self) -> Message: """Return the message that has been unmarshalled.""" @@ -175,7 +193,7 @@ class Unmarshaller: unix_fd_list = array.array("i") try: - msg, ancdata, *_ = self.sock.recvmsg( + msg, ancdata, *_ = self._sock.recvmsg( length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize) ) except BlockingIOError: @@ -187,7 +205,7 @@ class Unmarshaller: unix_fd_list.frombytes( data[: len(data) - (len(data) % unix_fd_list.itemsize)] ) - self.unix_fds.extend(list(unix_fd_list)) + self._unix_fds.extend(list(unix_fd_list)) return msg @@ -204,127 +222,130 @@ class Unmarshaller: :returns: None """ - start_len = len(self.buf) - missing_bytes = pos - (start_len - self.pos) - if self.sock is None: - data = self.stream.read(missing_bytes) + start_len = len(self._buf) + missing_bytes = pos - (start_len - self._pos) + if self._sock is None: + data = self._stream.read(missing_bytes) else: data = self.read_sock(missing_bytes) if data == b"": raise EOFError() if data is None: raise MarshallerStreamEndError() - self.buf.extend(data) + self._buf.extend(data) if len(data) + start_len != pos: raise MarshallerStreamEndError() def read_uint32_cast(self, signature: SignatureType) -> Any: - self.pos += UINT32_SIZE + (-self.pos & (UINT32_SIZE - 1)) # align - return self.view[self.pos - UINT32_SIZE : self.pos].cast(UINT32_CAST)[0] + self._pos += UINT32_SIZE + (-self._pos & (UINT32_SIZE - 1)) # align + return self._view[self._pos - UINT32_SIZE : self._pos].cast(UINT32_CAST)[0] def read_boolean(self, type_=None) -> bool: - return bool(self.readers[UINT32_SIGNATURE.token](self, UINT32_SIGNATURE)) + return bool(self._readers[UINT32_SIGNATURE.token](self, UINT32_SIGNATURE)) def read_string_cast(self, type_=None) -> str: """Read a string using cast.""" - self.pos += UINT32_SIZE + (-self.pos & (UINT32_SIZE - 1)) # align - str_start = self.pos + self._pos += UINT32_SIZE + (-self._pos & (UINT32_SIZE - 1)) # align + str_start = self._pos # read terminating '\0' byte as well (str_length + 1) - start_pos = self.pos - UINT32_SIZE - self.pos += self.view[start_pos : self.pos].cast(UINT32_CAST)[0] + 1 - return self.buf[str_start : self.pos - 1].decode() + start_pos = self._pos - UINT32_SIZE + self._pos += self._view[start_pos : self._pos].cast(UINT32_CAST)[0] + 1 + return self._buf[str_start : self._pos - 1].decode() def read_string_unpack(self, type_=None) -> str: """Read a string using unpack.""" - self.pos += UINT32_SIZE + (-self.pos & (UINT32_SIZE - 1)) # align - str_start = self.pos + self._pos += UINT32_SIZE + (-self._pos & (UINT32_SIZE - 1)) # align + str_start = self._pos # read terminating '\0' byte as well (str_length + 1) - self.pos += self._uint32_unpack(self.view, str_start - UINT32_SIZE)[0] + 1 - return self.buf[str_start : self.pos - 1].decode() + self._pos += self._uint32_unpack(self._view, str_start - UINT32_SIZE)[0] + 1 + return self._buf[str_start : self._pos - 1].decode() def read_signature(self, type_=None) -> str: - signature_len = self.view[self.pos] # byte - o = self.pos + 1 + signature_len = self._view[self._pos] # byte + o = self._pos + 1 # read terminating '\0' byte as well (str_length + 1) - self.pos = o + signature_len + 1 - return self.buf[o : o + signature_len].decode() + self._pos = o + signature_len + 1 + return self._buf[o : o + signature_len].decode() def read_variant(self, type_=None) -> Variant: tree = SignatureTree._get(self.read_signature()) # verify in Variant is only useful on construction not unmarshalling return Variant( - tree, self.readers[tree.types[0].token](self, tree.types[0]), verify=False + tree, self._readers[tree.types[0].token](self, tree.types[0]), verify=False ) def read_struct(self, type_=None) -> List[Any]: - self.pos += -self.pos & 7 # align 8 - readers = self.readers + self._pos += -self._pos & 7 # align 8 + readers = self._readers return [ readers[child_type.token](self, child_type) for child_type in type_.children ] def read_dict_entry(self, type_: SignatureType) -> Dict[Any, Any]: - self.pos += -self.pos & 7 # align 8 - return self.readers[type_.children[0].token]( + self._pos += -self._pos & 7 # align 8 + return self._readers[type_.children[0].token]( self, type_.children[0] - ), self.readers[type_.children[1].token](self, type_.children[1]) + ), self._readers[type_.children[1].token](self, type_.children[1]) def read_array(self, type_: SignatureType) -> List[Any]: - self.pos += -self.pos & 3 # align 4 for the array - self.pos += ( - -self.pos & (UINT32_SIZE - 1) + self._pos += -self._pos & 3 # align 4 for the array + self._pos += ( + -self._pos & (UINT32_SIZE - 1) ) + UINT32_SIZE # align for the uint32 if self._uint32_unpack: - array_length = self._uint32_unpack(self.view, self.pos - UINT32_SIZE)[0] + array_length = self._uint32_unpack(self._view, self._pos - UINT32_SIZE)[0] else: - array_length = self.view[self.pos - UINT32_SIZE : self.pos].cast( + array_length = self._view[self._pos - UINT32_SIZE : self._pos].cast( UINT32_CAST )[0] child_type = type_.children[0] - if child_type.token in "xtd{(": + token = child_type.token + + if token in "xtd{(": # the first alignment is not included in the array size - self.pos += -self.pos & 7 # align 8 + self._pos += -self._pos & 7 # align 8 - if child_type.token == "y": - self.pos += array_length - return self.buf[self.pos - array_length : self.pos] + if token == "y": + self._pos += array_length + return self._buf[self._pos - array_length : self._pos] - beginning_pos = self.pos - readers = self.readers + beginning_pos = self._pos + readers = self._readers - if child_type.token == "{": + if token == "{": result_dict = {} - while self.pos - beginning_pos < array_length: - self.pos += -self.pos & 7 # align 8 - key = readers[child_type.children[0].token]( - self, child_type.children[0] - ) - result_dict[key] = readers[child_type.children[1].token]( - self, child_type.children[1] - ) + child_0 = child_type.children[0] + reader_0 = readers[child_0.token] + child_1 = child_type.children[1] + reader_1 = readers[child_1.token] + while self._pos - beginning_pos < array_length: + self._pos += -self._pos & 7 # align 8 + key = reader_0(self, child_0) + result_dict[key] = reader_1(self, child_1) return result_dict result_list = [] - while self.pos - beginning_pos < array_length: - result_list.append(readers[child_type.token](self, child_type)) + reader = readers[child_type.token] + while self._pos - beginning_pos < array_length: + result_list.append(reader(self, child_type)) return result_list def header_fields(self, header_length) -> Dict[str, Any]: """Header fields are always a(yv).""" - beginning_pos = self.pos + beginning_pos = self._pos headers = {} - while self.pos - beginning_pos < header_length: + while self._pos - beginning_pos < header_length: # Now read the y (byte) of struct (yv) - self.pos += (-self.pos & 7) + 1 # align 8 + 1 for 'y' byte - field_0 = self.view[self.pos - 1] + self._pos += (-self._pos & 7) + 1 # align 8 + 1 for 'y' byte + field_0 = self._view[self._pos - 1] # Now read the v (variant) of struct (yv) - signature_len = self.view[self.pos] # byte - o = self.pos + 1 - self.pos += signature_len + 2 # one for the byte, one for the '\0' - tree = SignatureTree._get(self.buf[o : o + signature_len].decode()) - headers[HEADER_NAME_MAP[field_0]] = self.readers[tree.types[0].token]( + signature_len = self._view[self._pos] # byte + o = self._pos + 1 + self._pos += signature_len + 2 # one for the byte, one for the '\0' + tree = SignatureTree._get(self._buf[o : o + signature_len].decode()) + headers[HEADER_NAME_MAP[field_0]] = self._readers[tree.types[0].token]( self, tree.types[0] ) return headers @@ -334,10 +355,10 @@ class Unmarshaller: # Signature is of the header is # BYTE, BYTE, BYTE, BYTE, UINT32, UINT32, ARRAY of STRUCT of (BYTE,VARIANT) self.read_to_pos(HEADER_SIGNATURE_SIZE) - buffer = self.buf + buffer = self._buf endian = buffer[0] - self.message_type = MESSAGE_TYPE_MAP[buffer[1]] - self.flag = MESSAGE_FLAG_MAP[buffer[2]] + self._message_type = buffer[1] + self._flag = buffer[2] protocol_version = buffer[3] if endian != LITTLE_ENDIAN and endian != BIG_ENDIAN: @@ -349,44 +370,44 @@ class Unmarshaller: f"got unknown protocol version: {protocol_version}" ) - self.body_len, self.serial, self.header_len = UNPACK_LENGTHS[ + self._body_len, self._serial, self._header_len = UNPACK_LENGTHS[ endian ].unpack_from(buffer, 4) - self.msg_len = ( - self.header_len + (-self.header_len & 7) + self.body_len + self._msg_len = ( + self._header_len + (-self._header_len & 7) + self._body_len ) # align 8 can_cast = bool( (IS_LITTLE_ENDIAN and endian == LITTLE_ENDIAN) or (IS_BIG_ENDIAN and endian == BIG_ENDIAN) ) - self.readers = self._readers_by_type[(endian, can_cast)] + self._readers = self._readers_by_type[(endian, can_cast)] if not can_cast: self._uint32_unpack = UINT32_UNPACK_BY_ENDIAN[endian] def _read_body(self): """Read the body of the message.""" - self.read_to_pos(HEADER_SIGNATURE_SIZE + self.msg_len) - self.view = memoryview(self.buf) - self.pos = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION - header_fields = self.header_fields(self.header_len) - self.pos += -self.pos & 7 # align 8 + self.read_to_pos(HEADER_SIGNATURE_SIZE + self._msg_len) + self._view = memoryview(self._buf) + self._pos = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION + header_fields = self.header_fields(self._header_len) + self._pos += -self._pos & 7 # align 8 tree = SignatureTree._get(header_fields.get(HeaderField.SIGNATURE.name, "")) self._message = Message( destination=header_fields.get(HEADER_DESTINATION), path=header_fields.get(HEADER_PATH), interface=header_fields.get(HEADER_INTERFACE), member=header_fields.get(HEADER_MEMBER), - message_type=self.message_type, - flags=self.flag, + message_type=MESSAGE_TYPE_MAP[self._message_type], + flags=MESSAGE_FLAG_MAP[self._flag], error_name=header_fields.get(HEADER_ERROR_NAME), reply_serial=header_fields.get(HEADER_REPLY_SERIAL), sender=header_fields.get(HEADER_SENDER), - unix_fds=self.unix_fds, - signature=tree.signature, - body=[self.readers[t.token](self, t) for t in tree.types] - if self.body_len + unix_fds=self._unix_fds, + signature=tree, + body=[self._readers[t.token](self, t) for t in tree.types] + if self._body_len else [], - serial=self.serial, + serial=self._serial, # The D-Bus implementation already validates the message, # so we don't need to do it again. validate=False, @@ -400,7 +421,7 @@ class Unmarshaller: to be resumed when more data comes in over the wire. """ try: - if not self.msg_len: + if not self._msg_len: self._read_header() self._read_body() except MarshallerStreamEndError: diff --git a/src/dbus_fast/aio/message_bus.py b/src/dbus_fast/aio/message_bus.py index 8e88996..6406da0 100644 --- a/src/dbus_fast/aio/message_bus.py +++ b/src/dbus_fast/aio/message_bus.py @@ -2,6 +2,7 @@ import array import asyncio import logging import socket +import traceback from collections import deque from copy import copy from typing import Any, Optional @@ -419,11 +420,17 @@ class MessageBus(BaseMessageBus): return handler def _message_reader(self) -> None: + unmarshaller = self._unmarshaller try: while True: - if self._unmarshaller.unmarshall(): - self._on_message(self._unmarshaller.message) - self._unmarshaller = self._create_unmarshaller() + if unmarshaller.unmarshall(): + try: + self._process_message(unmarshaller.message) + except Exception as e: + logging.error( + f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}" + ) + unmarshaller.reset() else: break except Exception as e: diff --git a/src/dbus_fast/glib/message_bus.py b/src/dbus_fast/glib/message_bus.py index fd62dcb..8ceb370 100644 --- a/src/dbus_fast/glib/message_bus.py +++ b/src/dbus_fast/glib/message_bus.py @@ -1,4 +1,6 @@ import io +import logging +import traceback from typing import Callable, Optional from .. import introspection as intr @@ -173,6 +175,14 @@ class MessageBus(BaseMessageBus): else: self._auth = auth + def _on_message(self, msg: Message) -> None: + try: + self._process_message(msg) + except Exception as e: + logging.error( + f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}" + ) + def connect( self, connect_notify: Callable[["MessageBus", Optional[Exception]], None] = None ): diff --git a/src/dbus_fast/message.py b/src/dbus_fast/message.py index 100bf40..270453b 100644 --- a/src/dbus_fast/message.py +++ b/src/dbus_fast/message.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Union from ._private.constants import LITTLE_ENDIAN, PROTOCOL_VERSION, HeaderField from ._private.marshaller import Marshaller @@ -105,11 +105,11 @@ class Message: reply_serial: int = None, sender: str = None, unix_fds: List[int] = [], - signature: str = "", + signature: Union[str, SignatureTree] = "", body: List[Any] = [], serial: int = 0, validate: bool = True, - ): + ) -> None: self.destination = destination self.path = path self.interface = interface @@ -124,14 +124,12 @@ class Message: self.reply_serial = reply_serial self.sender = sender self.unix_fds = unix_fds - self.signature = ( - signature.signature if type(signature) is SignatureTree else signature - ) - self.signature_tree = ( - signature - if type(signature) is SignatureTree - else SignatureTree._get(signature) - ) + if type(signature) is SignatureTree: + self.signature = signature.signature + self.signature_tree = signature + else: + self.signature = signature + self.signature_tree = SignatureTree._get(signature) self.body = body self.serial = serial diff --git a/src/dbus_fast/message_bus.py b/src/dbus_fast/message_bus.py index 1f78e2d..1619814 100644 --- a/src/dbus_fast/message_bus.py +++ b/src/dbus_fast/message_bus.py @@ -690,14 +690,6 @@ class BaseMessageBus: ErrorType.INTERNAL_ERROR, "invalid message type for method call", msg ) - def _on_message(self, msg: Message) -> None: - try: - self._process_message(msg) - except Exception as e: - logging.error( - f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}" - ) - def _send_reply(self, msg: Message): bus = self