feat: speed up unmarshall performance (#71)

This commit is contained in:
J. Nick Koston 2022-10-03 15:11:33 -10:00 committed by GitHub
parent 4ea6eae2cc
commit f38e08fa7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 196 additions and 148 deletions

View File

@ -16,8 +16,15 @@ bluez_rssi_message = (
)
stream = io.BytesIO(bytes.fromhex(bluez_rssi_message))
unmarshaller = Unmarshaller(stream)
def unmarhsall_bluez_rssi_message():
Unmarshaller(io.BytesIO(bytes.fromhex(bluez_rssi_message))).unmarshall()
stream.seek(0)
unmarshaller.reset()
unmarshaller.unmarshall()
count = 1000000

View File

@ -35,8 +35,15 @@ bluez_properties_message = (
)
stream = io.BytesIO(bluez_properties_message)
unmarshaller = Unmarshaller(stream)
def unmarhsall_bluez_rssi_message():
Unmarshaller(io.BytesIO(bluez_properties_message)).unmarshall()
stream.seek(0)
unmarshaller.reset()
unmarshaller.unmarshall()
count = 1000000

View File

@ -8,25 +8,30 @@ from ..signature import SignatureType
cdef unsigned int UINT32_SIZE
cdef unsigned int HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION
cdef unsigned int HEADER_SIGNATURE_SIZE
cdef unsigned int LITTLE_ENDIAN
cdef unsigned int BIG_ENDIAN
cdef str UINT32_CAST
cdef object UINT32_SIGNATURE
cdef class Unmarshaller:
cdef object unix_fds
cdef bytearray buf
cdef object view
cdef unsigned int pos
cdef object stream
cdef object sock
cdef object _unix_fds
cdef bytearray _buf
cdef object _view
cdef unsigned int _pos
cdef object _stream
cdef object _sock
cdef object _message
cdef object readers
cdef unsigned int body_len
cdef unsigned int serial
cdef unsigned int header_len
cdef object message_type
cdef object flag
cdef unsigned int msg_len
cdef object _readers
cdef unsigned int _body_len
cdef unsigned int _serial
cdef unsigned int _header_len
cdef unsigned int _message_type
cdef unsigned int _flag
cdef unsigned int _msg_len
cdef object _uint32_unpack
cpdef reset(self)
@cython.locals(
start_len=cython.ulong,
@ -56,6 +61,7 @@ cdef class Unmarshaller:
@cython.locals(
endian=cython.uint,
protocol_version=cython.uint,
can_cast=cython.bint
)
cpdef _read_header(self)

View File

@ -65,8 +65,8 @@ def cast_parser_factory(ctype: str, size: int) -> READER_TYPE:
"""Build a parser that casts the bytes to the given ctype."""
def _cast_parser(self: "Unmarshaller", signature: SignatureType) -> Any:
self.pos += size + (-self.pos & (size - 1)) # align
return self.view[self.pos - size : self.pos].cast(ctype)[0]
self._pos += size + (-self._pos & (size - 1)) # align
return self._view[self._pos - size : self._pos].cast(ctype)[0]
return _cast_parser
@ -75,8 +75,8 @@ def unpack_parser_factory(unpack_from: Callable, size: int) -> READER_TYPE:
"""Build a parser that unpacks the bytes using the given unpack_from function."""
def _unpack_from_parser(self: "Unmarshaller", signature: SignatureType) -> Any:
self.pos += size + (-self.pos & (size - 1)) # align
return unpack_from(self.view, self.pos - size)[0]
self._pos += size + (-self._pos & (size - 1)) # align
return unpack_from(self._view, self._pos - size)[0]
return _unpack_from_parser
@ -129,41 +129,59 @@ class MarshallerStreamEndError(Exception):
class Unmarshaller:
__slots__ = (
"unix_fds",
"buf",
"view",
"pos",
"stream",
"sock",
"_unix_fds",
"_buf",
"_view",
"_pos",
"_stream",
"_sock",
"_message",
"readers",
"body_len",
"serial",
"header_len",
"message_type",
"flag",
"msg_len",
"_readers",
"_body_len",
"_serial",
"_header_len",
"_message_type",
"_flag",
"_msg_len",
"_uint32_unpack",
)
def __init__(self, stream: io.BufferedRWPair, sock=None):
self.unix_fds: List[int] = []
self.buf = bytearray() # Actual buffer
self.view = None # Memory view of the buffer
self.pos = 0
self.stream = stream
self.sock = sock
self._unix_fds: List[int] = []
self._buf = bytearray() # Actual buffer
self._view = None # Memory view of the buffer
self._stream = stream
self._sock = sock
self._message: Message | None = None
self.readers: Dict[str, READER_TYPE] = {}
self.body_len = 0
self.serial = 0
self.header_len = 0
self.message_type: MessageType | None = None
self.flag: MessageFlag | None = None
self.msg_len = 0
self._readers: Dict[str, READER_TYPE] = {}
self._pos = 0
self._body_len = 0
self._serial = 0
self._header_len = 0
self._message_type = 0
self._flag = 0
self._msg_len = 0
# Only set if we cannot cast
self._uint32_unpack: Callable | None = None
def reset(self) -> None:
"""Reset the unmarshaller to its initial state.
Call this before processing a new message.
"""
self._unix_fds: List[int] = []
self._view = None
self._buf.clear()
self._message = None
self._pos = 0
self._body_len = 0
self._serial = 0
self._header_len = 0
self._message_type = 0
self._flag = 0
self._msg_len = 0
self._uint32_unpack = None
@property
def message(self) -> Message:
"""Return the message that has been unmarshalled."""
@ -175,7 +193,7 @@ class Unmarshaller:
unix_fd_list = array.array("i")
try:
msg, ancdata, *_ = self.sock.recvmsg(
msg, ancdata, *_ = self._sock.recvmsg(
length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize)
)
except BlockingIOError:
@ -187,7 +205,7 @@ class Unmarshaller:
unix_fd_list.frombytes(
data[: len(data) - (len(data) % unix_fd_list.itemsize)]
)
self.unix_fds.extend(list(unix_fd_list))
self._unix_fds.extend(list(unix_fd_list))
return msg
@ -204,127 +222,130 @@ class Unmarshaller:
:returns:
None
"""
start_len = len(self.buf)
missing_bytes = pos - (start_len - self.pos)
if self.sock is None:
data = self.stream.read(missing_bytes)
start_len = len(self._buf)
missing_bytes = pos - (start_len - self._pos)
if self._sock is None:
data = self._stream.read(missing_bytes)
else:
data = self.read_sock(missing_bytes)
if data == b"":
raise EOFError()
if data is None:
raise MarshallerStreamEndError()
self.buf.extend(data)
self._buf.extend(data)
if len(data) + start_len != pos:
raise MarshallerStreamEndError()
def read_uint32_cast(self, signature: SignatureType) -> Any:
self.pos += UINT32_SIZE + (-self.pos & (UINT32_SIZE - 1)) # align
return self.view[self.pos - UINT32_SIZE : self.pos].cast(UINT32_CAST)[0]
self._pos += UINT32_SIZE + (-self._pos & (UINT32_SIZE - 1)) # align
return self._view[self._pos - UINT32_SIZE : self._pos].cast(UINT32_CAST)[0]
def read_boolean(self, type_=None) -> bool:
return bool(self.readers[UINT32_SIGNATURE.token](self, UINT32_SIGNATURE))
return bool(self._readers[UINT32_SIGNATURE.token](self, UINT32_SIGNATURE))
def read_string_cast(self, type_=None) -> str:
"""Read a string using cast."""
self.pos += UINT32_SIZE + (-self.pos & (UINT32_SIZE - 1)) # align
str_start = self.pos
self._pos += UINT32_SIZE + (-self._pos & (UINT32_SIZE - 1)) # align
str_start = self._pos
# read terminating '\0' byte as well (str_length + 1)
start_pos = self.pos - UINT32_SIZE
self.pos += self.view[start_pos : self.pos].cast(UINT32_CAST)[0] + 1
return self.buf[str_start : self.pos - 1].decode()
start_pos = self._pos - UINT32_SIZE
self._pos += self._view[start_pos : self._pos].cast(UINT32_CAST)[0] + 1
return self._buf[str_start : self._pos - 1].decode()
def read_string_unpack(self, type_=None) -> str:
"""Read a string using unpack."""
self.pos += UINT32_SIZE + (-self.pos & (UINT32_SIZE - 1)) # align
str_start = self.pos
self._pos += UINT32_SIZE + (-self._pos & (UINT32_SIZE - 1)) # align
str_start = self._pos
# read terminating '\0' byte as well (str_length + 1)
self.pos += self._uint32_unpack(self.view, str_start - UINT32_SIZE)[0] + 1
return self.buf[str_start : self.pos - 1].decode()
self._pos += self._uint32_unpack(self._view, str_start - UINT32_SIZE)[0] + 1
return self._buf[str_start : self._pos - 1].decode()
def read_signature(self, type_=None) -> str:
signature_len = self.view[self.pos] # byte
o = self.pos + 1
signature_len = self._view[self._pos] # byte
o = self._pos + 1
# read terminating '\0' byte as well (str_length + 1)
self.pos = o + signature_len + 1
return self.buf[o : o + signature_len].decode()
self._pos = o + signature_len + 1
return self._buf[o : o + signature_len].decode()
def read_variant(self, type_=None) -> Variant:
tree = SignatureTree._get(self.read_signature())
# verify in Variant is only useful on construction not unmarshalling
return Variant(
tree, self.readers[tree.types[0].token](self, tree.types[0]), verify=False
tree, self._readers[tree.types[0].token](self, tree.types[0]), verify=False
)
def read_struct(self, type_=None) -> List[Any]:
self.pos += -self.pos & 7 # align 8
readers = self.readers
self._pos += -self._pos & 7 # align 8
readers = self._readers
return [
readers[child_type.token](self, child_type) for child_type in type_.children
]
def read_dict_entry(self, type_: SignatureType) -> Dict[Any, Any]:
self.pos += -self.pos & 7 # align 8
return self.readers[type_.children[0].token](
self._pos += -self._pos & 7 # align 8
return self._readers[type_.children[0].token](
self, type_.children[0]
), self.readers[type_.children[1].token](self, type_.children[1])
), self._readers[type_.children[1].token](self, type_.children[1])
def read_array(self, type_: SignatureType) -> List[Any]:
self.pos += -self.pos & 3 # align 4 for the array
self.pos += (
-self.pos & (UINT32_SIZE - 1)
self._pos += -self._pos & 3 # align 4 for the array
self._pos += (
-self._pos & (UINT32_SIZE - 1)
) + UINT32_SIZE # align for the uint32
if self._uint32_unpack:
array_length = self._uint32_unpack(self.view, self.pos - UINT32_SIZE)[0]
array_length = self._uint32_unpack(self._view, self._pos - UINT32_SIZE)[0]
else:
array_length = self.view[self.pos - UINT32_SIZE : self.pos].cast(
array_length = self._view[self._pos - UINT32_SIZE : self._pos].cast(
UINT32_CAST
)[0]
child_type = type_.children[0]
if child_type.token in "xtd{(":
token = child_type.token
if token in "xtd{(":
# the first alignment is not included in the array size
self.pos += -self.pos & 7 # align 8
self._pos += -self._pos & 7 # align 8
if child_type.token == "y":
self.pos += array_length
return self.buf[self.pos - array_length : self.pos]
if token == "y":
self._pos += array_length
return self._buf[self._pos - array_length : self._pos]
beginning_pos = self.pos
readers = self.readers
beginning_pos = self._pos
readers = self._readers
if child_type.token == "{":
if token == "{":
result_dict = {}
while self.pos - beginning_pos < array_length:
self.pos += -self.pos & 7 # align 8
key = readers[child_type.children[0].token](
self, child_type.children[0]
)
result_dict[key] = readers[child_type.children[1].token](
self, child_type.children[1]
)
child_0 = child_type.children[0]
reader_0 = readers[child_0.token]
child_1 = child_type.children[1]
reader_1 = readers[child_1.token]
while self._pos - beginning_pos < array_length:
self._pos += -self._pos & 7 # align 8
key = reader_0(self, child_0)
result_dict[key] = reader_1(self, child_1)
return result_dict
result_list = []
while self.pos - beginning_pos < array_length:
result_list.append(readers[child_type.token](self, child_type))
reader = readers[child_type.token]
while self._pos - beginning_pos < array_length:
result_list.append(reader(self, child_type))
return result_list
def header_fields(self, header_length) -> Dict[str, Any]:
"""Header fields are always a(yv)."""
beginning_pos = self.pos
beginning_pos = self._pos
headers = {}
while self.pos - beginning_pos < header_length:
while self._pos - beginning_pos < header_length:
# Now read the y (byte) of struct (yv)
self.pos += (-self.pos & 7) + 1 # align 8 + 1 for 'y' byte
field_0 = self.view[self.pos - 1]
self._pos += (-self._pos & 7) + 1 # align 8 + 1 for 'y' byte
field_0 = self._view[self._pos - 1]
# Now read the v (variant) of struct (yv)
signature_len = self.view[self.pos] # byte
o = self.pos + 1
self.pos += signature_len + 2 # one for the byte, one for the '\0'
tree = SignatureTree._get(self.buf[o : o + signature_len].decode())
headers[HEADER_NAME_MAP[field_0]] = self.readers[tree.types[0].token](
signature_len = self._view[self._pos] # byte
o = self._pos + 1
self._pos += signature_len + 2 # one for the byte, one for the '\0'
tree = SignatureTree._get(self._buf[o : o + signature_len].decode())
headers[HEADER_NAME_MAP[field_0]] = self._readers[tree.types[0].token](
self, tree.types[0]
)
return headers
@ -334,10 +355,10 @@ class Unmarshaller:
# Signature is of the header is
# BYTE, BYTE, BYTE, BYTE, UINT32, UINT32, ARRAY of STRUCT of (BYTE,VARIANT)
self.read_to_pos(HEADER_SIGNATURE_SIZE)
buffer = self.buf
buffer = self._buf
endian = buffer[0]
self.message_type = MESSAGE_TYPE_MAP[buffer[1]]
self.flag = MESSAGE_FLAG_MAP[buffer[2]]
self._message_type = buffer[1]
self._flag = buffer[2]
protocol_version = buffer[3]
if endian != LITTLE_ENDIAN and endian != BIG_ENDIAN:
@ -349,44 +370,44 @@ class Unmarshaller:
f"got unknown protocol version: {protocol_version}"
)
self.body_len, self.serial, self.header_len = UNPACK_LENGTHS[
self._body_len, self._serial, self._header_len = UNPACK_LENGTHS[
endian
].unpack_from(buffer, 4)
self.msg_len = (
self.header_len + (-self.header_len & 7) + self.body_len
self._msg_len = (
self._header_len + (-self._header_len & 7) + self._body_len
) # align 8
can_cast = bool(
(IS_LITTLE_ENDIAN and endian == LITTLE_ENDIAN)
or (IS_BIG_ENDIAN and endian == BIG_ENDIAN)
)
self.readers = self._readers_by_type[(endian, can_cast)]
self._readers = self._readers_by_type[(endian, can_cast)]
if not can_cast:
self._uint32_unpack = UINT32_UNPACK_BY_ENDIAN[endian]
def _read_body(self):
"""Read the body of the message."""
self.read_to_pos(HEADER_SIGNATURE_SIZE + self.msg_len)
self.view = memoryview(self.buf)
self.pos = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION
header_fields = self.header_fields(self.header_len)
self.pos += -self.pos & 7 # align 8
self.read_to_pos(HEADER_SIGNATURE_SIZE + self._msg_len)
self._view = memoryview(self._buf)
self._pos = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION
header_fields = self.header_fields(self._header_len)
self._pos += -self._pos & 7 # align 8
tree = SignatureTree._get(header_fields.get(HeaderField.SIGNATURE.name, ""))
self._message = Message(
destination=header_fields.get(HEADER_DESTINATION),
path=header_fields.get(HEADER_PATH),
interface=header_fields.get(HEADER_INTERFACE),
member=header_fields.get(HEADER_MEMBER),
message_type=self.message_type,
flags=self.flag,
message_type=MESSAGE_TYPE_MAP[self._message_type],
flags=MESSAGE_FLAG_MAP[self._flag],
error_name=header_fields.get(HEADER_ERROR_NAME),
reply_serial=header_fields.get(HEADER_REPLY_SERIAL),
sender=header_fields.get(HEADER_SENDER),
unix_fds=self.unix_fds,
signature=tree.signature,
body=[self.readers[t.token](self, t) for t in tree.types]
if self.body_len
unix_fds=self._unix_fds,
signature=tree,
body=[self._readers[t.token](self, t) for t in tree.types]
if self._body_len
else [],
serial=self.serial,
serial=self._serial,
# The D-Bus implementation already validates the message,
# so we don't need to do it again.
validate=False,
@ -400,7 +421,7 @@ class Unmarshaller:
to be resumed when more data comes in over the wire.
"""
try:
if not self.msg_len:
if not self._msg_len:
self._read_header()
self._read_body()
except MarshallerStreamEndError:

View File

@ -2,6 +2,7 @@ import array
import asyncio
import logging
import socket
import traceback
from collections import deque
from copy import copy
from typing import Any, Optional
@ -419,11 +420,17 @@ class MessageBus(BaseMessageBus):
return handler
def _message_reader(self) -> None:
unmarshaller = self._unmarshaller
try:
while True:
if self._unmarshaller.unmarshall():
self._on_message(self._unmarshaller.message)
self._unmarshaller = self._create_unmarshaller()
if unmarshaller.unmarshall():
try:
self._process_message(unmarshaller.message)
except Exception as e:
logging.error(
f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}"
)
unmarshaller.reset()
else:
break
except Exception as e:

View File

@ -1,4 +1,6 @@
import io
import logging
import traceback
from typing import Callable, Optional
from .. import introspection as intr
@ -173,6 +175,14 @@ class MessageBus(BaseMessageBus):
else:
self._auth = auth
def _on_message(self, msg: Message) -> None:
try:
self._process_message(msg)
except Exception as e:
logging.error(
f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}"
)
def connect(
self, connect_notify: Callable[["MessageBus", Optional[Exception]], None] = None
):

View File

@ -1,4 +1,4 @@
from typing import Any, List
from typing import Any, List, Union
from ._private.constants import LITTLE_ENDIAN, PROTOCOL_VERSION, HeaderField
from ._private.marshaller import Marshaller
@ -105,11 +105,11 @@ class Message:
reply_serial: int = None,
sender: str = None,
unix_fds: List[int] = [],
signature: str = "",
signature: Union[str, SignatureTree] = "",
body: List[Any] = [],
serial: int = 0,
validate: bool = True,
):
) -> None:
self.destination = destination
self.path = path
self.interface = interface
@ -124,14 +124,12 @@ class Message:
self.reply_serial = reply_serial
self.sender = sender
self.unix_fds = unix_fds
self.signature = (
signature.signature if type(signature) is SignatureTree else signature
)
self.signature_tree = (
signature
if type(signature) is SignatureTree
else SignatureTree._get(signature)
)
if type(signature) is SignatureTree:
self.signature = signature.signature
self.signature_tree = signature
else:
self.signature = signature
self.signature_tree = SignatureTree._get(signature)
self.body = body
self.serial = serial

View File

@ -690,14 +690,6 @@ class BaseMessageBus:
ErrorType.INTERNAL_ERROR, "invalid message type for method call", msg
)
def _on_message(self, msg: Message) -> None:
try:
self._process_message(msg)
except Exception as e:
logging.error(
f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}"
)
def _send_reply(self, msg: Message):
bus = self