fix: avoid cythonizing SendReply (#232)

This commit is contained in:
J. Nick Koston 2023-08-24 09:28:53 -05:00 committed by GitHub
parent ed5c87f492
commit d12266ddef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 63 additions and 65 deletions

View File

@ -18,11 +18,6 @@ cdef object assert_bus_name_valid
cdef _expects_reply(Message msg) cdef _expects_reply(Message msg)
cdef class SendReply:
cdef object _bus
cdef object _msg
cdef class BaseMessageBus: cdef class BaseMessageBus:
cdef public object unique_name cdef public object unique_name

View File

@ -3,7 +3,6 @@ import logging
import socket import socket
import traceback import traceback
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable, Dict, List, Optional, Type, Union
from . import introspection as intr from . import introspection as intr
@ -21,6 +20,7 @@ from .constants import (
from .errors import DBusError, InvalidAddressError from .errors import DBusError, InvalidAddressError
from .message import Message from .message import Message
from .proxy_object import BaseProxyObject from .proxy_object import BaseProxyObject
from .send_reply import SendReply
from .service import ServiceInterface, _Method from .service import ServiceInterface, _Method
from .signature import Variant from .signature import Variant
from .validators import assert_bus_name_valid, assert_object_path_valid from .validators import assert_bus_name_valid, assert_object_path_valid
@ -57,55 +57,6 @@ def _block_unexpected_reply(reply: _Message) -> None:
BLOCK_UNEXPECTED_REPLY = _block_unexpected_reply BLOCK_UNEXPECTED_REPLY = _block_unexpected_reply
class SendReply:
"""A context manager to send a reply to a message."""
__slots__ = ("_bus", "_msg")
def __init__(self, bus: "BaseMessageBus", msg: Message) -> None:
"""Create a new reply context manager."""
self._bus = bus
self._msg = msg
def __enter__(self):
return self
def __call__(self, reply: Message) -> None:
self._bus.send(reply)
def _exit(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
) -> bool:
if exc_value:
if isinstance(exc_value, DBusError):
self(exc_value._as_message(self._msg))
else:
self(
Message.new_error(
self._msg,
ErrorType.SERVICE_ERROR,
f"The service interface raised an error: {exc_value}.\n{traceback.format_tb(tb)}",
)
)
return True
return False
def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
) -> bool:
return self._exit(exc_type, exc_value, tb)
def send_error(self, exc: Exception) -> None:
self._exit(exc.__class__, exc, exc.__traceback__)
class BaseMessageBus: class BaseMessageBus:
"""An abstract class to manage a connection to a DBus message bus. """An abstract class to manage a connection to a DBus message bus.

View File

@ -0,0 +1,59 @@
import traceback
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Type
from .constants import ErrorType
from .errors import DBusError
from .message import Message
if TYPE_CHECKING:
from .message_bus import BaseMessageBus
class SendReply:
"""A context manager to send a reply to a message."""
__slots__ = ("_bus", "_msg")
def __init__(self, bus: "BaseMessageBus", msg: Message) -> None:
"""Create a new reply context manager."""
self._bus = bus
self._msg = msg
def __enter__(self):
return self
def __call__(self, reply: Message) -> None:
self._bus.send(reply)
def _exit(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
) -> bool:
if exc_value:
if isinstance(exc_value, DBusError):
self(exc_value._as_message(self._msg))
else:
self(
Message.new_error(
self._msg,
ErrorType.SERVICE_ERROR,
f"The service interface raised an error: {exc_value}.\n{traceback.format_tb(tb)}",
)
)
return True
return False
def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
tb: Optional[TracebackType],
) -> bool:
return self._exit(exc_type, exc_value, tb)
def send_error(self, exc: Exception) -> None:
self._exit(exc.__class__, exc, exc.__traceback__)

View File

@ -240,7 +240,6 @@ async def test_standard_interface_properties():
"org.freedesktop.DBus.Peer", "org.freedesktop.DBus.Peer",
"org.freedesktop.DBus.ObjectManager", "org.freedesktop.DBus.ObjectManager",
]: ]:
result = await bus2.call( result = await bus2.call(
Message( Message(
destination=bus1.unique_name, destination=bus1.unique_name,

View File

@ -2,7 +2,6 @@ from dbus_fast._private.address import parse_address
def test_valid_addresses(): def test_valid_addresses():
valid_addresses = { valid_addresses = {
"unix:path=/run/user/1000/bus": [("unix", {"path": "/run/user/1000/bus"})], "unix:path=/run/user/1000/bus": [("unix", {"path": "/run/user/1000/bus"})],
"unix:abstract=/tmp/dbus-ft9sODWpZk,guid=a7b1d5912379c2d471165e9b5cb74a03": [ "unix:abstract=/tmp/dbus-ft9sODWpZk,guid=a7b1d5912379c2d471165e9b5cb74a03": [

View File

@ -104,7 +104,6 @@ def test_unmarshalling_with_table(unmarshall_table):
from dbus_fast._private import unmarshaller from dbus_fast._private import unmarshaller
for item in unmarshall_table: for item in unmarshall_table:
stream = io.BytesIO(bytes.fromhex(item["data"])) stream = io.BytesIO(bytes.fromhex(item["data"]))
unmarshaller = Unmarshaller(stream) unmarshaller = Unmarshaller(stream)
try: try:
@ -486,7 +485,6 @@ def tests_fallback_no_cython():
def test_unmarshall_large_message(): def test_unmarshall_large_message():
stream = io.BytesIO(bytes.fromhex(get_managed_objects_msg)) stream = io.BytesIO(bytes.fromhex(get_managed_objects_msg))
unmarshaller = Unmarshaller(stream) unmarshaller = Unmarshaller(stream)
unmarshaller.unmarshall() unmarshaller.unmarshall()

View File

@ -1,15 +1,12 @@
import os import os
import traceback
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Type, Union
from unittest.mock import Mock
import pytest import pytest
from dbus_fast.constants import ErrorType, MessageType from dbus_fast.constants import ErrorType, MessageType
from dbus_fast.errors import DBusError from dbus_fast.errors import DBusError
from dbus_fast.message import Message from dbus_fast.message import Message
from dbus_fast.message_bus import BaseMessageBus, SendReply from dbus_fast.message_bus import BaseMessageBus
from dbus_fast.send_reply import SendReply
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -44,7 +41,7 @@ def test_send_reply_exception() -> None:
) )
send_reply = SendReply(mock_message_bus, mock_message) send_reply = SendReply(mock_message_bus, mock_message)
with send_reply as reply: with send_reply:
raise DBusError(ErrorType.DISCONNECTED, "Disconnected", None) raise DBusError(ErrorType.DISCONNECTED, "Disconnected", None)
assert len(messages) == 1 assert len(messages) == 1