fix: avoid double buffering when using asyncio reader without negotiate_unix_fd (#213)

This commit is contained in:
J. Nick Koston 2023-07-27 22:46:29 -05:00 committed by GitHub
parent e669a13de2
commit c933be7095
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 127 additions and 52 deletions

View File

@ -23,7 +23,6 @@ unmarshaller = Unmarshaller(stream)
def unmarhsall_bluez_rssi_message(): def unmarhsall_bluez_rssi_message():
stream.seek(0) stream.seek(0)
unmarshaller.reset()
unmarshaller.unmarshall() unmarshaller.unmarshall()

View File

@ -42,7 +42,6 @@ unmarshaller = Unmarshaller(stream)
def unmarhsall_bluez_rssi_message(): def unmarhsall_bluez_rssi_message():
stream.seek(0) stream.seek(0)
unmarshaller.reset()
unmarshaller.unmarshall() unmarshaller.unmarshall()

View File

@ -16,7 +16,6 @@ unmarshaller = Unmarshaller(stream)
def unmarhsall_bluez_get_managed_objects_message(): def unmarhsall_bluez_get_managed_objects_message():
stream.seek(0) stream.seek(0)
unmarshaller.reset()
unmarshaller.unmarshall() unmarshaller.unmarshall()

View File

@ -31,7 +31,6 @@ unmarshaller = Unmarshaller(stream)
def unmarshall_interfaces_added_message(): def unmarshall_interfaces_added_message():
stream.seek(0) stream.seek(0)
unmarshaller.reset()
unmarshaller.unmarshall() unmarshaller.unmarshall()

View File

@ -22,7 +22,6 @@ unmarshaller = Unmarshaller(stream)
def unmarshall_interfaces_removed_message(): def unmarshall_interfaces_removed_message():
stream.seek(0) stream.seek(0)
unmarshaller.reset()
unmarshaller.unmarshall() unmarshaller.unmarshall()

View File

@ -21,7 +21,6 @@ unmarshaller = Unmarshaller(stream)
def unmarshall_mfr_data_message(): def unmarshall_mfr_data_message():
stream.seek(0) stream.seek(0)
unmarshaller.reset()
unmarshaller.unmarshall() unmarshaller.unmarshall()

View File

@ -20,7 +20,6 @@ unmarshaller = Unmarshaller(stream)
def unmarhsall_bluez_rssi_message(): def unmarhsall_bluez_rssi_message():
stream.seek(0) stream.seek(0)
unmarshaller.reset()
unmarshaller.unmarshall() unmarshaller.unmarshall()

View File

@ -23,7 +23,6 @@ unmarshaller = Unmarshaller(stream)
def unmarshall_properties_changed_message(): def unmarshall_properties_changed_message():
stream.seek(0) stream.seek(0)
unmarshaller.reset()
unmarshaller.unmarshall() unmarshaller.unmarshall()

View File

@ -78,6 +78,10 @@ cdef unsigned int TOKEN_S_AS_INT
cdef unsigned int TOKEN_G_AS_INT cdef unsigned int TOKEN_G_AS_INT
cdef object MARSHALL_STREAM_END_ERROR cdef object MARSHALL_STREAM_END_ERROR
cdef object DEFAULT_BUFFER_SIZE
cdef cython.uint EAGAIN
cdef cython.uint EWOULDBLOCK
cdef get_signature_tree cdef get_signature_tree
@ -116,23 +120,32 @@ cdef class Unmarshaller:
cdef object _int16_unpack cdef object _int16_unpack
cdef object _uint16_unpack cdef object _uint16_unpack
cdef object _stream_reader cdef object _stream_reader
cdef bint _negotiate_unix_fd
cdef bint _read_complete
cdef _reset(self) cdef _next_message(self)
cpdef reset(self) cdef _has_another_message_in_buffer(self)
@cython.locals( @cython.locals(
msg=cython.bytes, msg=cython.bytes,
recv=cython.tuple recv=cython.tuple,
errno=cython.uint
) )
cdef bytes _read_sock(self, object length) cdef _read_sock_with_fds(self, unsigned int pos, unsigned int missing_bytes)
@cython.locals(
data=cython.bytes,
errno=cython.uint
)
cdef _read_sock_without_fds(self, unsigned int pos)
@cython.locals( @cython.locals(
start_len=cython.ulong,
missing_bytes=cython.ulong,
data=cython.bytes data=cython.bytes
) )
cdef _read_to_pos(self, unsigned long pos) cdef _read_stream(self, unsigned int pos, unsigned int missing_bytes)
cdef _read_to_pos(self, unsigned int pos)
cpdef read_boolean(self, SignatureType type_) cpdef read_boolean(self, SignatureType type_)

View File

@ -1,4 +1,5 @@
import array import array
import errno
import io import io
import socket import socket
import sys import sys
@ -102,6 +103,10 @@ ARRAY = array.array
SOL_SOCKET = socket.SOL_SOCKET SOL_SOCKET = socket.SOL_SOCKET
SCM_RIGHTS = socket.SCM_RIGHTS SCM_RIGHTS = socket.SCM_RIGHTS
EAGAIN = errno.EAGAIN
EWOULDBLOCK = errno.EWOULDBLOCK
HEADER_MESSAGE_ARG_NAME = { HEADER_MESSAGE_ARG_NAME = {
1: "path", 1: "path",
2: "interface", 2: "interface",
@ -121,6 +126,8 @@ READER_TYPE = Callable[["Unmarshaller", SignatureType], Any]
MARSHALL_STREAM_END_ERROR = BlockingIOError MARSHALL_STREAM_END_ERROR = BlockingIOError
DEFAULT_BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE
def unpack_parser_factory(unpack_from: Callable, size: int) -> READER_TYPE: def unpack_parser_factory(unpack_from: Callable, size: int) -> READER_TYPE:
"""Build a parser that unpacks the bytes using the given unpack_from function.""" """Build a parser that unpacks the bytes using the given unpack_from function."""
@ -169,6 +176,12 @@ except ImportError:
# #
# #
class Unmarshaller: class Unmarshaller:
"""Unmarshall messages from a stream.
When calling with sock and _negotiate_unix_fd False, the unmashaller must
be called continuously for each new message as it will buffer the data
until a complete message is available.
"""
__slots__ = ( __slots__ = (
"_unix_fds", "_unix_fds",
@ -189,9 +202,16 @@ class Unmarshaller:
"_uint16_unpack", "_uint16_unpack",
"_is_native", "_is_native",
"_stream_reader", "_stream_reader",
"_negotiate_unix_fd",
"_read_complete",
) )
def __init__(self, stream: io.BufferedRWPair, sock: Optional[socket.socket] = None): def __init__(
self,
stream: Optional[io.BufferedRWPair] = None,
sock: Optional[socket.socket] = None,
negotiate_unix_fd: bool = True,
) -> None:
self._unix_fds: List[int] = [] self._unix_fds: List[int] = []
self._buf = bytearray() # Actual buffer self._buf = bytearray() # Actual buffer
self._stream = stream self._stream = stream
@ -210,25 +230,24 @@ class Unmarshaller:
self._int16_unpack: Optional[Callable] = None self._int16_unpack: Optional[Callable] = None
self._uint16_unpack: Optional[Callable] = None self._uint16_unpack: Optional[Callable] = None
self._stream_reader: Optional[Callable] = None self._stream_reader: Optional[Callable] = None
if self._sock is None: self._negotiate_unix_fd = negotiate_unix_fd
self._read_complete = False
if stream:
if isinstance(stream, io.BufferedRWPair) and hasattr(stream, "reader"): if isinstance(stream, io.BufferedRWPair) and hasattr(stream, "reader"):
self._stream_reader = stream.reader.read # type: ignore[attr-defined] self._stream_reader = stream.reader.read # type: ignore[attr-defined]
self._stream_reader = stream.read self._stream_reader = stream.read
def reset(self) -> None: def _next_message(self) -> None:
"""Reset the unmarshaller to its initial state.
Call this before processing a new message.
"""
self._reset()
def _reset(self) -> None:
"""Reset the unmarshaller to its initial state. """Reset the unmarshaller to its initial state.
Call this before processing a new message. Call this before processing a new message.
""" """
self._unix_fds = [] self._unix_fds = []
self._buf.clear() to_clear = HEADER_SIGNATURE_SIZE + self._msg_len
if len(self._buf) == to_clear:
self._buf = bytearray()
else:
del self._buf[:to_clear]
self._message = None self._message = None
self._pos = 0 self._pos = 0
self._body_len = 0 self._body_len = 0
@ -238,6 +257,7 @@ class Unmarshaller:
self._flag = 0 self._flag = 0
self._msg_len = 0 self._msg_len = 0
self._is_native = 0 self._is_native = 0
self._read_complete = False
# No need to reset the unpack functions, they are set in _read_header # No need to reset the unpack functions, they are set in _read_header
# every time a new message is processed. # every time a new message is processed.
@ -246,12 +266,26 @@ class Unmarshaller:
"""Return the message that has been unmarshalled.""" """Return the message that has been unmarshalled."""
return self._message return self._message
def _read_sock(self, length: _int) -> bytes: def _has_another_message_in_buffer(self) -> bool:
"""Check if there is another message in the buffer."""
return len(self._buf) > HEADER_SIGNATURE_SIZE + self._msg_len
def _read_sock_with_fds(self, pos: _int, missing_bytes: _int) -> None:
"""reads from the socket, storing any fds sent and handling errors """reads from the socket, storing any fds sent and handling errors
from the read itself""" from the read itself.
This function is greedy and will read as much data as possible
from the underlying socket.
"""
# This will raise BlockingIOError if there is no data to read # This will raise BlockingIOError if there is no data to read
# which we store in the MARSHALL_STREAM_END_ERROR object # which we store in the MARSHALL_STREAM_END_ERROR object
recv = self._sock.recvmsg(length, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr] try:
recv = self._sock.recvmsg(missing_bytes, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr]
except OSError as e:
errno = e.errno
if errno == EAGAIN or errno == EWOULDBLOCK:
raise MARSHALL_STREAM_END_ERROR
raise
msg = recv[0] msg = recv[0]
ancdata = recv[1] ancdata = recv[1]
if ancdata: if ancdata:
@ -261,8 +295,44 @@ class Unmarshaller:
self._unix_fds.extend( self._unix_fds.extend(
ARRAY("i", data[: len(data) - (len(data) % MAX_UNIX_FDS_SIZE)]) ARRAY("i", data[: len(data) - (len(data) % MAX_UNIX_FDS_SIZE)])
) )
if msg == b"":
raise EOFError()
self._buf += msg
if len(self._buf) < pos:
raise MARSHALL_STREAM_END_ERROR
return msg def _read_sock_without_fds(self, pos: _int) -> None:
"""reads from the socket and handling errors from the read itself.
This function is greedy and will read as much data as possible
from the underlying socket.
"""
# This will raise BlockingIOError if there is no data to read
# which we store in the MARSHALL_STREAM_END_ERROR object
while True:
try:
data = self._sock.recv(DEFAULT_BUFFER_SIZE) # type: ignore[union-attr]
except OSError as e:
errno = e.errno
if errno == EAGAIN or errno == EWOULDBLOCK:
raise MARSHALL_STREAM_END_ERROR
raise
if data == b"":
raise EOFError()
self._buf += data
if len(self._buf) >= pos:
return
def _read_stream(self, pos: _int, missing_bytes: _int) -> bytes:
"""Read from the stream."""
data = self._stream_reader(missing_bytes) # type: ignore[misc]
if data is None:
raise MARSHALL_STREAM_END_ERROR
if data == b"":
raise EOFError()
self._buf += data
if len(self._buf) < pos:
raise MARSHALL_STREAM_END_ERROR
def _read_to_pos(self, pos: _int) -> None: def _read_to_pos(self, pos: _int) -> None:
""" """
@ -277,19 +347,15 @@ class Unmarshaller:
:returns: :returns:
None None
""" """
start_len = len(self._buf) missing_bytes = pos - len(self._buf)
missing_bytes = pos - (start_len - self._pos) if missing_bytes <= 0:
return
if self._sock is None: if self._sock is None:
data = self._stream_reader(missing_bytes) # type: ignore[misc] self._read_stream(pos, missing_bytes)
elif self._negotiate_unix_fd:
self._read_sock_with_fds(pos, missing_bytes)
else: else:
data = self._read_sock(missing_bytes) self._read_sock_without_fds(pos)
if data == b"":
raise EOFError()
if data is None:
raise MARSHALL_STREAM_END_ERROR
self._buf += data
if len(data) + start_len != pos:
raise MARSHALL_STREAM_END_ERROR
def read_uint32_unpack(self, type_: _SignatureType) -> int: def read_uint32_unpack(self, type_: _SignatureType) -> int:
return self._read_uint32_unpack() return self._read_uint32_unpack()
@ -641,6 +707,7 @@ class Unmarshaller:
validate=False, validate=False,
**header_fields, **header_fields,
) )
self._read_complete = True
def unmarshall(self) -> Optional[Message]: def unmarshall(self) -> Optional[Message]:
"""Unmarshall the message. """Unmarshall the message.
@ -658,6 +725,8 @@ class Unmarshaller:
if there are not enough bytes in the buffer. This allows unmarshall if there are not enough bytes in the buffer. This allows unmarshall
to be resumed when more data comes in over the wire. to be resumed when more data comes in over the wire.
""" """
if self._read_complete:
self._next_message()
try: try:
if not self._msg_len: if not self._msg_len:
self._read_header() self._read_header()

View File

@ -201,10 +201,10 @@ class MessageBus(BaseMessageBus):
self._loop.add_reader( self._loop.add_reader(
self._fd, self._fd,
build_message_reader( build_message_reader(
self._stream, self._sock,
self._sock if self._negotiate_unix_fd else None,
self._process_message, self._process_message,
self._finalize, self._finalize,
self._negotiate_unix_fd,
), ),
) )

View File

@ -1,7 +1,5 @@
import io
import logging import logging
import socket import socket
import traceback
from typing import Callable, Optional from typing import Callable, Optional
from .._private.unmarshaller import Unmarshaller from .._private.unmarshaller import Unmarshaller
@ -9,13 +7,13 @@ from ..message import Message
def build_message_reader( def build_message_reader(
stream: io.BufferedRWPair,
sock: Optional[socket.socket], sock: Optional[socket.socket],
process: Callable[[Message], None], process: Callable[[Message], None],
finalize: Callable[[Optional[Exception]], None], finalize: Callable[[Optional[Exception]], None],
negotiate_unix_fd: bool,
) -> None: ) -> None:
"""Build a callable that reads messages from the unmarshaller and passes them to the process function.""" """Build a callable that reads messages from the unmarshaller and passes them to the process function."""
unmarshaller = Unmarshaller(stream, sock) unmarshaller = Unmarshaller(None, sock, negotiate_unix_fd)
def _message_reader() -> None: def _message_reader() -> None:
"""Reads messages from the unmarshaller and passes them to the process function.""" """Reads messages from the unmarshaller and passes them to the process function."""
@ -28,9 +26,15 @@ def build_message_reader(
process(message) process(message)
except Exception as e: except Exception as e:
logging.error( logging.error(
f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}" "Unexpected error processing message: %s", exc_info=True
) )
unmarshaller._reset() # If we are not negotiating unix fds, we can stop reading as soon as we have
# the buffer is empty as asyncio will call us again when there is more data.
if (
not negotiate_unix_fd
and not unmarshaller._has_another_message_in_buffer()
):
return
except Exception as e: except Exception as e:
finalize(e) finalize(e)

View File

@ -457,21 +457,18 @@ def test_unmarshall_multiple_messages():
unpacked = unpack_variants(message.body) unpacked = unpack_variants(message.body)
assert unpacked == ["org.bluez.Device1", {"RSSI": -76}, []] assert unpacked == ["org.bluez.Device1", {"RSSI": -76}, []]
unmarshaller.reset()
assert unmarshaller.unmarshall() assert unmarshaller.unmarshall()
message = unmarshaller.message message = unmarshaller.message
assert message is not None assert message is not None
unpacked = unpack_variants(message.body) unpacked = unpack_variants(message.body)
assert unpacked == ["org.bluez.Device1", {"RSSI": -80}, []] assert unpacked == ["org.bluez.Device1", {"RSSI": -80}, []]
unmarshaller.reset()
assert unmarshaller.unmarshall() assert unmarshaller.unmarshall()
message = unmarshaller.message message = unmarshaller.message
assert message is not None assert message is not None
unpacked = unpack_variants(message.body) unpacked = unpack_variants(message.body)
assert unpacked == ["org.bluez.Device1", {"RSSI": -94}, []] assert unpacked == ["org.bluez.Device1", {"RSSI": -94}, []]
unmarshaller.reset()
with pytest.raises(EOFError): with pytest.raises(EOFError):
unmarshaller.unmarshall() unmarshaller.unmarshall()