feat: speed up unmarshaller (#109)

This commit is contained in:
J. Nick Koston 2022-10-19 15:51:36 -05:00 committed by GitHub
parent ba4c66cece
commit 2443cf9990
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 14 deletions

View File

@ -22,6 +22,14 @@ cdef object UINT32_UNPACK_BIG_ENDIAN
cdef object INT16_UNPACK_LITTLE_ENDIAN
cdef object INT16_UNPACK_BIG_ENDIAN
cdef object Variant
cdef object Message
cdef object MESSAGE_TYPE_MAP
cdef object MESSAGE_FLAG_MAP
cdef object HEADER_MESSAGE_ARG_NAME
cpdef get_signature_tree
cdef class MarshallerStreamEndError(Exception):
pass
@ -46,13 +54,14 @@ cdef class Unmarshaller:
cpdef reset(self)
cdef read_sock(self, unsigned long length)
cdef bytes _read_sock(self, unsigned long length)
@cython.locals(
start_len=cython.ulong,
missing_bytes=cython.ulong
missing_bytes=cython.ulong,
data=cython.bytes
)
cdef read_to_pos(self, unsigned long pos)
cdef _read_to_pos(self, unsigned long pos)
cpdef read_uint32_unpack(self, object type_)
@ -96,6 +105,9 @@ cdef class Unmarshaller:
)
cdef _read_header(self)
@cython.locals(
body=cython.list
)
cdef _read_body(self)
cpdef unmarshall(self)

View File

@ -176,7 +176,7 @@ class Unmarshaller:
"""Return the message that has been unmarshalled."""
return self._message
def read_sock(self, length: int) -> bytes:
def _read_sock(self, length: int) -> bytes:
"""reads from the socket, storing any fds sent and handling errors
from the read itself"""
unix_fd_list = array.array("i")
@ -198,7 +198,7 @@ class Unmarshaller:
return msg
def read_to_pos(self, pos: int) -> None:
def _read_to_pos(self, pos: int) -> None:
"""
Read from underlying socket into buffer.
@ -216,12 +216,12 @@ class Unmarshaller:
if self._sock is None:
data = self._stream.read(missing_bytes)
else:
data = self.read_sock(missing_bytes)
data = self._read_sock(missing_bytes)
if data == b"":
raise EOFError()
if data is None:
raise MarshallerStreamEndError()
self._buf.extend(data)
self._buf += data
if len(data) + start_len != pos:
raise MarshallerStreamEndError()
@ -385,17 +385,13 @@ class Unmarshaller:
"""Read the header of the message."""
# Signature is of the header is
# BYTE, BYTE, BYTE, BYTE, UINT32, UINT32, ARRAY of STRUCT of (BYTE,VARIANT)
self.read_to_pos(HEADER_SIGNATURE_SIZE)
self._read_to_pos(HEADER_SIGNATURE_SIZE)
buffer = self._buf
endian = buffer[0]
self._message_type = buffer[1]
self._flag = buffer[2]
protocol_version = buffer[3]
if endian != LITTLE_ENDIAN and endian != BIG_ENDIAN:
raise InvalidMessageError(
f"Expecting endianness as the first byte, got {endian} from {buffer}"
)
if protocol_version != PROTOCOL_VERSION:
raise InvalidMessageError(
f"got unknown protocol version: {protocol_version}"
@ -409,12 +405,16 @@ class Unmarshaller:
) = UNPACK_HEADER_LITTLE_ENDIAN(self._buf, 4)
self._uint32_unpack = UINT32_UNPACK_LITTLE_ENDIAN
self._int16_unpack = INT16_UNPACK_LITTLE_ENDIAN
else:
elif endian == BIG_ENDIAN:
self._body_len, self._serial, self._header_len = UNPACK_HEADER_BIG_ENDIAN(
self._buf, 4
)
self._uint32_unpack = UINT32_UNPACK_BIG_ENDIAN
self._int16_unpack = INT16_UNPACK_BIG_ENDIAN
else:
raise InvalidMessageError(
f"Expecting endianness as the first byte, got {endian} from {buffer}"
)
self._msg_len = (
self._header_len + (-self._header_len & 7) + self._body_len
@ -423,7 +423,7 @@ class Unmarshaller:
def _read_body(self) -> None:
"""Read the body of the message."""
self.read_to_pos(HEADER_SIGNATURE_SIZE + self._msg_len)
self._read_to_pos(HEADER_SIGNATURE_SIZE + self._msg_len)
self._pos = HEADER_ARRAY_OF_STRUCT_SIGNATURE_POSITION
header_fields = self.header_fields(self._header_len)
self._pos += -self._pos & 7 # align 8