fix: ensure the underlying socket is closed on disconnect (#12)

This commit is contained in:
J. Nick Koston 2022-09-09 13:41:08 -05:00 committed by GitHub
parent 2355fa1643
commit 6770a656bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@ import logging
import socket import socket
from asyncio import Queue from asyncio import Queue
from copy import copy from copy import copy
from typing import Optional from typing import Any, Optional
from .. import introspection as intr from .. import introspection as intr
from .._private.unmarshaller import Unmarshaller from .._private.unmarshaller import Unmarshaller
@ -24,18 +24,18 @@ from ..service import ServiceInterface
from .proxy_object import ProxyObject from .proxy_object import ProxyObject
def _future_set_exception(fut, exc): def _future_set_exception(fut: asyncio.Future, exc: Exception) -> None:
if fut is not None and not fut.done(): if fut is not None and not fut.done():
fut.set_exception(exc) fut.set_exception(exc)
def _future_set_result(fut, result): def _future_set_result(fut: asyncio.Future, result: Any) -> None:
if fut is not None and not fut.done(): if fut is not None and not fut.done():
fut.set_result(result) fut.set_result(result)
class _MessageWriter: class _MessageWriter:
def __init__(self, bus): def __init__(self, bus: "MessageBus") -> None:
self.messages = Queue() self.messages = Queue()
self.negotiate_unix_fd = bus._negotiate_unix_fd self.negotiate_unix_fd = bus._negotiate_unix_fd
self.bus = bus self.bus = bus
@ -82,7 +82,10 @@ class _MessageWriter:
# wait for writable # wait for writable
return return
except Exception as e: except Exception as e:
_future_set_exception(self.fut, e) if self.bus._user_disconnect:
_future_set_result(self.fut, None)
else:
_future_set_exception(self.fut, e)
self.bus._finalize(e) self.bus._finalize(e)
def buffer_message(self, msg: Message, future=None): def buffer_message(self, msg: Message, future=None):
@ -229,7 +232,7 @@ class MessageBus(BaseMessageBus):
""" """
future = self._loop.create_future() future = self._loop.create_future()
def reply_handler(reply, err): def reply_handler(reply: Any, err: Exception) -> None:
if err: if err:
_future_set_exception(future, err) _future_set_exception(future, err)
else: else:
@ -436,13 +439,24 @@ class MessageBus(BaseMessageBus):
if response == "BEGIN": if response == "BEGIN":
break break
def _create_unmarshaller(self): def disconnect(self):
"""Disconnect the message bus by closing the underlying connection asynchronously.
All pending and future calls will error with a connection error.
"""
super().disconnect()
try:
self._sock.close()
except Exception:
logging.warning("could not close socket", exc_info=True)
def _create_unmarshaller(self) -> Unmarshaller:
sock = None sock = None
if self._negotiate_unix_fd: if self._negotiate_unix_fd:
sock = self._sock sock = self._sock
return Unmarshaller(self._stream, sock) return Unmarshaller(self._stream, sock)
def _finalize(self, err=None): def _finalize(self, err: Optional[Exception] = None) -> None:
try: try:
self._loop.remove_reader(self._fd) self._loop.remove_reader(self._fd)
except Exception: except Exception: