fix: ensure proxy object tasks do not get garbage collected prematurely (#409)
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
@@ -6,7 +8,7 @@ import xml.etree.ElementTree as ET
|
|||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable
|
||||||
|
|
||||||
from . import introspection as intr
|
from . import introspection as intr
|
||||||
from . import message_bus
|
from . import message_bus
|
||||||
@@ -22,7 +24,7 @@ from .validators import assert_bus_name_valid, assert_object_path_valid
|
|||||||
class SignalHandler:
|
class SignalHandler:
|
||||||
"""Signal handler."""
|
"""Signal handler."""
|
||||||
|
|
||||||
fn: Callable
|
fn: Callable | Coroutine
|
||||||
unpack_variants: bool
|
unpack_variants: bool
|
||||||
|
|
||||||
|
|
||||||
@@ -57,7 +59,7 @@ class BaseProxyInterface:
|
|||||||
bus_name: str,
|
bus_name: str,
|
||||||
path: str,
|
path: str,
|
||||||
introspection: intr.Interface,
|
introspection: intr.Interface,
|
||||||
bus: "message_bus.BaseMessageBus",
|
bus: message_bus.BaseMessageBus,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bus_name = bus_name
|
self.bus_name = bus_name
|
||||||
self.path = path
|
self.path = path
|
||||||
@@ -65,6 +67,7 @@ class BaseProxyInterface:
|
|||||||
self.bus = bus
|
self.bus = bus
|
||||||
self._signal_handlers: dict[str, list[SignalHandler]] = {}
|
self._signal_handlers: dict[str, list[SignalHandler]] = {}
|
||||||
self._signal_match_rule = f"type='signal',sender={bus_name},interface={introspection.name},path={path}"
|
self._signal_match_rule = f"type='signal',sender={bus_name},interface={introspection.name},path={path}"
|
||||||
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
_underscorer1 = re.compile(r"(.)([A-Z][a-z]+)")
|
_underscorer1 = re.compile(r"(.)([A-Z][a-z]+)")
|
||||||
_underscorer2 = re.compile(r"([a-z0-9])([A-Z])")
|
_underscorer2 = re.compile(r"([a-z0-9])([A-Z])")
|
||||||
@@ -76,7 +79,7 @@ class BaseProxyInterface:
|
|||||||
return BaseProxyInterface._underscorer2.sub(r"\1_\2", subbed).lower()
|
return BaseProxyInterface._underscorer2.sub(r"\1_\2", subbed).lower()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_method_return(msg: Message, signature: Optional[str] = None):
|
def _check_method_return(msg: Message, signature: str | None = None):
|
||||||
if msg.message_type == MessageType.ERROR:
|
if msg.message_type == MessageType.ERROR:
|
||||||
raise DBusError._from_message(msg)
|
raise DBusError._from_message(msg)
|
||||||
if msg.message_type != MessageType.METHOD_RETURN:
|
if msg.message_type != MessageType.METHOD_RETURN:
|
||||||
@@ -137,10 +140,14 @@ class BaseProxyInterface:
|
|||||||
|
|
||||||
cb_result = handler.fn(*data)
|
cb_result = handler.fn(*data)
|
||||||
if isinstance(cb_result, Coroutine):
|
if isinstance(cb_result, Coroutine):
|
||||||
asyncio.create_task(cb_result) # noqa: RUF006
|
# Save a strong reference to the task so it doesn't get garbage
|
||||||
|
# collected before it finishes.
|
||||||
|
task = asyncio.create_task(cb_result)
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.remove)
|
||||||
|
|
||||||
def _add_signal(self, intr_signal: intr.Signal, interface: intr.Interface) -> None:
|
def _add_signal(self, intr_signal: intr.Signal, interface: intr.Interface) -> None:
|
||||||
def on_signal_fn(fn: Callable, *, unpack_variants: bool = False):
|
def on_signal_fn(fn: Callable | Coroutine, *, unpack_variants: bool = False):
|
||||||
fn_signature = inspect.signature(fn)
|
fn_signature = inspect.signature(fn)
|
||||||
if (
|
if (
|
||||||
len(
|
len(
|
||||||
@@ -182,7 +189,9 @@ class BaseProxyInterface:
|
|||||||
SignalHandler(fn, unpack_variants)
|
SignalHandler(fn, unpack_variants)
|
||||||
)
|
)
|
||||||
|
|
||||||
def off_signal_fn(fn: Callable, *, unpack_variants: bool = False) -> None:
|
def off_signal_fn(
|
||||||
|
fn: Callable | Coroutine, *, unpack_variants: bool = False
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
i = self._signal_handlers[intr_signal.name].index(
|
i = self._signal_handlers[intr_signal.name].index(
|
||||||
SignalHandler(fn, unpack_variants)
|
SignalHandler(fn, unpack_variants)
|
||||||
@@ -241,8 +250,8 @@ class BaseProxyObject:
|
|||||||
self,
|
self,
|
||||||
bus_name: str,
|
bus_name: str,
|
||||||
path: str,
|
path: str,
|
||||||
introspection: Union[intr.Node, str, ET.Element],
|
introspection: intr.Node | str | ET.Element,
|
||||||
bus: "message_bus.BaseMessageBus",
|
bus: message_bus.BaseMessageBus,
|
||||||
ProxyInterface: type[BaseProxyInterface],
|
ProxyInterface: type[BaseProxyInterface],
|
||||||
) -> None:
|
) -> None:
|
||||||
assert_object_path_valid(path)
|
assert_object_path_valid(path)
|
||||||
@@ -305,7 +314,7 @@ class BaseProxyObject:
|
|||||||
for intr_signal in intr_interface.signals:
|
for intr_signal in intr_interface.signals:
|
||||||
interface._add_signal(intr_signal, interface)
|
interface._add_signal(intr_signal, interface)
|
||||||
|
|
||||||
def get_owner_notify(msg: Message, err: Optional[Exception]) -> None:
|
def get_owner_notify(msg: Message, err: Exception | None) -> None:
|
||||||
if err:
|
if err:
|
||||||
logging.error(f'getting name owner for "{name}" failed, {err}')
|
logging.error(f'getting name owner for "{name}" failed, {err}')
|
||||||
return
|
return
|
||||||
@@ -334,7 +343,7 @@ class BaseProxyObject:
|
|||||||
self._interfaces[name] = interface
|
self._interfaces[name] = interface
|
||||||
return interface
|
return interface
|
||||||
|
|
||||||
def get_children(self) -> list["BaseProxyObject"]:
|
def get_children(self) -> list[BaseProxyObject]:
|
||||||
"""Get the child nodes of this proxy object according to the introspection data."""
|
"""Get the child nodes of this proxy object according to the introspection data."""
|
||||||
if self._children is None:
|
if self._children is None:
|
||||||
self._children = [
|
self._children = [
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from dbus_fast import Message
|
from dbus_fast import Message
|
||||||
@@ -359,6 +361,76 @@ async def test_kwargs_callback():
|
|||||||
bus2.disconnect()
|
bus2.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_coro_callback():
|
||||||
|
"""Test callback for signal with a coroutine."""
|
||||||
|
bus1 = await MessageBus().connect()
|
||||||
|
bus2 = await MessageBus().connect()
|
||||||
|
|
||||||
|
await bus1.request_name("test.signals.name")
|
||||||
|
service_interface = ExampleInterface()
|
||||||
|
bus1.export("/test/path", service_interface)
|
||||||
|
|
||||||
|
obj = bus2.get_proxy_object(
|
||||||
|
"test.signals.name", "/test/path", bus1._introspect_export_path("/test/path")
|
||||||
|
)
|
||||||
|
interface = obj.get_interface(service_interface.name)
|
||||||
|
|
||||||
|
async def ping():
|
||||||
|
await bus2.call(
|
||||||
|
Message(
|
||||||
|
destination=bus1.unique_name,
|
||||||
|
interface="org.freedesktop.DBus.Peer",
|
||||||
|
path="/test/path",
|
||||||
|
member="Ping",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs_handler_counter = 0
|
||||||
|
kwargs_handler_err = None
|
||||||
|
kwarg_default_handler_counter = 0
|
||||||
|
kwarg_default_handler_err = None
|
||||||
|
|
||||||
|
async def kwargs_handler(value, **_):
|
||||||
|
nonlocal kwargs_handler_counter
|
||||||
|
nonlocal kwargs_handler_err
|
||||||
|
try:
|
||||||
|
assert value == "hello"
|
||||||
|
kwargs_handler_counter += 1
|
||||||
|
except AssertionError as ex:
|
||||||
|
kwargs_handler_err = ex
|
||||||
|
|
||||||
|
async def kwarg_default_handler(value, *, _=True):
|
||||||
|
nonlocal kwarg_default_handler_counter
|
||||||
|
nonlocal kwarg_default_handler_err
|
||||||
|
try:
|
||||||
|
assert value == "hello"
|
||||||
|
kwarg_default_handler_counter += 1
|
||||||
|
except AssertionError as ex:
|
||||||
|
kwarg_default_handler_err = ex
|
||||||
|
|
||||||
|
interface.on_some_signal(kwargs_handler)
|
||||||
|
interface.on_some_signal(kwarg_default_handler)
|
||||||
|
await ping()
|
||||||
|
|
||||||
|
service_interface.SomeSignal()
|
||||||
|
await ping()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert kwargs_handler_err is None
|
||||||
|
assert kwargs_handler_counter == 1
|
||||||
|
assert kwarg_default_handler_err is None
|
||||||
|
assert kwarg_default_handler_counter == 1
|
||||||
|
|
||||||
|
def kwarg_bad_handler(value, *, bad_kwarg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
interface.on_some_signal(kwarg_bad_handler)
|
||||||
|
|
||||||
|
bus1.disconnect()
|
||||||
|
bus2.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_signal_type_error():
|
async def test_on_signal_type_error():
|
||||||
"""Test on callback raises type errors for invalid callbacks."""
|
"""Test on callback raises type errors for invalid callbacks."""
|
||||||
|
|||||||
Reference in New Issue
Block a user