diff --git a/src/dbus_fast/message_bus.py b/src/dbus_fast/message_bus.py index e54728d..7a8d733 100644 --- a/src/dbus_fast/message_bus.py +++ b/src/dbus_fast/message_bus.py @@ -878,7 +878,10 @@ class BaseMessageBus: ) -> None: """This is the callback that will be called when a method call is.""" args = ServiceInterface._c_msg_body_to_args(msg) if msg.unix_fds else msg.body - result = method.fn(interface, *args) + kwargs = {} + if method.caller_argument is not None: + kwargs = {method.caller_argument: msg.sender} + result = method.fn(interface, *args, **kwargs) if send_reply is BLOCK_UNEXPECTED_REPLY or _expects_reply(msg) is False: return body_fds = ServiceInterface._c_fn_result_to_body( diff --git a/src/dbus_fast/service.pxd b/src/dbus_fast/service.pxd index f1903d6..a376f30 100644 --- a/src/dbus_fast/service.pxd +++ b/src/dbus_fast/service.pxd @@ -16,6 +16,7 @@ cdef class _Method: cdef public str out_signature cdef public SignatureTree in_signature_tree cdef public SignatureTree out_signature_tree + cdef public str caller_argument diff --git a/src/dbus_fast/service.py b/src/dbus_fast/service.py index 0b65b96..fbd829a 100644 --- a/src/dbus_fast/service.py +++ b/src/dbus_fast/service.py @@ -71,6 +71,7 @@ class _Method: for type_ in get_signature_tree(out_signature).types: out_args.append(intr.Arg(type_, intr.ArgDirection.OUT)) + self.caller_argument = caller_arg self.name = name self.fn = fn self.disabled = disabled @@ -79,7 +80,6 @@ class _Method: self.out_signature = out_signature self.in_signature_tree = get_signature_tree(in_signature) self.out_signature_tree = get_signature_tree(out_signature) - self.caller_argument = caller_arg def method( @@ -123,7 +123,7 @@ def method( raise TypeError("name must be a string") if type(disabled) is not bool: raise TypeError("disabled must be a bool") - if type(inject_caller) is not bool or type(caller_arg) is not str: + if type(inject_caller) is not bool and type(inject_caller) is not str: raise TypeError("inject_caller must be a string or bool") caller_arg = None diff --git a/tests/service/test_methods.py b/tests/service/test_methods.py index 9f7d719..ccbadab 100644 --- a/tests/service/test_methods.py +++ b/tests/service/test_methods.py @@ -60,6 +60,16 @@ class ExampleInterface(ServiceInterface): assert type(self) is ExampleInterface raise DBusError("test.error", "an error occurred") + @method(inject_caller=True) + def echo_caller(self, caller: str | None = None) -> "s": + assert type(self) is ExampleInterface + return caller + + @method(inject_caller="source") + def echo_source(self, source: str | None = None) -> "s": + assert type(self) is ExampleInterface + return source + class AsyncInterface(ServiceInterface): def __init__(self, name): @@ -108,6 +118,16 @@ class AsyncInterface(ServiceInterface): assert type(self) is AsyncInterface raise DBusError("test.error", "an error occurred") + @method(inject_caller=True) + async def echo_caller(self, caller: str | None = None) -> "s": + assert type(self) is AsyncInterface + return caller + + @method(inject_caller="source") + async def echo_source(self, source: str | None = None) -> "s": + assert type(self) is AsyncInterface + return source + @pytest.mark.parametrize("interface_class", [ExampleInterface, AsyncInterface]) @pytest.mark.asyncio @@ -216,6 +236,14 @@ async def test_methods(interface_class): 'test.interface.does_not_exist with signature "" could not be found' ] + reply = await call("echo_caller") + assert reply.message_type == MessageType.METHOD_RETURN, reply.body[0] + assert reply.body[0] == bus2.unique_name + + reply = await call("echo_source") + assert reply.message_type == MessageType.METHOD_RETURN, reply.body[0] + assert reply.body[0] == bus2.unique_name + bus1.disconnect() bus2.disconnect() bus1._sock.close()