fix: avoid double buffering when using asyncio reader without negotiate_unix_fd (#213)
This commit is contained in:
parent
e669a13de2
commit
c933be7095
@ -23,7 +23,6 @@ unmarshaller = Unmarshaller(stream)
|
||||
|
||||
def unmarhsall_bluez_rssi_message():
|
||||
stream.seek(0)
|
||||
unmarshaller.reset()
|
||||
unmarshaller.unmarshall()
|
||||
|
||||
|
||||
|
||||
@ -42,7 +42,6 @@ unmarshaller = Unmarshaller(stream)
|
||||
|
||||
def unmarhsall_bluez_rssi_message():
|
||||
stream.seek(0)
|
||||
unmarshaller.reset()
|
||||
unmarshaller.unmarshall()
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,6 @@ unmarshaller = Unmarshaller(stream)
|
||||
|
||||
def unmarhsall_bluez_get_managed_objects_message():
|
||||
stream.seek(0)
|
||||
unmarshaller.reset()
|
||||
unmarshaller.unmarshall()
|
||||
|
||||
|
||||
|
||||
@ -31,7 +31,6 @@ unmarshaller = Unmarshaller(stream)
|
||||
|
||||
def unmarshall_interfaces_added_message():
|
||||
stream.seek(0)
|
||||
unmarshaller.reset()
|
||||
unmarshaller.unmarshall()
|
||||
|
||||
|
||||
|
||||
@ -22,7 +22,6 @@ unmarshaller = Unmarshaller(stream)
|
||||
|
||||
def unmarshall_interfaces_removed_message():
|
||||
stream.seek(0)
|
||||
unmarshaller.reset()
|
||||
unmarshaller.unmarshall()
|
||||
|
||||
|
||||
|
||||
@ -21,7 +21,6 @@ unmarshaller = Unmarshaller(stream)
|
||||
|
||||
def unmarshall_mfr_data_message():
|
||||
stream.seek(0)
|
||||
unmarshaller.reset()
|
||||
unmarshaller.unmarshall()
|
||||
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ unmarshaller = Unmarshaller(stream)
|
||||
|
||||
def unmarhsall_bluez_rssi_message():
|
||||
stream.seek(0)
|
||||
unmarshaller.reset()
|
||||
unmarshaller.unmarshall()
|
||||
|
||||
|
||||
|
||||
@ -23,7 +23,6 @@ unmarshaller = Unmarshaller(stream)
|
||||
|
||||
def unmarshall_properties_changed_message():
|
||||
stream.seek(0)
|
||||
unmarshaller.reset()
|
||||
unmarshaller.unmarshall()
|
||||
|
||||
|
||||
|
||||
@ -78,6 +78,10 @@ cdef unsigned int TOKEN_S_AS_INT
|
||||
cdef unsigned int TOKEN_G_AS_INT
|
||||
|
||||
cdef object MARSHALL_STREAM_END_ERROR
|
||||
cdef object DEFAULT_BUFFER_SIZE
|
||||
|
||||
cdef cython.uint EAGAIN
|
||||
cdef cython.uint EWOULDBLOCK
|
||||
|
||||
cdef get_signature_tree
|
||||
|
||||
@ -116,23 +120,32 @@ cdef class Unmarshaller:
|
||||
cdef object _int16_unpack
|
||||
cdef object _uint16_unpack
|
||||
cdef object _stream_reader
|
||||
cdef bint _negotiate_unix_fd
|
||||
cdef bint _read_complete
|
||||
|
||||
cdef _reset(self)
|
||||
cdef _next_message(self)
|
||||
|
||||
cpdef reset(self)
|
||||
cdef _has_another_message_in_buffer(self)
|
||||
|
||||
@cython.locals(
|
||||
msg=cython.bytes,
|
||||
recv=cython.tuple
|
||||
recv=cython.tuple,
|
||||
errno=cython.uint
|
||||
)
|
||||
cdef bytes _read_sock(self, object length)
|
||||
cdef _read_sock_with_fds(self, unsigned int pos, unsigned int missing_bytes)
|
||||
|
||||
@cython.locals(
|
||||
data=cython.bytes,
|
||||
errno=cython.uint
|
||||
)
|
||||
cdef _read_sock_without_fds(self, unsigned int pos)
|
||||
|
||||
@cython.locals(
|
||||
start_len=cython.ulong,
|
||||
missing_bytes=cython.ulong,
|
||||
data=cython.bytes
|
||||
)
|
||||
cdef _read_to_pos(self, unsigned long pos)
|
||||
cdef _read_stream(self, unsigned int pos, unsigned int missing_bytes)
|
||||
|
||||
cdef _read_to_pos(self, unsigned int pos)
|
||||
|
||||
cpdef read_boolean(self, SignatureType type_)
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import array
|
||||
import errno
|
||||
import io
|
||||
import socket
|
||||
import sys
|
||||
@ -102,6 +103,10 @@ ARRAY = array.array
|
||||
SOL_SOCKET = socket.SOL_SOCKET
|
||||
SCM_RIGHTS = socket.SCM_RIGHTS
|
||||
|
||||
EAGAIN = errno.EAGAIN
|
||||
EWOULDBLOCK = errno.EWOULDBLOCK
|
||||
|
||||
|
||||
HEADER_MESSAGE_ARG_NAME = {
|
||||
1: "path",
|
||||
2: "interface",
|
||||
@ -121,6 +126,8 @@ READER_TYPE = Callable[["Unmarshaller", SignatureType], Any]
|
||||
|
||||
MARSHALL_STREAM_END_ERROR = BlockingIOError
|
||||
|
||||
DEFAULT_BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE
|
||||
|
||||
|
||||
def unpack_parser_factory(unpack_from: Callable, size: int) -> READER_TYPE:
|
||||
"""Build a parser that unpacks the bytes using the given unpack_from function."""
|
||||
@ -169,6 +176,12 @@ except ImportError:
|
||||
#
|
||||
#
|
||||
class Unmarshaller:
|
||||
"""Unmarshall messages from a stream.
|
||||
|
||||
When calling with sock and _negotiate_unix_fd False, the unmashaller must
|
||||
be called continuously for each new message as it will buffer the data
|
||||
until a complete message is available.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"_unix_fds",
|
||||
@ -189,9 +202,16 @@ class Unmarshaller:
|
||||
"_uint16_unpack",
|
||||
"_is_native",
|
||||
"_stream_reader",
|
||||
"_negotiate_unix_fd",
|
||||
"_read_complete",
|
||||
)
|
||||
|
||||
def __init__(self, stream: io.BufferedRWPair, sock: Optional[socket.socket] = None):
|
||||
def __init__(
|
||||
self,
|
||||
stream: Optional[io.BufferedRWPair] = None,
|
||||
sock: Optional[socket.socket] = None,
|
||||
negotiate_unix_fd: bool = True,
|
||||
) -> None:
|
||||
self._unix_fds: List[int] = []
|
||||
self._buf = bytearray() # Actual buffer
|
||||
self._stream = stream
|
||||
@ -210,25 +230,24 @@ class Unmarshaller:
|
||||
self._int16_unpack: Optional[Callable] = None
|
||||
self._uint16_unpack: Optional[Callable] = None
|
||||
self._stream_reader: Optional[Callable] = None
|
||||
if self._sock is None:
|
||||
self._negotiate_unix_fd = negotiate_unix_fd
|
||||
self._read_complete = False
|
||||
if stream:
|
||||
if isinstance(stream, io.BufferedRWPair) and hasattr(stream, "reader"):
|
||||
self._stream_reader = stream.reader.read # type: ignore[attr-defined]
|
||||
self._stream_reader = stream.read
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the unmarshaller to its initial state.
|
||||
|
||||
Call this before processing a new message.
|
||||
"""
|
||||
self._reset()
|
||||
|
||||
def _reset(self) -> None:
|
||||
def _next_message(self) -> None:
|
||||
"""Reset the unmarshaller to its initial state.
|
||||
|
||||
Call this before processing a new message.
|
||||
"""
|
||||
self._unix_fds = []
|
||||
self._buf.clear()
|
||||
to_clear = HEADER_SIGNATURE_SIZE + self._msg_len
|
||||
if len(self._buf) == to_clear:
|
||||
self._buf = bytearray()
|
||||
else:
|
||||
del self._buf[:to_clear]
|
||||
self._message = None
|
||||
self._pos = 0
|
||||
self._body_len = 0
|
||||
@ -238,6 +257,7 @@ class Unmarshaller:
|
||||
self._flag = 0
|
||||
self._msg_len = 0
|
||||
self._is_native = 0
|
||||
self._read_complete = False
|
||||
# No need to reset the unpack functions, they are set in _read_header
|
||||
# every time a new message is processed.
|
||||
|
||||
@ -246,12 +266,26 @@ class Unmarshaller:
|
||||
"""Return the message that has been unmarshalled."""
|
||||
return self._message
|
||||
|
||||
def _read_sock(self, length: _int) -> bytes:
|
||||
def _has_another_message_in_buffer(self) -> bool:
|
||||
"""Check if there is another message in the buffer."""
|
||||
return len(self._buf) > HEADER_SIGNATURE_SIZE + self._msg_len
|
||||
|
||||
def _read_sock_with_fds(self, pos: _int, missing_bytes: _int) -> None:
|
||||
"""reads from the socket, storing any fds sent and handling errors
|
||||
from the read itself"""
|
||||
from the read itself.
|
||||
|
||||
This function is greedy and will read as much data as possible
|
||||
from the underlying socket.
|
||||
"""
|
||||
# This will raise BlockingIOError if there is no data to read
|
||||
# which we store in the MARSHALL_STREAM_END_ERROR object
|
||||
recv = self._sock.recvmsg(length, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr]
|
||||
try:
|
||||
recv = self._sock.recvmsg(missing_bytes, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr]
|
||||
except OSError as e:
|
||||
errno = e.errno
|
||||
if errno == EAGAIN or errno == EWOULDBLOCK:
|
||||
raise MARSHALL_STREAM_END_ERROR
|
||||
raise
|
||||
msg = recv[0]
|
||||
ancdata = recv[1]
|
||||
if ancdata:
|
||||
@ -261,8 +295,44 @@ class Unmarshaller:
|
||||
self._unix_fds.extend(
|
||||
ARRAY("i", data[: len(data) - (len(data) % MAX_UNIX_FDS_SIZE)])
|
||||
)
|
||||
if msg == b"":
|
||||
raise EOFError()
|
||||
self._buf += msg
|
||||
if len(self._buf) < pos:
|
||||
raise MARSHALL_STREAM_END_ERROR
|
||||
|
||||
return msg
|
||||
def _read_sock_without_fds(self, pos: _int) -> None:
|
||||
"""reads from the socket and handling errors from the read itself.
|
||||
|
||||
This function is greedy and will read as much data as possible
|
||||
from the underlying socket.
|
||||
"""
|
||||
# This will raise BlockingIOError if there is no data to read
|
||||
# which we store in the MARSHALL_STREAM_END_ERROR object
|
||||
while True:
|
||||
try:
|
||||
data = self._sock.recv(DEFAULT_BUFFER_SIZE) # type: ignore[union-attr]
|
||||
except OSError as e:
|
||||
errno = e.errno
|
||||
if errno == EAGAIN or errno == EWOULDBLOCK:
|
||||
raise MARSHALL_STREAM_END_ERROR
|
||||
raise
|
||||
if data == b"":
|
||||
raise EOFError()
|
||||
self._buf += data
|
||||
if len(self._buf) >= pos:
|
||||
return
|
||||
|
||||
def _read_stream(self, pos: _int, missing_bytes: _int) -> bytes:
|
||||
"""Read from the stream."""
|
||||
data = self._stream_reader(missing_bytes) # type: ignore[misc]
|
||||
if data is None:
|
||||
raise MARSHALL_STREAM_END_ERROR
|
||||
if data == b"":
|
||||
raise EOFError()
|
||||
self._buf += data
|
||||
if len(self._buf) < pos:
|
||||
raise MARSHALL_STREAM_END_ERROR
|
||||
|
||||
def _read_to_pos(self, pos: _int) -> None:
|
||||
"""
|
||||
@ -277,19 +347,15 @@ class Unmarshaller:
|
||||
:returns:
|
||||
None
|
||||
"""
|
||||
start_len = len(self._buf)
|
||||
missing_bytes = pos - (start_len - self._pos)
|
||||
missing_bytes = pos - len(self._buf)
|
||||
if missing_bytes <= 0:
|
||||
return
|
||||
if self._sock is None:
|
||||
data = self._stream_reader(missing_bytes) # type: ignore[misc]
|
||||
self._read_stream(pos, missing_bytes)
|
||||
elif self._negotiate_unix_fd:
|
||||
self._read_sock_with_fds(pos, missing_bytes)
|
||||
else:
|
||||
data = self._read_sock(missing_bytes)
|
||||
if data == b"":
|
||||
raise EOFError()
|
||||
if data is None:
|
||||
raise MARSHALL_STREAM_END_ERROR
|
||||
self._buf += data
|
||||
if len(data) + start_len != pos:
|
||||
raise MARSHALL_STREAM_END_ERROR
|
||||
self._read_sock_without_fds(pos)
|
||||
|
||||
def read_uint32_unpack(self, type_: _SignatureType) -> int:
|
||||
return self._read_uint32_unpack()
|
||||
@ -641,6 +707,7 @@ class Unmarshaller:
|
||||
validate=False,
|
||||
**header_fields,
|
||||
)
|
||||
self._read_complete = True
|
||||
|
||||
def unmarshall(self) -> Optional[Message]:
|
||||
"""Unmarshall the message.
|
||||
@ -658,6 +725,8 @@ class Unmarshaller:
|
||||
if there are not enough bytes in the buffer. This allows unmarshall
|
||||
to be resumed when more data comes in over the wire.
|
||||
"""
|
||||
if self._read_complete:
|
||||
self._next_message()
|
||||
try:
|
||||
if not self._msg_len:
|
||||
self._read_header()
|
||||
|
||||
@ -201,10 +201,10 @@ class MessageBus(BaseMessageBus):
|
||||
self._loop.add_reader(
|
||||
self._fd,
|
||||
build_message_reader(
|
||||
self._stream,
|
||||
self._sock if self._negotiate_unix_fd else None,
|
||||
self._sock,
|
||||
self._process_message,
|
||||
self._finalize,
|
||||
self._negotiate_unix_fd,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
import io
|
||||
import logging
|
||||
import socket
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
|
||||
from .._private.unmarshaller import Unmarshaller
|
||||
@ -9,13 +7,13 @@ from ..message import Message
|
||||
|
||||
|
||||
def build_message_reader(
|
||||
stream: io.BufferedRWPair,
|
||||
sock: Optional[socket.socket],
|
||||
process: Callable[[Message], None],
|
||||
finalize: Callable[[Optional[Exception]], None],
|
||||
negotiate_unix_fd: bool,
|
||||
) -> None:
|
||||
"""Build a callable that reads messages from the unmarshaller and passes them to the process function."""
|
||||
unmarshaller = Unmarshaller(stream, sock)
|
||||
unmarshaller = Unmarshaller(None, sock, negotiate_unix_fd)
|
||||
|
||||
def _message_reader() -> None:
|
||||
"""Reads messages from the unmarshaller and passes them to the process function."""
|
||||
@ -28,9 +26,15 @@ def build_message_reader(
|
||||
process(message)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}"
|
||||
"Unexpected error processing message: %s", exc_info=True
|
||||
)
|
||||
unmarshaller._reset()
|
||||
# If we are not negotiating unix fds, we can stop reading as soon as we have
|
||||
# the buffer is empty as asyncio will call us again when there is more data.
|
||||
if (
|
||||
not negotiate_unix_fd
|
||||
and not unmarshaller._has_another_message_in_buffer()
|
||||
):
|
||||
return
|
||||
except Exception as e:
|
||||
finalize(e)
|
||||
|
||||
|
||||
@ -457,21 +457,18 @@ def test_unmarshall_multiple_messages():
|
||||
unpacked = unpack_variants(message.body)
|
||||
assert unpacked == ["org.bluez.Device1", {"RSSI": -76}, []]
|
||||
|
||||
unmarshaller.reset()
|
||||
assert unmarshaller.unmarshall()
|
||||
message = unmarshaller.message
|
||||
assert message is not None
|
||||
unpacked = unpack_variants(message.body)
|
||||
assert unpacked == ["org.bluez.Device1", {"RSSI": -80}, []]
|
||||
|
||||
unmarshaller.reset()
|
||||
assert unmarshaller.unmarshall()
|
||||
message = unmarshaller.message
|
||||
assert message is not None
|
||||
unpacked = unpack_variants(message.body)
|
||||
assert unpacked == ["org.bluez.Device1", {"RSSI": -94}, []]
|
||||
|
||||
unmarshaller.reset()
|
||||
with pytest.raises(EOFError):
|
||||
unmarshaller.unmarshall()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user