diff --git a/src/dbus_fast/message_bus.pxd b/src/dbus_fast/message_bus.pxd index db7ef4d..01a182f 100644 --- a/src/dbus_fast/message_bus.pxd +++ b/src/dbus_fast/message_bus.pxd @@ -11,6 +11,11 @@ cdef object ServiceInterface cdef object MESSAGE_TYPE_CALL cdef object MESSAGE_TYPE_SIGNAL +cdef class SendReply: + + cdef object _bus + cdef object _msg + cdef class BaseMessageBus: cdef public object unique_name diff --git a/src/dbus_fast/message_bus.py b/src/dbus_fast/message_bus.py index c6b2663..9327bc2 100644 --- a/src/dbus_fast/message_bus.py +++ b/src/dbus_fast/message_bus.py @@ -29,6 +29,60 @@ MESSAGE_TYPE_CALL = MessageType.METHOD_CALL MESSAGE_TYPE_SIGNAL = MessageType.SIGNAL +class SendReply: + """A context manager to send a reply to a message.""" + + __slots__ = ("_bus", "_msg") + + def __init__(self, bus: "BaseMessageBus", msg: Message) -> None: + """Create a new reply context manager.""" + self._bus = bus + self._msg = msg + + def __enter__(self): + return self + + def __call__(self, reply: Message) -> None: + if self._msg.flags & MessageFlag.NO_REPLY_EXPECTED: + return + + self._bus.send(reply) + + def _exit( + self, + exc_type: Optional[Type[Exception]], + exc_value: Optional[Exception], + tb: Optional[TracebackType], + ) -> bool: + if exc_type is None: + return False + + if issubclass(exc_type, DBusError): + self(exc_value._as_message(self._msg)) # type: ignore[union-attr] + return True + + if issubclass(exc_type, Exception): + self( + Message.new_error( + self._msg, + ErrorType.SERVICE_ERROR, + f"The service interface raised an error: {exc_value}.\n{traceback.format_tb(tb)}", + ) + ) + return True + + def __exit__( + self, + exc_type: Optional[Type[Exception]], + exc_value: Optional[Exception], + tb: Optional[TracebackType], + ) -> bool: + return self._exit(exc_type, exc_value, tb) + + def send_error(self, exc: Exception) -> None: + self._exit(exc.__class__, exc, exc.__traceback__) + + class BaseMessageBus: """An abstract class to manage a connection to a DBus message bus. @@ -747,55 +801,6 @@ class BaseMessageBus: ErrorType.INTERNAL_ERROR, "invalid message type for method call", msg ) - def _send_reply(self, msg: Message): - bus = self - - class SendReply: - def __enter__(self) -> "SendReply": - return self - - def __call__(self, reply: Message) -> None: - if msg.flags & MessageFlag.NO_REPLY_EXPECTED: - return - - bus.send(reply) - - def _exit( - self, - exc_type: Optional[Type[Exception]], - exc_value: Optional[Exception], - tb: Optional[TracebackType], - ) -> bool: - if exc_type is None: - return False - - if issubclass(exc_type, DBusError): - self(exc_value._as_message(msg)) # type: ignore[union-attr] - return True - - if issubclass(exc_type, Exception): - self( - Message.new_error( - msg, - ErrorType.SERVICE_ERROR, - f"The service interface raised an error: {exc_value}.\n{traceback.format_tb(tb)}", - ) - ) - return True - - def __exit__( - self, - exc_type: Optional[Type[Exception]], - exc_value: Optional[Exception], - tb: Optional[TracebackType], - ) -> bool: - return self._exit(exc_type, exc_value, tb) - - def send_error(self, exc: Exception) -> None: - self._exit(exc.__class__, exc, exc.__traceback__) - - return SendReply() - def _process_message(self, msg) -> None: handled = False for user_handler in self._user_message_handlers: @@ -848,7 +853,7 @@ class BaseMessageBus: if not handled: handler = self._find_message_handler(msg) - send_reply = self._send_reply(msg) + send_reply = SendReply(self, msg) with send_reply: if handler: