diff --git a/src/dbus_fast/aio/message_bus.py b/src/dbus_fast/aio/message_bus.py index 9c9a0d3..d9e7614 100644 --- a/src/dbus_fast/aio/message_bus.py +++ b/src/dbus_fast/aio/message_bus.py @@ -45,15 +45,16 @@ class _MessageWriter: self.fd = bus._fd self.offset = 0 self.unix_fds = None - self.fut = None + self.fut: Optional[asyncio.Future] = None - def write_callback(self): + def write_callback(self, remove_writer: bool = True) -> None: try: while True: if self.buf is None: if self.messages.qsize() == 0: # nothing more to write - self.loop.remove_writer(self.fd) + if remove_writer: + self.loop.remove_writer(self.fd) return buf, unix_fds, fut = self.messages.get_nowait() self.unix_fds = unix_fds @@ -97,12 +98,28 @@ class _MessageWriter: ) ) + def _write_without_remove_writer(self): + """Call the write callback without removing the writer.""" + self.write_callback(remove_writer=False) + def schedule_write(self, msg: Message = None, future=None): + queue_is_empty = self.messages.qsize() == 0 if msg is not None: self.buffer_message(msg, future) if self.bus.unique_name: - # don't run the writer until the bus is ready to send messages - self.loop.add_writer(self.fd, self.write_callback) + # Optimization: try to send now if the queue + # is empty. With bleak this usually means we + # can send right away 99% of the time which + # is a huge improvement in latency. + if queue_is_empty: + self._write_without_remove_writer() + if ( + self.buf is not None + or self.messages.qsize() != 0 + or not self.fut + or not self.fut.done() + ): + self.loop.add_writer(self.fd, self.write_callback) class MessageBus(BaseMessageBus): diff --git a/src/dbus_fast/message_bus.py b/src/dbus_fast/message_bus.py index 287d0ba..a81ece7 100644 --- a/src/dbus_fast/message_bus.py +++ b/src/dbus_fast/message_bus.py @@ -646,12 +646,19 @@ class BaseMessageBus: self._name_owners[msg.destination] = reply.sender callback(reply, err) + no_reply_expected = msg.flags & MessageFlag.NO_REPLY_EXPECTED + + # Make sure the return reply handler is installed + # before sending the message to avoid a race condition + # where the reply is lost in case the backend can + # send it right away. + if not no_reply_expected: + self._method_return_handlers[msg.serial] = reply_notify + self.send(msg) - if msg.flags & MessageFlag.NO_REPLY_EXPECTED: + if no_reply_expected: callback(None, None) - else: - self._method_return_handlers[msg.serial] = reply_notify @staticmethod def _check_callback_type(callback): diff --git a/tests/test_disconnect.py b/tests/test_disconnect.py index 80f92a0..5f32f6b 100644 --- a/tests/test_disconnect.py +++ b/tests/test_disconnect.py @@ -1,5 +1,6 @@ import functools import os +from unittest.mock import patch import pytest @@ -16,19 +17,20 @@ async def test_bus_disconnect_before_reply(event_loop): await bus.connect() assert bus.connected - ping = bus.call( - Message( - destination="org.freedesktop.DBus", - path="/org/freedesktop/DBus", - interface="org.freedesktop.DBus", - member="Ping", + with patch.object(bus._writer, "_write_without_remove_writer"): + ping = bus.call( + Message( + destination="org.freedesktop.DBus", + path="/org/freedesktop/DBus", + interface="org.freedesktop.DBus", + member="Ping", + ) ) - ) - event_loop.call_soon(bus.disconnect) + event_loop.call_soon(bus.disconnect) - with pytest.raises((EOFError, BrokenPipeError)): - await ping + with pytest.raises((EOFError, BrokenPipeError)): + await ping assert bus._disconnected assert not bus.connected @@ -42,22 +44,23 @@ async def test_unexpected_disconnect(event_loop): await bus.connect() assert bus.connected - ping = bus.call( - Message( - destination="org.freedesktop.DBus", - path="/org/freedesktop/DBus", - interface="org.freedesktop.DBus", - member="Ping", + with patch.object(bus._writer, "_write_without_remove_writer"): + ping = bus.call( + Message( + destination="org.freedesktop.DBus", + path="/org/freedesktop/DBus", + interface="org.freedesktop.DBus", + member="Ping", + ) ) - ) - event_loop.call_soon(functools.partial(os.close, bus._fd)) + event_loop.call_soon(functools.partial(os.close, bus._fd)) - with pytest.raises(OSError): - await ping + with pytest.raises(OSError): + await ping - assert bus._disconnected - assert not bus.connected + assert bus._disconnected + assert not bus.connected with pytest.raises(OSError): await bus.wait_for_disconnect()