fix: more typing fixes (#40)

This commit is contained in:
J. Nick Koston
2022-09-27 09:47:40 -10:00
committed by GitHub
parent b26c5f8e28
commit a6b9581d62
2 changed files with 43 additions and 29 deletions

View File

@@ -91,7 +91,7 @@ class _MessageWriter:
_future_set_exception(self.fut, e) _future_set_exception(self.fut, e)
self.bus._finalize(e) self.bus._finalize(e)
def buffer_message(self, msg: Message, future=None): def buffer_message(self, msg: Message, future=None) -> None:
self.messages.append( self.messages.append(
( (
msg._marshall(negotiate_unix_fd=self.negotiate_unix_fd), msg._marshall(negotiate_unix_fd=self.negotiate_unix_fd),
@@ -100,11 +100,11 @@ class _MessageWriter:
) )
) )
def _write_without_remove_writer(self): def _write_without_remove_writer(self) -> None:
"""Call the write callback without removing the writer.""" """Call the write callback without removing the writer."""
self.write_callback(remove_writer=False) self.write_callback(remove_writer=False)
def schedule_write(self, msg: Message = None, future=None): def schedule_write(self, msg: Message = None, future=None) -> None:
queue_is_empty = not self.messages queue_is_empty = not self.messages
if msg is not None: if msg is not None:
self.buffer_message(msg, future) self.buffer_message(msg, future)
@@ -358,7 +358,7 @@ class MessageBus(BaseMessageBus):
return future.result() return future.result()
def send(self, msg: Message): def send(self, msg: Message) -> asyncio.Future:
"""Asynchronously send a message on the message bus. """Asynchronously send a message on the message bus.
.. note:: This method may change to a couroutine function in the 1.0 .. note:: This method may change to a couroutine function in the 1.0
@@ -418,7 +418,7 @@ class MessageBus(BaseMessageBus):
return handler return handler
def _message_reader(self): def _message_reader(self) -> None:
try: try:
while True: while True:
if self._unmarshaller.unmarshall(): if self._unmarshaller.unmarshall():
@@ -429,13 +429,13 @@ class MessageBus(BaseMessageBus):
except Exception as e: except Exception as e:
self._finalize(e) self._finalize(e)
async def _auth_readline(self): async def _auth_readline(self) -> str:
buf = b"" buf = b""
while buf[-2:] != b"\r\n": while buf[-2:] != b"\r\n":
buf += await self._loop.sock_recv(self._sock, 2) buf += await self._loop.sock_recv(self._sock, 2)
return buf[:-2].decode() return buf[:-2].decode()
async def _authenticate(self): async def _authenticate(self) -> None:
await self._loop.sock_sendall(self._sock, b"\0") await self._loop.sock_sendall(self._sock, b"\0")
first_line = self._auth._authentication_start( first_line = self._auth._authentication_start(
@@ -459,7 +459,7 @@ class MessageBus(BaseMessageBus):
if response == "BEGIN": if response == "BEGIN":
break break
def disconnect(self): def disconnect(self) -> None:
"""Disconnect the message bus by closing the underlying connection asynchronously. """Disconnect the message bus by closing the underlying connection asynchronously.
All pending and future calls will error with a connection error. All pending and future calls will error with a connection error.

View File

@@ -60,7 +60,7 @@ class BaseMessageBus:
bus_address: Optional[str] = None, bus_address: Optional[str] = None,
bus_type: BusType = BusType.SESSION, bus_type: BusType = BusType.SESSION,
ProxyObject: Optional[Type[BaseProxyObject]] = None, ProxyObject: Optional[Type[BaseProxyObject]] = None,
): ) -> None:
self.unique_name = None self.unique_name = None
self._disconnected = False self._disconnected = False
@@ -98,12 +98,12 @@ class BaseMessageBus:
self._setup_socket() self._setup_socket()
@property @property
def connected(self): def connected(self) -> bool:
if self.unique_name is None or self._disconnected or self._user_disconnect: if self.unique_name is None or self._disconnected or self._user_disconnect:
return False return False
return True return True
def export(self, path: str, interface: ServiceInterface): def export(self, path: str, interface: ServiceInterface) -> None:
"""Export the service interface on this message bus to make it available """Export the service interface on this message bus to make it available
to other clients. to other clients.
@@ -136,7 +136,7 @@ class BaseMessageBus:
def unexport( def unexport(
self, path: str, interface: Optional[Union[ServiceInterface, str]] = None self, path: str, interface: Optional[Union[ServiceInterface, str]] = None
): ) -> None:
"""Unexport the path or service interface to make it no longer """Unexport the path or service interface to make it no longer
available to clients. available to clients.
@@ -190,7 +190,7 @@ class BaseMessageBus:
bus_name: str, bus_name: str,
path: str, path: str,
callback: Callable[[Optional[intr.Node], Optional[Exception]], None], callback: Callable[[Optional[intr.Node], Optional[Exception]], None],
): ) -> None:
"""Get introspection data for the node at the given path from the given """Get introspection data for the node at the given path from the given
bus name. bus name.
@@ -231,7 +231,7 @@ class BaseMessageBus:
reply_notify, reply_notify,
) )
def _emit_interface_added(self, path, interface): def _emit_interface_added(self, path: str, interface: ServiceInterface) -> None:
"""Emit the ``org.freedesktop.DBus.ObjectManager.InterfacesAdded`` signal. """Emit the ``org.freedesktop.DBus.ObjectManager.InterfacesAdded`` signal.
This signal is intended to be used to alert clients when This signal is intended to be used to alert clients when
@@ -427,7 +427,7 @@ class BaseMessageBus:
return self._ProxyObject(bus_name, path, introspection, self) return self._ProxyObject(bus_name, path, introspection, self)
def disconnect(self): def disconnect(self) -> None:
"""Disconnect the message bus by closing the underlying connection asynchronously. """Disconnect the message bus by closing the underlying connection asynchronously.
All pending and future calls will error with a connection error. All pending and future calls will error with a connection error.
@@ -634,7 +634,9 @@ class BaseMessageBus:
if err: if err:
raise err raise err
def _call(self, msg, callback, check_callback: bool = True) -> None: def _call(
self, msg: Message, callback: Callable, check_callback: bool = True
) -> None:
if check_callback: if check_callback:
BaseMessageBus._check_callback_type(callback) BaseMessageBus._check_callback_type(callback)
@@ -661,7 +663,7 @@ class BaseMessageBus:
callback(None, None) callback(None, None)
@staticmethod @staticmethod
def _check_callback_type(callback): def _check_callback_type(callback: Callable) -> None:
"""Raise a TypeError if the user gives an invalid callback as a parameter""" """Raise a TypeError if the user gives an invalid callback as a parameter"""
text = "a callback must be callable with two parameters" text = "a callback must be callable with two parameters"
@@ -674,7 +676,7 @@ class BaseMessageBus:
raise TypeError(text) raise TypeError(text)
@staticmethod @staticmethod
def _check_method_return(msg, err, signature): def _check_method_return(msg: Message, err: Exception, signature: str) -> None:
if err: if err:
raise err raise err
elif ( elif (
@@ -688,7 +690,7 @@ class BaseMessageBus:
ErrorType.INTERNAL_ERROR, "invalid message type for method call", msg ErrorType.INTERNAL_ERROR, "invalid message type for method call", msg
) )
def _on_message(self, msg): def _on_message(self, msg: Message) -> None:
try: try:
self._process_message(msg) self._process_message(msg)
except Exception as e: except Exception as e:
@@ -696,7 +698,7 @@ class BaseMessageBus:
f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}" f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}"
) )
def _send_reply(self, msg): def _send_reply(self, msg: Message):
bus = self bus = self
class SendReply: class SendReply:
@@ -735,7 +737,7 @@ class BaseMessageBus:
return SendReply() return SendReply()
def _process_message(self, msg): def _process_message(self, msg: Message) -> None:
handled = False handled = False
for handler in self._user_message_handlers: for handler in self._user_message_handlers:
@@ -820,7 +822,9 @@ class BaseMessageBus:
return handler return handler
def _find_message_handler(self, msg): def _find_message_handler(
self, msg: Message
) -> Optional[Callable[[Message, Callable], None]]:
handler = None handler = None
if msg._matches( if msg._matches(
@@ -860,14 +864,20 @@ class BaseMessageBus:
return handler return handler
def _default_introspect_handler(self, msg, send_reply): def _default_introspect_handler(
self, msg: Message, send_reply: Callable[[Message], None]
) -> None:
introspection = self._introspect_export_path(msg.path).tostring() introspection = self._introspect_export_path(msg.path).tostring()
send_reply(Message.new_method_return(msg, "s", [introspection])) send_reply(Message.new_method_return(msg, "s", [introspection]))
def _default_ping_handler(self, msg, send_reply): def _default_ping_handler(
self, msg: Message, send_reply: Callable[[Message], None]
) -> None:
send_reply(Message.new_method_return(msg)) send_reply(Message.new_method_return(msg))
def _default_get_machine_id_handler(self, msg, send_reply): def _default_get_machine_id_handler(
self, msg: Message, send_reply: Callable[[Message], None]
) -> None:
if self._machine_id: if self._machine_id:
send_reply(Message.new_method_return(msg, "s", self._machine_id)) send_reply(Message.new_method_return(msg, "s", self._machine_id))
return return
@@ -898,7 +908,9 @@ class BaseMessageBus:
check_callback=False, check_callback=False,
) )
def _default_get_managed_objects_handler(self, msg, send_reply): def _default_get_managed_objects_handler(
self, msg: Message, send_reply: Callable[[Message], None]
) -> None:
result = {} result = {}
result_signature = "a{oa{sa{sv}}}" result_signature = "a{oa{sa{sv}}}"
error_handled = False error_handled = False
@@ -948,7 +960,9 @@ class BaseMessageBus:
interface, get_all_properties_callback, node interface, get_all_properties_callback, node
) )
def _default_properties_handler(self, msg, send_reply): def _default_properties_handler(
self, msg: Message, send_reply: Callable[[Message], None]
) -> None:
methods = {"Get": "ss", "Set": "ssv", "GetAll": "s"} methods = {"Get": "ss", "Set": "ssv", "GetAll": "s"}
if msg.member not in methods or methods[msg.member] != msg.signature: if msg.member not in methods or methods[msg.member] != msg.signature:
raise DBusError( raise DBusError(
@@ -1090,7 +1104,7 @@ class BaseMessageBus:
else: else:
assert False assert False
def _init_high_level_client(self): def _init_high_level_client(self) -> None:
"""The high level client is initialized when the first proxy object is """The high level client is initialized when the first proxy object is
gotten. Currently just sets up the match rules for the name owner cache gotten. Currently just sets up the match rules for the name owner cache
so signals can be routed to the right objects.""" so signals can be routed to the right objects."""
@@ -1134,7 +1148,7 @@ class BaseMessageBus:
self._match_rules[match_rule] = 1 self._match_rules[match_rule] = 1
def add_match_notify(msg, err): def add_match_notify(msg: Message, err: Optional[Exception]) -> None:
if err: if err:
logging.error(f'add match request failed. match="{match_rule}", {err}') logging.error(f'add match request failed. match="{match_rule}", {err}')
if msg.message_type == MessageType.ERROR: if msg.message_type == MessageType.ERROR: