diff --git a/bench/unmarshall.py b/bench/unmarshall.py index f971683..e071e2f 100644 --- a/bench/unmarshall.py +++ b/bench/unmarshall.py @@ -23,7 +23,6 @@ unmarshaller = Unmarshaller(stream) def unmarhsall_bluez_rssi_message(): stream.seek(0) - unmarshaller.reset() unmarshaller.unmarshall() diff --git a/bench/unmarshall_2.py b/bench/unmarshall_2.py index e0c8191..edc296c 100644 --- a/bench/unmarshall_2.py +++ b/bench/unmarshall_2.py @@ -42,7 +42,6 @@ unmarshaller = Unmarshaller(stream) def unmarhsall_bluez_rssi_message(): stream.seek(0) - unmarshaller.reset() unmarshaller.unmarshall() diff --git a/bench/unmarshall_getmanagedobjects.py b/bench/unmarshall_getmanagedobjects.py index 107ad2a..999b905 100644 --- a/bench/unmarshall_getmanagedobjects.py +++ b/bench/unmarshall_getmanagedobjects.py @@ -16,7 +16,6 @@ unmarshaller = Unmarshaller(stream) def unmarhsall_bluez_get_managed_objects_message(): stream.seek(0) - unmarshaller.reset() unmarshaller.unmarshall() diff --git a/bench/unmarshall_interfaces_added.py b/bench/unmarshall_interfaces_added.py index 40fd8c5..7b0fd6e 100644 --- a/bench/unmarshall_interfaces_added.py +++ b/bench/unmarshall_interfaces_added.py @@ -31,7 +31,6 @@ unmarshaller = Unmarshaller(stream) def unmarshall_interfaces_added_message(): stream.seek(0) - unmarshaller.reset() unmarshaller.unmarshall() diff --git a/bench/unmarshall_interfaces_removed.py b/bench/unmarshall_interfaces_removed.py index d3f4fb7..9b6cdb7 100644 --- a/bench/unmarshall_interfaces_removed.py +++ b/bench/unmarshall_interfaces_removed.py @@ -22,7 +22,6 @@ unmarshaller = Unmarshaller(stream) def unmarshall_interfaces_removed_message(): stream.seek(0) - unmarshaller.reset() unmarshaller.unmarshall() diff --git a/bench/unmarshall_manufacturerdata.py b/bench/unmarshall_manufacturerdata.py index 89cb615..6390ceb 100644 --- a/bench/unmarshall_manufacturerdata.py +++ b/bench/unmarshall_manufacturerdata.py @@ -21,7 +21,6 @@ unmarshaller = Unmarshaller(stream) def unmarshall_mfr_data_message(): stream.seek(0) - unmarshaller.reset() unmarshaller.unmarshall() diff --git a/bench/unmarshall_passive.py b/bench/unmarshall_passive.py index ff45bf4..243d3d9 100644 --- a/bench/unmarshall_passive.py +++ b/bench/unmarshall_passive.py @@ -20,7 +20,6 @@ unmarshaller = Unmarshaller(stream) def unmarhsall_bluez_rssi_message(): stream.seek(0) - unmarshaller.reset() unmarshaller.unmarshall() diff --git a/bench/unmarshall_servicedata.py b/bench/unmarshall_servicedata.py index 16eb6d7..8f3fc39 100644 --- a/bench/unmarshall_servicedata.py +++ b/bench/unmarshall_servicedata.py @@ -23,7 +23,6 @@ unmarshaller = Unmarshaller(stream) def unmarshall_properties_changed_message(): stream.seek(0) - unmarshaller.reset() unmarshaller.unmarshall() diff --git a/src/dbus_fast/_private/unmarshaller.pxd b/src/dbus_fast/_private/unmarshaller.pxd index a9e0f6c..f86b143 100644 --- a/src/dbus_fast/_private/unmarshaller.pxd +++ b/src/dbus_fast/_private/unmarshaller.pxd @@ -78,6 +78,10 @@ cdef unsigned int TOKEN_S_AS_INT cdef unsigned int TOKEN_G_AS_INT cdef object MARSHALL_STREAM_END_ERROR +cdef object DEFAULT_BUFFER_SIZE + +cdef cython.uint EAGAIN +cdef cython.uint EWOULDBLOCK cdef get_signature_tree @@ -116,23 +120,32 @@ cdef class Unmarshaller: cdef object _int16_unpack cdef object _uint16_unpack cdef object _stream_reader + cdef bint _negotiate_unix_fd + cdef bint _read_complete - cdef _reset(self) + cdef _next_message(self) - cpdef reset(self) + cdef _has_another_message_in_buffer(self) @cython.locals( msg=cython.bytes, - recv=cython.tuple + recv=cython.tuple, + errno=cython.uint ) - cdef bytes _read_sock(self, object length) + cdef _read_sock_with_fds(self, unsigned int pos, unsigned int missing_bytes) + + @cython.locals( + data=cython.bytes, + errno=cython.uint + ) + cdef _read_sock_without_fds(self, unsigned int pos) @cython.locals( - start_len=cython.ulong, - missing_bytes=cython.ulong, data=cython.bytes ) - cdef _read_to_pos(self, unsigned long pos) + cdef _read_stream(self, unsigned int pos, unsigned int missing_bytes) + + cdef _read_to_pos(self, unsigned int pos) cpdef read_boolean(self, SignatureType type_) diff --git a/src/dbus_fast/_private/unmarshaller.py b/src/dbus_fast/_private/unmarshaller.py index 215990d..fcf9ba9 100644 --- a/src/dbus_fast/_private/unmarshaller.py +++ b/src/dbus_fast/_private/unmarshaller.py @@ -1,4 +1,5 @@ import array +import errno import io import socket import sys @@ -102,6 +103,10 @@ ARRAY = array.array SOL_SOCKET = socket.SOL_SOCKET SCM_RIGHTS = socket.SCM_RIGHTS +EAGAIN = errno.EAGAIN +EWOULDBLOCK = errno.EWOULDBLOCK + + HEADER_MESSAGE_ARG_NAME = { 1: "path", 2: "interface", @@ -121,6 +126,8 @@ READER_TYPE = Callable[["Unmarshaller", SignatureType], Any] MARSHALL_STREAM_END_ERROR = BlockingIOError +DEFAULT_BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE + def unpack_parser_factory(unpack_from: Callable, size: int) -> READER_TYPE: """Build a parser that unpacks the bytes using the given unpack_from function.""" @@ -169,6 +176,12 @@ except ImportError: # # class Unmarshaller: + """Unmarshall messages from a stream. + + When calling with sock and _negotiate_unix_fd False, the unmashaller must + be called continuously for each new message as it will buffer the data + until a complete message is available. + """ __slots__ = ( "_unix_fds", @@ -189,9 +202,16 @@ class Unmarshaller: "_uint16_unpack", "_is_native", "_stream_reader", + "_negotiate_unix_fd", + "_read_complete", ) - def __init__(self, stream: io.BufferedRWPair, sock: Optional[socket.socket] = None): + def __init__( + self, + stream: Optional[io.BufferedRWPair] = None, + sock: Optional[socket.socket] = None, + negotiate_unix_fd: bool = True, + ) -> None: self._unix_fds: List[int] = [] self._buf = bytearray() # Actual buffer self._stream = stream @@ -210,25 +230,24 @@ class Unmarshaller: self._int16_unpack: Optional[Callable] = None self._uint16_unpack: Optional[Callable] = None self._stream_reader: Optional[Callable] = None - if self._sock is None: + self._negotiate_unix_fd = negotiate_unix_fd + self._read_complete = False + if stream: if isinstance(stream, io.BufferedRWPair) and hasattr(stream, "reader"): self._stream_reader = stream.reader.read # type: ignore[attr-defined] self._stream_reader = stream.read - def reset(self) -> None: - """Reset the unmarshaller to its initial state. - - Call this before processing a new message. - """ - self._reset() - - def _reset(self) -> None: + def _next_message(self) -> None: """Reset the unmarshaller to its initial state. Call this before processing a new message. """ self._unix_fds = [] - self._buf.clear() + to_clear = HEADER_SIGNATURE_SIZE + self._msg_len + if len(self._buf) == to_clear: + self._buf = bytearray() + else: + del self._buf[:to_clear] self._message = None self._pos = 0 self._body_len = 0 @@ -238,6 +257,7 @@ class Unmarshaller: self._flag = 0 self._msg_len = 0 self._is_native = 0 + self._read_complete = False # No need to reset the unpack functions, they are set in _read_header # every time a new message is processed. @@ -246,12 +266,26 @@ class Unmarshaller: """Return the message that has been unmarshalled.""" return self._message - def _read_sock(self, length: _int) -> bytes: + def _has_another_message_in_buffer(self) -> bool: + """Check if there is another message in the buffer.""" + return len(self._buf) > HEADER_SIGNATURE_SIZE + self._msg_len + + def _read_sock_with_fds(self, pos: _int, missing_bytes: _int) -> None: """reads from the socket, storing any fds sent and handling errors - from the read itself""" + from the read itself. + + This function is greedy and will read as much data as possible + from the underlying socket. + """ # This will raise BlockingIOError if there is no data to read # which we store in the MARSHALL_STREAM_END_ERROR object - recv = self._sock.recvmsg(length, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr] + try: + recv = self._sock.recvmsg(missing_bytes, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr] + except OSError as e: + errno = e.errno + if errno == EAGAIN or errno == EWOULDBLOCK: + raise MARSHALL_STREAM_END_ERROR + raise msg = recv[0] ancdata = recv[1] if ancdata: @@ -261,8 +295,44 @@ class Unmarshaller: self._unix_fds.extend( ARRAY("i", data[: len(data) - (len(data) % MAX_UNIX_FDS_SIZE)]) ) + if msg == b"": + raise EOFError() + self._buf += msg + if len(self._buf) < pos: + raise MARSHALL_STREAM_END_ERROR - return msg + def _read_sock_without_fds(self, pos: _int) -> None: + """reads from the socket and handling errors from the read itself. + + This function is greedy and will read as much data as possible + from the underlying socket. + """ + # This will raise BlockingIOError if there is no data to read + # which we store in the MARSHALL_STREAM_END_ERROR object + while True: + try: + data = self._sock.recv(DEFAULT_BUFFER_SIZE) # type: ignore[union-attr] + except OSError as e: + errno = e.errno + if errno == EAGAIN or errno == EWOULDBLOCK: + raise MARSHALL_STREAM_END_ERROR + raise + if data == b"": + raise EOFError() + self._buf += data + if len(self._buf) >= pos: + return + + def _read_stream(self, pos: _int, missing_bytes: _int) -> bytes: + """Read from the stream.""" + data = self._stream_reader(missing_bytes) # type: ignore[misc] + if data is None: + raise MARSHALL_STREAM_END_ERROR + if data == b"": + raise EOFError() + self._buf += data + if len(self._buf) < pos: + raise MARSHALL_STREAM_END_ERROR def _read_to_pos(self, pos: _int) -> None: """ @@ -277,19 +347,15 @@ class Unmarshaller: :returns: None """ - start_len = len(self._buf) - missing_bytes = pos - (start_len - self._pos) + missing_bytes = pos - len(self._buf) + if missing_bytes <= 0: + return if self._sock is None: - data = self._stream_reader(missing_bytes) # type: ignore[misc] + self._read_stream(pos, missing_bytes) + elif self._negotiate_unix_fd: + self._read_sock_with_fds(pos, missing_bytes) else: - data = self._read_sock(missing_bytes) - if data == b"": - raise EOFError() - if data is None: - raise MARSHALL_STREAM_END_ERROR - self._buf += data - if len(data) + start_len != pos: - raise MARSHALL_STREAM_END_ERROR + self._read_sock_without_fds(pos) def read_uint32_unpack(self, type_: _SignatureType) -> int: return self._read_uint32_unpack() @@ -641,6 +707,7 @@ class Unmarshaller: validate=False, **header_fields, ) + self._read_complete = True def unmarshall(self) -> Optional[Message]: """Unmarshall the message. @@ -658,6 +725,8 @@ class Unmarshaller: if there are not enough bytes in the buffer. This allows unmarshall to be resumed when more data comes in over the wire. """ + if self._read_complete: + self._next_message() try: if not self._msg_len: self._read_header() diff --git a/src/dbus_fast/aio/message_bus.py b/src/dbus_fast/aio/message_bus.py index 33c6570..795c394 100644 --- a/src/dbus_fast/aio/message_bus.py +++ b/src/dbus_fast/aio/message_bus.py @@ -201,10 +201,10 @@ class MessageBus(BaseMessageBus): self._loop.add_reader( self._fd, build_message_reader( - self._stream, - self._sock if self._negotiate_unix_fd else None, + self._sock, self._process_message, self._finalize, + self._negotiate_unix_fd, ), ) diff --git a/src/dbus_fast/aio/message_reader.py b/src/dbus_fast/aio/message_reader.py index 35d6883..07957a1 100644 --- a/src/dbus_fast/aio/message_reader.py +++ b/src/dbus_fast/aio/message_reader.py @@ -1,7 +1,5 @@ -import io import logging import socket -import traceback from typing import Callable, Optional from .._private.unmarshaller import Unmarshaller @@ -9,13 +7,13 @@ from ..message import Message def build_message_reader( - stream: io.BufferedRWPair, sock: Optional[socket.socket], process: Callable[[Message], None], finalize: Callable[[Optional[Exception]], None], + negotiate_unix_fd: bool, ) -> None: """Build a callable that reads messages from the unmarshaller and passes them to the process function.""" - unmarshaller = Unmarshaller(stream, sock) + unmarshaller = Unmarshaller(None, sock, negotiate_unix_fd) def _message_reader() -> None: """Reads messages from the unmarshaller and passes them to the process function.""" @@ -28,9 +26,15 @@ def build_message_reader( process(message) except Exception as e: logging.error( - f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}" + "Unexpected error processing message: %s", exc_info=True ) - unmarshaller._reset() + # If we are not negotiating unix fds, we can stop reading as soon as we have + # the buffer is empty as asyncio will call us again when there is more data. + if ( + not negotiate_unix_fd + and not unmarshaller._has_another_message_in_buffer() + ): + return except Exception as e: finalize(e) diff --git a/tests/test_marshaller.py b/tests/test_marshaller.py index ae69eae..71268fa 100644 --- a/tests/test_marshaller.py +++ b/tests/test_marshaller.py @@ -457,21 +457,18 @@ def test_unmarshall_multiple_messages(): unpacked = unpack_variants(message.body) assert unpacked == ["org.bluez.Device1", {"RSSI": -76}, []] - unmarshaller.reset() assert unmarshaller.unmarshall() message = unmarshaller.message assert message is not None unpacked = unpack_variants(message.body) assert unpacked == ["org.bluez.Device1", {"RSSI": -80}, []] - unmarshaller.reset() assert unmarshaller.unmarshall() message = unmarshaller.message assert message is not None unpacked = unpack_variants(message.body) assert unpacked == ["org.bluez.Device1", {"RSSI": -94}, []] - unmarshaller.reset() with pytest.raises(EOFError): unmarshaller.unmarshall()