chore: upgrade typing on private modules (#402)

* chore: upgrade typing on private modules

* chore: typing fixes
This commit is contained in:
J. Nick Koston 2025-03-05 13:05:46 -10:00 committed by GitHub
parent dc3d8e7609
commit 640e1f8d87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 73 additions and 61 deletions

View File

@ -0,0 +1 @@
from __future__ import annotations

View File

@ -1,5 +1,7 @@
"""Stub for when Cython is not available."""
from __future__ import annotations
class FakeCython:
"""Stub for when Cython is not available."""

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import os
import re
from typing import Optional
from urllib.parse import unquote
from ..constants import BusType
@ -85,7 +86,7 @@ def get_session_bus_address() -> str:
machine_id = f.read().rstrip()
dbus_info_file_name = f"{home}/.dbus/session-bus/{machine_id}-{display}"
dbus_info: Optional[str] = None
dbus_info: str | None = None
try:
with open(dbus_info_file_name) as f:
dbus_info = f.read().rstrip()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from enum import Enum
PROTOCOL_VERSION = 1

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from struct import Struct, error
from typing import Any, Callable, Optional, Union
from typing import Any, Callable
from ..signature import SignatureType, Variant, get_signature_tree
@ -85,21 +86,21 @@ class Marshaller:
signature = variant.signature
signature_bytes = signature.encode()
written = self._write_signature(signature_bytes)
written += self._write_single(variant.type, variant.value) # type: ignore[has-type]
written += self._write_single(variant.type, variant.value)
return written
def write_array(
self, array: Union[list[Any], dict[Any, Any]], type_: SignatureType
self, array: bytes | list[Any] | dict[Any, Any], type_: SignatureType
) -> int:
return self._write_array(array, type_)
def _write_array(
self, array: Union[list[Any], dict[Any, Any]], type_: SignatureType
self, array: bytes | list[Any] | dict[Any, Any], type_: SignatureType
) -> int:
# TODO max array size is 64MiB (67108864 bytes)
written = self._align(4)
# length placeholder
buf = self._buf
buf: bytearray = self._buf
offset = len(buf)
written += self._align(4) + 4
buf += PACKED_UINT32_ZERO
@ -116,7 +117,7 @@ class Marshaller:
array_len += self.write_dict_entry([key, value], child_type)
elif token == "y":
array_len = len(array)
buf += array
buf += array # type: ignore[arg-type]
elif token == "(":
for value in array:
array_len += self._write_struct(value, child_type)
@ -136,14 +137,10 @@ class Marshaller:
return written + array_len
def write_struct(
self, array: Union[tuple[Any], list[Any]], type_: SignatureType
) -> int:
def write_struct(self, array: tuple[Any] | list[Any], type_: SignatureType) -> int:
return self._write_struct(array, type_)
def _write_struct(
self, array: Union[tuple[Any], list[Any]], type_: SignatureType
) -> int:
def _write_struct(self, array: tuple[Any] | list[Any], type_: SignatureType) -> int:
written = self._align(8)
for i, value in enumerate(array):
written += self._write_single(type_.children[i], value)
@ -204,8 +201,8 @@ class Marshaller:
_writers: dict[
str,
tuple[
Optional[Callable[[Any, Any, SignatureType], int]],
Optional[Callable[[Any], bytes]],
Callable[[Any, Any, SignatureType], int] | None,
Callable[[Any], bytes] | None,
int,
],
] = {

View File

@ -5,6 +5,7 @@ import cython
from ..message cimport Message
from ..signature cimport SignatureTree, SignatureType, Variant
cdef bint TYPE_CHECKING
cdef object MAX_UNIX_FDS_SIZE
cdef object ARRAY

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import array
import errno
import io
@ -5,7 +7,7 @@ import socket
import sys
from collections.abc import Iterable
from struct import Struct
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, TYPE_CHECKING
from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP, MessageFlag
from ..errors import InvalidMessageError
@ -144,7 +146,7 @@ HEADER_SENDER_IDX = HEADER_IDX_TO_ARG_NAME.index("sender")
HEADER_SIGNATURE_IDX = HEADER_IDX_TO_ARG_NAME.index("signature")
HEADER_UNIX_FDS_IDX = HEADER_IDX_TO_ARG_NAME.index("unix_fds")
_EMPTY_HEADERS = [None] * len(HEADER_IDX_TO_ARG_NAME)
_EMPTY_HEADERS: list[Any | None] = [None] * len(HEADER_IDX_TO_ARG_NAME)
_SignatureType = SignatureType
_int = int
@ -159,7 +161,7 @@ DEFAULT_BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE
def unpack_parser_factory(unpack_from: Callable, size: int) -> READER_TYPE:
"""Build a parser that unpacks the bytes using the given unpack_from function."""
def _unpack_from_parser(self: "Unmarshaller", signature: SignatureType) -> Any:
def _unpack_from_parser(self: Unmarshaller, signature: SignatureType) -> Any:
self._pos += size + (-self._pos & (size - 1)) # align
return unpack_from(self._buf, self._pos - size)[0]
@ -168,7 +170,7 @@ def unpack_parser_factory(unpack_from: Callable, size: int) -> READER_TYPE:
def build_simple_parsers(
endian: int,
) -> dict[str, Callable[["Unmarshaller", SignatureType], Any]]:
) -> dict[str, Callable[[Unmarshaller, SignatureType], Any]]:
"""Build a dict of parsers for simple types."""
parsers: dict[str, READER_TYPE] = {}
for dbus_type, ctype_size in DBUS_TO_CTYPE.items():
@ -238,15 +240,15 @@ class Unmarshaller:
def __init__(
self,
stream: Optional[io.BufferedRWPair] = None,
sock: Optional[socket.socket] = None,
stream: io.BufferedRWPair | None = None,
sock: socket.socket | None = None,
negotiate_unix_fd: bool = True,
) -> None:
self._unix_fds: list[int] = []
self._buf = bytearray.__new__(bytearray) # Actual buffer
self._buf: bytearray = bytearray.__new__(bytearray) # Actual buffer
self._stream = stream
self._sock = sock
self._message: Optional[Message] = None
self._message: Message | None = None
self._readers: dict[str, READER_TYPE] = {}
self._pos = 0
self._body_len = 0
@ -256,20 +258,24 @@ class Unmarshaller:
self._flag = 0
self._msg_len = 0
self._is_native = 0
self._uint32_unpack: Optional[Callable] = None
self._int16_unpack: Optional[Callable] = None
self._uint16_unpack: Optional[Callable] = None
self._stream_reader: Optional[Callable] = None
self._uint32_unpack: Callable | None = None
self._int16_unpack: Callable | None = None
self._uint16_unpack: Callable | None = None
self._stream_reader: Callable | None = None
self._negotiate_unix_fd = negotiate_unix_fd
self._read_complete = False
if stream:
if isinstance(stream, io.BufferedRWPair) and hasattr(stream, "reader"):
self._stream_reader = stream.reader.read # type: ignore[attr-defined]
self._stream_reader = stream.reader.read
self._stream_reader = stream.read
elif self._negotiate_unix_fd:
if TYPE_CHECKING:
assert self._sock is not None
self._sock_reader = self._sock.recvmsg
else:
self._sock_reader = self._sock.recv
if TYPE_CHECKING:
assert self._sock is not None
self._sock_reader = self._sock.recv # type: ignore[assignment]
self._endian = 0
def _next_message(self) -> None:
@ -289,7 +295,7 @@ class Unmarshaller:
# every time a new message is processed.
@property
def message(self) -> Optional[Message]:
def message(self) -> Message | None:
"""Return the message that has been unmarshalled."""
if self._read_complete:
return self._message
@ -309,7 +315,7 @@ class Unmarshaller:
# This will raise BlockingIOError if there is no data to read
# which we store in the MARSHALL_STREAM_END_ERROR object
try:
recv = self._sock_reader(missing_bytes, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr]
recv = self._sock_reader(missing_bytes, UNIX_FDS_CMSG_LENGTH)
except OSError as e:
errno = e.errno
if errno == EAGAIN or errno == EWOULDBLOCK:
@ -340,7 +346,7 @@ class Unmarshaller:
# which we store in the MARSHALL_STREAM_END_ERROR object
while True:
try:
data = self._sock_reader(DEFAULT_BUFFER_SIZE) # type: ignore[union-attr]
data = self._sock_reader(DEFAULT_BUFFER_SIZE)
except OSError as e:
errno = e.errno
if errno == EAGAIN or errno == EWOULDBLOCK:
@ -348,11 +354,11 @@ class Unmarshaller:
raise
if not data:
raise EOFError()
self._buf += data
self._buf += data # type: ignore[arg-type]
if len(self._buf) >= pos:
return
def _read_stream(self, pos: _int, missing_bytes: _int) -> bytes:
def _read_stream(self, pos: _int, missing_bytes: _int) -> None:
"""Read from the stream."""
data = self._stream_reader(missing_bytes) # type: ignore[misc]
if data is None:
@ -577,7 +583,7 @@ class Unmarshaller:
) and child_1_token_as_int == TOKEN_V_AS_INT:
while self._pos - beginning_pos < array_length:
self._pos += -self._pos & 7 # align 8
key: Union[str, int] = self._read_string_unpack()
key: str | int = self._read_string_unpack()
result_dict[key] = self._read_variant()
elif (
child_0_token_as_int == TOKEN_Q_AS_INT
@ -785,7 +791,7 @@ class Unmarshaller:
self._message = message
self._read_complete = True
def unmarshall(self) -> Optional[Message]:
def unmarshall(self) -> Message | None:
"""Unmarshall the message.
The underlying read function will raise BlockingIOError if the
@ -794,7 +800,7 @@ class Unmarshaller:
"""
return self._unmarshall()
def _unmarshall(self) -> Optional[Message]:
def _unmarshall(self) -> Message | None:
"""Unmarshall the message.
The underlying read function will raise BlockingIOError if the
@ -811,9 +817,7 @@ class Unmarshaller:
return None
return self._message
_complex_parsers_unpack: dict[
str, Callable[["Unmarshaller", SignatureType], Any]
] = {
_complex_parsers_unpack: dict[str, Callable[[Unmarshaller, SignatureType], Any]] = {
"b": read_boolean,
"o": read_string_unpack,
"s": read_string_unpack,

View File

@ -1,22 +1,22 @@
from __future__ import annotations
import ast
import inspect
from typing import Any, Union
from typing import Any, Callable
from ..signature import SignatureTree, Variant, get_signature_tree
from ..signature import SignatureTree, Variant, get_signature_tree, SignatureType
def signature_contains_type(
signature: Union[str, SignatureTree], body: list[Any], token: str
signature: str | SignatureTree, body: list[Any], token: str
) -> bool:
"""For a given signature and body, check to see if it contains any members
with the given token"""
if type(signature) is str:
signature = get_signature_tree(signature)
queue = []
queue = list(signature.types) # type: ignore[union-attr]
contains_variants = False
for st in signature.types:
queue.append(st)
while True:
if not queue:
@ -49,7 +49,7 @@ def signature_contains_type(
def replace_fds_with_idx(
signature: Union[str, SignatureTree], body: list[Any]
signature: str | SignatureTree, body: list[Any]
) -> tuple[list[Any], list[int]]:
"""Take the high level body format and convert it into the low level body
format. Type 'h' refers directly to the fd in the body. Replace that with
@ -61,22 +61,22 @@ def replace_fds_with_idx(
if not signature_contains_type(signature, body, "h"):
return body, []
unix_fds = []
unix_fds: list[Any] = []
def _replace(fd):
def _replace(fd: Any) -> int:
try:
return unix_fds.index(fd)
except ValueError:
unix_fds.append(fd)
return len(unix_fds) - 1
_replace_fds(body, signature.types, _replace)
_replace_fds(body, signature.types, _replace) # type: ignore[union-attr]
return body, unix_fds
def replace_idx_with_fds(
signature: Union[str, SignatureTree], body: list[Any], unix_fds: list[int]
signature: str | SignatureTree, body: list[Any], unix_fds: list[Any]
) -> list[Any]:
"""Take the low level body format and return the high level body format.
Type 'h' refers to an index in the unix_fds array. Replace those with the
@ -87,13 +87,13 @@ def replace_idx_with_fds(
if not signature_contains_type(signature, body, "h"):
return body
def _replace(idx):
def _replace(idx: int) -> Any:
try:
return unix_fds[idx]
except IndexError:
return None
_replace_fds(body, signature.types, _replace)
_replace_fds(body, signature.types, _replace) # type: ignore[union-attr]
return body
@ -107,7 +107,7 @@ def parse_annotation(annotation: str) -> str:
constant.
"""
def raise_value_error():
def raise_value_error() -> None:
raise ValueError(
f"service annotations must be a string constant (got {annotation})"
)
@ -118,17 +118,21 @@ def parse_annotation(annotation: str) -> str:
raise_value_error()
try:
body = ast.parse(annotation).body
if len(body) == 1 and type(body[0].value) is ast.Constant:
if type(body[0].value.value) is not str:
if len(body) == 1 and type(body[0].value) is ast.Constant: # type: ignore[attr-defined]
if type(body[0].value.value) is not str: # type: ignore[attr-defined]
raise_value_error()
return body[0].value.value
return body[0].value.value # type: ignore[attr-defined]
except SyntaxError:
pass
return annotation
def _replace_fds(body_obj: list[Any], children, replace_fn):
def _replace_fds(
body_obj: dict[Any, Any] | list[Any],
children: list[SignatureType],
replace_fn: Callable[[Any], Any],
) -> None:
"""Replace any type 'h' with the value returned by replace_fn() given the
value of the fd field. This is used by the high level interfaces which
allow type 'h' to be the fd directly instead of an index in an external
@ -150,7 +154,7 @@ def _replace_fds(body_obj: list[Any], children, replace_fn):
elif st.token in "(":
_replace_fds(body_obj[index], st.children, replace_fn)
elif st.token in "{":
for key, value in list(body_obj.items()):
for key, value in list(body_obj.items()): # type: ignore[union-attr]
body_obj.pop(key)
if st.children[0].signature == "h":
key = replace_fn(key)

View File

@ -20,7 +20,7 @@ class MessageType(Enum):
SIGNAL = 4 #: A broadcast signal to subscribed connections
@cached_property
def value(self) -> str:
def value(self) -> int:
"""Return the value."""
return self._value_