From 640e1f8d87a753d6721dae77ee94ff8702a2f508 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 5 Mar 2025 13:05:46 -1000 Subject: [PATCH] chore: upgrade typing on private modules (#402) * chore: upgrade typing on private modules * chore: typing fixes --- src/dbus_fast/_private/__init__.py | 1 + src/dbus_fast/_private/_cython_compat.py | 2 + src/dbus_fast/_private/address.py | 5 ++- src/dbus_fast/_private/constants.py | 2 + src/dbus_fast/_private/marshaller.py | 25 +++++------ src/dbus_fast/_private/unmarshaller.pxd | 1 + src/dbus_fast/_private/unmarshaller.py | 54 +++++++++++++----------- src/dbus_fast/_private/util.py | 42 +++++++++--------- src/dbus_fast/constants.py | 2 +- 9 files changed, 73 insertions(+), 61 deletions(-) diff --git a/src/dbus_fast/_private/__init__.py b/src/dbus_fast/_private/__init__.py index e69de29..9d48db4 100644 --- a/src/dbus_fast/_private/__init__.py +++ b/src/dbus_fast/_private/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/src/dbus_fast/_private/_cython_compat.py b/src/dbus_fast/_private/_cython_compat.py index 27c9626..071a1c8 100644 --- a/src/dbus_fast/_private/_cython_compat.py +++ b/src/dbus_fast/_private/_cython_compat.py @@ -1,5 +1,7 @@ """Stub for when Cython is not available.""" +from __future__ import annotations + class FakeCython: """Stub for when Cython is not available.""" diff --git a/src/dbus_fast/_private/address.py b/src/dbus_fast/_private/address.py index ee150f2..9c3eb63 100644 --- a/src/dbus_fast/_private/address.py +++ b/src/dbus_fast/_private/address.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import os import re -from typing import Optional from urllib.parse import unquote from ..constants import BusType @@ -85,7 +86,7 @@ def get_session_bus_address() -> str: machine_id = f.read().rstrip() dbus_info_file_name = f"{home}/.dbus/session-bus/{machine_id}-{display}" - dbus_info: Optional[str] = None + dbus_info: str | None = None try: with open(dbus_info_file_name) as f: dbus_info = f.read().rstrip() diff --git a/src/dbus_fast/_private/constants.py b/src/dbus_fast/_private/constants.py index 605c3cf..5d12e5f 100644 --- a/src/dbus_fast/_private/constants.py +++ b/src/dbus_fast/_private/constants.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum PROTOCOL_VERSION = 1 diff --git a/src/dbus_fast/_private/marshaller.py b/src/dbus_fast/_private/marshaller.py index 18e0232..7e3c727 100644 --- a/src/dbus_fast/_private/marshaller.py +++ b/src/dbus_fast/_private/marshaller.py @@ -1,5 +1,6 @@ +from __future__ import annotations from struct import Struct, error -from typing import Any, Callable, Optional, Union +from typing import Any, Callable from ..signature import SignatureType, Variant, get_signature_tree @@ -85,21 +86,21 @@ class Marshaller: signature = variant.signature signature_bytes = signature.encode() written = self._write_signature(signature_bytes) - written += self._write_single(variant.type, variant.value) # type: ignore[has-type] + written += self._write_single(variant.type, variant.value) return written def write_array( - self, array: Union[list[Any], dict[Any, Any]], type_: SignatureType + self, array: bytes | list[Any] | dict[Any, Any], type_: SignatureType ) -> int: return self._write_array(array, type_) def _write_array( - self, array: Union[list[Any], dict[Any, Any]], type_: SignatureType + self, array: bytes | list[Any] | dict[Any, Any], type_: SignatureType ) -> int: # TODO max array size is 64MiB (67108864 bytes) written = self._align(4) # length placeholder - buf = self._buf + buf: bytearray = self._buf offset = len(buf) written += self._align(4) + 4 buf += PACKED_UINT32_ZERO @@ -116,7 +117,7 @@ class Marshaller: array_len += self.write_dict_entry([key, value], child_type) elif token == "y": array_len = len(array) - buf += array + buf += array # type: ignore[arg-type] elif token == "(": for value in array: array_len += self._write_struct(value, child_type) @@ -136,14 +137,10 @@ class Marshaller: return written + array_len - def write_struct( - self, array: Union[tuple[Any], list[Any]], type_: SignatureType - ) -> int: + def write_struct(self, array: tuple[Any] | list[Any], type_: SignatureType) -> int: return self._write_struct(array, type_) - def _write_struct( - self, array: Union[tuple[Any], list[Any]], type_: SignatureType - ) -> int: + def _write_struct(self, array: tuple[Any] | list[Any], type_: SignatureType) -> int: written = self._align(8) for i, value in enumerate(array): written += self._write_single(type_.children[i], value) @@ -204,8 +201,8 @@ class Marshaller: _writers: dict[ str, tuple[ - Optional[Callable[[Any, Any, SignatureType], int]], - Optional[Callable[[Any], bytes]], + Callable[[Any, Any, SignatureType], int] | None, + Callable[[Any], bytes] | None, int, ], ] = { diff --git a/src/dbus_fast/_private/unmarshaller.pxd b/src/dbus_fast/_private/unmarshaller.pxd index 98363d9..89fd7c3 100644 --- a/src/dbus_fast/_private/unmarshaller.pxd +++ b/src/dbus_fast/_private/unmarshaller.pxd @@ -5,6 +5,7 @@ import cython from ..message cimport Message from ..signature cimport SignatureTree, SignatureType, Variant +cdef bint TYPE_CHECKING cdef object MAX_UNIX_FDS_SIZE cdef object ARRAY diff --git a/src/dbus_fast/_private/unmarshaller.py b/src/dbus_fast/_private/unmarshaller.py index 85ca17f..f329bf2 100644 --- a/src/dbus_fast/_private/unmarshaller.py +++ b/src/dbus_fast/_private/unmarshaller.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import array import errno import io @@ -5,7 +7,7 @@ import socket import sys from collections.abc import Iterable from struct import Struct -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, TYPE_CHECKING from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP, MessageFlag from ..errors import InvalidMessageError @@ -144,7 +146,7 @@ HEADER_SENDER_IDX = HEADER_IDX_TO_ARG_NAME.index("sender") HEADER_SIGNATURE_IDX = HEADER_IDX_TO_ARG_NAME.index("signature") HEADER_UNIX_FDS_IDX = HEADER_IDX_TO_ARG_NAME.index("unix_fds") -_EMPTY_HEADERS = [None] * len(HEADER_IDX_TO_ARG_NAME) +_EMPTY_HEADERS: list[Any | None] = [None] * len(HEADER_IDX_TO_ARG_NAME) _SignatureType = SignatureType _int = int @@ -159,7 +161,7 @@ DEFAULT_BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE def unpack_parser_factory(unpack_from: Callable, size: int) -> READER_TYPE: """Build a parser that unpacks the bytes using the given unpack_from function.""" - def _unpack_from_parser(self: "Unmarshaller", signature: SignatureType) -> Any: + def _unpack_from_parser(self: Unmarshaller, signature: SignatureType) -> Any: self._pos += size + (-self._pos & (size - 1)) # align return unpack_from(self._buf, self._pos - size)[0] @@ -168,7 +170,7 @@ def unpack_parser_factory(unpack_from: Callable, size: int) -> READER_TYPE: def build_simple_parsers( endian: int, -) -> dict[str, Callable[["Unmarshaller", SignatureType], Any]]: +) -> dict[str, Callable[[Unmarshaller, SignatureType], Any]]: """Build a dict of parsers for simple types.""" parsers: dict[str, READER_TYPE] = {} for dbus_type, ctype_size in DBUS_TO_CTYPE.items(): @@ -238,15 +240,15 @@ class Unmarshaller: def __init__( self, - stream: Optional[io.BufferedRWPair] = None, - sock: Optional[socket.socket] = None, + stream: io.BufferedRWPair | None = None, + sock: socket.socket | None = None, negotiate_unix_fd: bool = True, ) -> None: self._unix_fds: list[int] = [] - self._buf = bytearray.__new__(bytearray) # Actual buffer + self._buf: bytearray = bytearray.__new__(bytearray) # Actual buffer self._stream = stream self._sock = sock - self._message: Optional[Message] = None + self._message: Message | None = None self._readers: dict[str, READER_TYPE] = {} self._pos = 0 self._body_len = 0 @@ -256,20 +258,24 @@ class Unmarshaller: self._flag = 0 self._msg_len = 0 self._is_native = 0 - self._uint32_unpack: Optional[Callable] = None - self._int16_unpack: Optional[Callable] = None - self._uint16_unpack: Optional[Callable] = None - self._stream_reader: Optional[Callable] = None + self._uint32_unpack: Callable | None = None + self._int16_unpack: Callable | None = None + self._uint16_unpack: Callable | None = None + self._stream_reader: Callable | None = None self._negotiate_unix_fd = negotiate_unix_fd self._read_complete = False if stream: if isinstance(stream, io.BufferedRWPair) and hasattr(stream, "reader"): - self._stream_reader = stream.reader.read # type: ignore[attr-defined] + self._stream_reader = stream.reader.read self._stream_reader = stream.read elif self._negotiate_unix_fd: + if TYPE_CHECKING: + assert self._sock is not None self._sock_reader = self._sock.recvmsg else: - self._sock_reader = self._sock.recv + if TYPE_CHECKING: + assert self._sock is not None + self._sock_reader = self._sock.recv # type: ignore[assignment] self._endian = 0 def _next_message(self) -> None: @@ -289,7 +295,7 @@ class Unmarshaller: # every time a new message is processed. @property - def message(self) -> Optional[Message]: + def message(self) -> Message | None: """Return the message that has been unmarshalled.""" if self._read_complete: return self._message @@ -309,7 +315,7 @@ class Unmarshaller: # This will raise BlockingIOError if there is no data to read # which we store in the MARSHALL_STREAM_END_ERROR object try: - recv = self._sock_reader(missing_bytes, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr] + recv = self._sock_reader(missing_bytes, UNIX_FDS_CMSG_LENGTH) except OSError as e: errno = e.errno if errno == EAGAIN or errno == EWOULDBLOCK: @@ -340,7 +346,7 @@ class Unmarshaller: # which we store in the MARSHALL_STREAM_END_ERROR object while True: try: - data = self._sock_reader(DEFAULT_BUFFER_SIZE) # type: ignore[union-attr] + data = self._sock_reader(DEFAULT_BUFFER_SIZE) except OSError as e: errno = e.errno if errno == EAGAIN or errno == EWOULDBLOCK: @@ -348,11 +354,11 @@ class Unmarshaller: raise if not data: raise EOFError() - self._buf += data + self._buf += data # type: ignore[arg-type] if len(self._buf) >= pos: return - def _read_stream(self, pos: _int, missing_bytes: _int) -> bytes: + def _read_stream(self, pos: _int, missing_bytes: _int) -> None: """Read from the stream.""" data = self._stream_reader(missing_bytes) # type: ignore[misc] if data is None: @@ -577,7 +583,7 @@ class Unmarshaller: ) and child_1_token_as_int == TOKEN_V_AS_INT: while self._pos - beginning_pos < array_length: self._pos += -self._pos & 7 # align 8 - key: Union[str, int] = self._read_string_unpack() + key: str | int = self._read_string_unpack() result_dict[key] = self._read_variant() elif ( child_0_token_as_int == TOKEN_Q_AS_INT @@ -785,7 +791,7 @@ class Unmarshaller: self._message = message self._read_complete = True - def unmarshall(self) -> Optional[Message]: + def unmarshall(self) -> Message | None: """Unmarshall the message. The underlying read function will raise BlockingIOError if the @@ -794,7 +800,7 @@ class Unmarshaller: """ return self._unmarshall() - def _unmarshall(self) -> Optional[Message]: + def _unmarshall(self) -> Message | None: """Unmarshall the message. The underlying read function will raise BlockingIOError if the @@ -811,9 +817,7 @@ class Unmarshaller: return None return self._message - _complex_parsers_unpack: dict[ - str, Callable[["Unmarshaller", SignatureType], Any] - ] = { + _complex_parsers_unpack: dict[str, Callable[[Unmarshaller, SignatureType], Any]] = { "b": read_boolean, "o": read_string_unpack, "s": read_string_unpack, diff --git a/src/dbus_fast/_private/util.py b/src/dbus_fast/_private/util.py index b7c2366..528b1cb 100644 --- a/src/dbus_fast/_private/util.py +++ b/src/dbus_fast/_private/util.py @@ -1,22 +1,22 @@ +from __future__ import annotations + import ast import inspect -from typing import Any, Union +from typing import Any, Callable -from ..signature import SignatureTree, Variant, get_signature_tree +from ..signature import SignatureTree, Variant, get_signature_tree, SignatureType def signature_contains_type( - signature: Union[str, SignatureTree], body: list[Any], token: str + signature: str | SignatureTree, body: list[Any], token: str ) -> bool: """For a given signature and body, check to see if it contains any members with the given token""" if type(signature) is str: signature = get_signature_tree(signature) - queue = [] + queue = list(signature.types) # type: ignore[union-attr] contains_variants = False - for st in signature.types: - queue.append(st) while True: if not queue: @@ -49,7 +49,7 @@ def signature_contains_type( def replace_fds_with_idx( - signature: Union[str, SignatureTree], body: list[Any] + signature: str | SignatureTree, body: list[Any] ) -> tuple[list[Any], list[int]]: """Take the high level body format and convert it into the low level body format. Type 'h' refers directly to the fd in the body. Replace that with @@ -61,22 +61,22 @@ def replace_fds_with_idx( if not signature_contains_type(signature, body, "h"): return body, [] - unix_fds = [] + unix_fds: list[Any] = [] - def _replace(fd): + def _replace(fd: Any) -> int: try: return unix_fds.index(fd) except ValueError: unix_fds.append(fd) return len(unix_fds) - 1 - _replace_fds(body, signature.types, _replace) + _replace_fds(body, signature.types, _replace) # type: ignore[union-attr] return body, unix_fds def replace_idx_with_fds( - signature: Union[str, SignatureTree], body: list[Any], unix_fds: list[int] + signature: str | SignatureTree, body: list[Any], unix_fds: list[Any] ) -> list[Any]: """Take the low level body format and return the high level body format. Type 'h' refers to an index in the unix_fds array. Replace those with the @@ -87,13 +87,13 @@ def replace_idx_with_fds( if not signature_contains_type(signature, body, "h"): return body - def _replace(idx): + def _replace(idx: int) -> Any: try: return unix_fds[idx] except IndexError: return None - _replace_fds(body, signature.types, _replace) + _replace_fds(body, signature.types, _replace) # type: ignore[union-attr] return body @@ -107,7 +107,7 @@ def parse_annotation(annotation: str) -> str: constant. """ - def raise_value_error(): + def raise_value_error() -> None: raise ValueError( f"service annotations must be a string constant (got {annotation})" ) @@ -118,17 +118,21 @@ def parse_annotation(annotation: str) -> str: raise_value_error() try: body = ast.parse(annotation).body - if len(body) == 1 and type(body[0].value) is ast.Constant: - if type(body[0].value.value) is not str: + if len(body) == 1 and type(body[0].value) is ast.Constant: # type: ignore[attr-defined] + if type(body[0].value.value) is not str: # type: ignore[attr-defined] raise_value_error() - return body[0].value.value + return body[0].value.value # type: ignore[attr-defined] except SyntaxError: pass return annotation -def _replace_fds(body_obj: list[Any], children, replace_fn): +def _replace_fds( + body_obj: dict[Any, Any] | list[Any], + children: list[SignatureType], + replace_fn: Callable[[Any], Any], +) -> None: """Replace any type 'h' with the value returned by replace_fn() given the value of the fd field. This is used by the high level interfaces which allow type 'h' to be the fd directly instead of an index in an external @@ -150,7 +154,7 @@ def _replace_fds(body_obj: list[Any], children, replace_fn): elif st.token in "(": _replace_fds(body_obj[index], st.children, replace_fn) elif st.token in "{": - for key, value in list(body_obj.items()): + for key, value in list(body_obj.items()): # type: ignore[union-attr] body_obj.pop(key) if st.children[0].signature == "h": key = replace_fn(key) diff --git a/src/dbus_fast/constants.py b/src/dbus_fast/constants.py index bc23461..4078403 100644 --- a/src/dbus_fast/constants.py +++ b/src/dbus_fast/constants.py @@ -20,7 +20,7 @@ class MessageType(Enum): SIGNAL = 4 #: A broadcast signal to subscribed connections @cached_property - def value(self) -> str: + def value(self) -> int: """Return the value.""" return self._value_