diff --git a/src/dbus_fast/_private/unmarshaller.pxd b/src/dbus_fast/_private/unmarshaller.pxd index 4073cf0..667f8d7 100644 --- a/src/dbus_fast/_private/unmarshaller.pxd +++ b/src/dbus_fast/_private/unmarshaller.pxd @@ -106,6 +106,8 @@ cdef class Unmarshaller: cdef object _int16_unpack cdef object _uint16_unpack + cdef _reset(self) + cpdef reset(self) @cython.locals( @@ -173,6 +175,8 @@ cdef class Unmarshaller: ) cdef _read_body(self) + cdef _unmarshall(self) + cpdef unmarshall(self) @cython.locals( diff --git a/src/dbus_fast/_private/unmarshaller.py b/src/dbus_fast/_private/unmarshaller.py index 06b0535..d5695d2 100644 --- a/src/dbus_fast/_private/unmarshaller.py +++ b/src/dbus_fast/_private/unmarshaller.py @@ -202,6 +202,13 @@ class Unmarshaller: def reset(self) -> None: """Reset the unmarshaller to its initial state. + Call this before processing a new message. + """ + self._reset() + + def _reset(self) -> None: + """Reset the unmarshaller to its initial state. + Call this before processing a new message. """ self._unix_fds = [] @@ -596,6 +603,15 @@ class Unmarshaller: def unmarshall(self) -> Optional[Message]: """Unmarshall the message. + The underlying read function will raise BlockingIOError if the + if there are not enough bytes in the buffer. This allows unmarshall + to be resumed when more data comes in over the wire. + """ + return self._unmarshall() + + def _unmarshall(self) -> Optional[Message]: + """Unmarshall the message. + The underlying read function will raise BlockingIOError if the if there are not enough bytes in the buffer. This allows unmarshall to be resumed when more data comes in over the wire. diff --git a/src/dbus_fast/aio/message_bus.py b/src/dbus_fast/aio/message_bus.py index 3f4ab52..090d0cb 100644 --- a/src/dbus_fast/aio/message_bus.py +++ b/src/dbus_fast/aio/message_bus.py @@ -170,7 +170,6 @@ class MessageBus(BaseMessageBus): super().__init__(bus_address, bus_type, ProxyObject) self._negotiate_unix_fd = negotiate_unix_fd self._loop = asyncio.get_running_loop() - self._unmarshaller = self._create_unmarshaller() self._writer = _MessageWriter(self) @@ -201,7 +200,8 @@ class MessageBus(BaseMessageBus): self._loop.add_reader( self._fd, build_message_reader( - self._unmarshaller, + self._stream, + self._sock if self._negotiate_unix_fd else None, self._process_message, self._finalize, ), @@ -477,12 +477,6 @@ class MessageBus(BaseMessageBus): except Exception: logging.warning("could not close socket", exc_info=True) - def _create_unmarshaller(self) -> Unmarshaller: - sock = None - if self._negotiate_unix_fd: - sock = self._sock - return Unmarshaller(self._stream, sock) - def _finalize(self, err: Optional[Exception] = None) -> None: try: self._loop.remove_reader(self._fd) diff --git a/src/dbus_fast/aio/message_reader.pxd b/src/dbus_fast/aio/message_reader.pxd index a53a8e3..e76627b 100644 --- a/src/dbus_fast/aio/message_reader.pxd +++ b/src/dbus_fast/aio/message_reader.pxd @@ -1,3 +1,5 @@ """cdefs for message_reader.py""" import cython + +from .._private.unmarshaller cimport Unmarshaller diff --git a/src/dbus_fast/aio/message_reader.py b/src/dbus_fast/aio/message_reader.py index 4da1f08..35d6883 100644 --- a/src/dbus_fast/aio/message_reader.py +++ b/src/dbus_fast/aio/message_reader.py @@ -1,4 +1,6 @@ +import io import logging +import socket import traceback from typing import Callable, Optional @@ -7,19 +9,19 @@ from ..message import Message def build_message_reader( - unmarshaller: Unmarshaller, + stream: io.BufferedRWPair, + sock: Optional[socket.socket], process: Callable[[Message], None], finalize: Callable[[Optional[Exception]], None], ) -> None: """Build a callable that reads messages from the unmarshaller and passes them to the process function.""" - unmarshall = unmarshaller.unmarshall - reset = unmarshaller.reset + unmarshaller = Unmarshaller(stream, sock) def _message_reader() -> None: """Reads messages from the unmarshaller and passes them to the process function.""" try: while True: - message = unmarshall() + message = unmarshaller._unmarshall() if not message: return try: @@ -28,7 +30,7 @@ def build_message_reader( logging.error( f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}" ) - reset() + unmarshaller._reset() except Exception as e: finalize(e)