fix: ensure the underlying socket is closed on disconnect (#12)

This commit is contained in:
J. Nick Koston 2022-09-09 13:41:08 -05:00 committed by GitHub
parent 2355fa1643
commit 6770a656bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@ import logging
import socket
from asyncio import Queue
from copy import copy
from typing import Optional
from typing import Any, Optional
from .. import introspection as intr
from .._private.unmarshaller import Unmarshaller
@ -24,18 +24,18 @@ from ..service import ServiceInterface
from .proxy_object import ProxyObject
def _future_set_exception(fut, exc):
def _future_set_exception(fut: asyncio.Future, exc: Exception) -> None:
if fut is not None and not fut.done():
fut.set_exception(exc)
def _future_set_result(fut, result):
def _future_set_result(fut: asyncio.Future, result: Any) -> None:
if fut is not None and not fut.done():
fut.set_result(result)
class _MessageWriter:
def __init__(self, bus):
def __init__(self, bus: "MessageBus") -> None:
self.messages = Queue()
self.negotiate_unix_fd = bus._negotiate_unix_fd
self.bus = bus
@ -82,7 +82,10 @@ class _MessageWriter:
# wait for writable
return
except Exception as e:
_future_set_exception(self.fut, e)
if self.bus._user_disconnect:
_future_set_result(self.fut, None)
else:
_future_set_exception(self.fut, e)
self.bus._finalize(e)
def buffer_message(self, msg: Message, future=None):
@ -229,7 +232,7 @@ class MessageBus(BaseMessageBus):
"""
future = self._loop.create_future()
def reply_handler(reply, err):
def reply_handler(reply: Any, err: Exception) -> None:
if err:
_future_set_exception(future, err)
else:
@ -436,13 +439,24 @@ class MessageBus(BaseMessageBus):
if response == "BEGIN":
break
def _create_unmarshaller(self):
def disconnect(self):
"""Disconnect the message bus by closing the underlying connection asynchronously.
All pending and future calls will error with a connection error.
"""
super().disconnect()
try:
self._sock.close()
except Exception:
logging.warning("could not close socket", exc_info=True)
def _create_unmarshaller(self) -> Unmarshaller:
sock = None
if self._negotiate_unix_fd:
sock = self._sock
return Unmarshaller(self._stream, sock)
def _finalize(self, err=None):
def _finalize(self, err: Optional[Exception] = None) -> None:
try:
self._loop.remove_reader(self._fd)
except Exception: