From 830183e1887a7abb876813098f17e22550453569 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 24 Dec 2022 09:42:10 -1000 Subject: [PATCH] fix: cleanup typing in marshaller and unmarshaller (#190) --- src/dbus_fast/_private/marshaller.py | 17 +++--- src/dbus_fast/_private/unmarshaller.py | 75 ++++++++++++++------------ 2 files changed, 53 insertions(+), 39 deletions(-) diff --git a/src/dbus_fast/_private/marshaller.py b/src/dbus_fast/_private/marshaller.py index 9a9bff8..d249a28 100644 --- a/src/dbus_fast/_private/marshaller.py +++ b/src/dbus_fast/_private/marshaller.py @@ -10,6 +10,10 @@ PACKED_UINT32_ZERO = PACK_UINT32(0) PACKED_BOOL_FALSE = PACK_UINT32(int(0)) PACKED_BOOL_TRUE = PACK_UINT32(int(1)) +_int = int +_bytes = bytes +_str = str + class Marshaller: """Marshall data for Dbus.""" @@ -29,10 +33,10 @@ class Marshaller: def _buffer(self) -> bytearray: return self._buf - def align(self, n): + def align(self, n: _int) -> int: return self._align(n) - def _align(self, n): + def _align(self, n: _int) -> _int: offset = n - len(self._buf) % n if offset == 0 or offset == n: return 0 @@ -51,7 +55,7 @@ class Marshaller: def write_signature(self, signature: str, type_: SignatureType) -> int: return self._write_signature(signature.encode()) - def _write_signature(self, signature_bytes) -> int: + def _write_signature(self, signature_bytes: _bytes) -> int: signature_len = len(signature_bytes) buf = self._buf buf.append(signature_len) @@ -59,10 +63,10 @@ class Marshaller: buf.append(0) return signature_len + 2 - def write_string(self, value, type_: SignatureType) -> int: + def write_string(self, value: _str, type_: SignatureType) -> int: return self._write_string(value) - def _write_string(self, value) -> int: + def _write_string(self, value: _str) -> int: value_bytes = value.encode() value_len = len(value) written = self._align(4) + 4 @@ -81,7 +85,7 @@ class Marshaller: signature = variant.signature signature_bytes = signature.encode() written = self._write_signature(signature_bytes) - written += self._write_single(variant.type, variant.value) + written += self._write_single(variant.type, variant.value) # type: ignore[has-type] return written def write_array( @@ -184,6 +188,7 @@ class Marshaller: raise NotImplementedError(f'type is not implemented yet: "{ex.args}"') except error: self.signature_tree.verify(self.body) + raise RuntimeError("should not reach here") def _construct_buffer(self) -> bytearray: self._buf.clear() diff --git a/src/dbus_fast/_private/unmarshaller.py b/src/dbus_fast/_private/unmarshaller.py index 00a114e..b4afbe0 100644 --- a/src/dbus_fast/_private/unmarshaller.py +++ b/src/dbus_fast/_private/unmarshaller.py @@ -3,7 +3,7 @@ import io import socket import sys from struct import Struct -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP from ..errors import InvalidMessageError @@ -114,6 +114,7 @@ HEADER_MESSAGE_ARG_NAME = { 9: "unix_fds", } +_SignatureType = SignatureType READER_TYPE = Callable[["Unmarshaller", SignatureType], Any] @@ -210,7 +211,7 @@ class Unmarshaller: self._stream_reader: Optional[Callable] = None if self._sock is None: if isinstance(stream, io.BufferedRWPair) and hasattr(stream, "reader"): - self._stream_reader = stream.reader.read + self._stream_reader = stream.reader.read # type: ignore[attr-defined] self._stream_reader = stream.read def reset(self) -> None: @@ -240,7 +241,7 @@ class Unmarshaller: # every time a new message is processed. @property - def message(self) -> Message: + def message(self) -> Optional[Message]: """Return the message that has been unmarshalled.""" return self._message @@ -249,7 +250,7 @@ class Unmarshaller: from the read itself""" # This will raise BlockingIOError if there is no data to read # which we store in the MARSHALL_STREAM_END_ERROR object - msg, ancdata, _flags, _addr = self._sock.recvmsg(length, UNIX_FDS_CMSG_LENGTH) + msg, ancdata, _flags, _addr = self._sock.recvmsg(length, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr] for level, type_, data in ancdata: if not (level == SOL_SOCKET and type_ == SCM_RIGHTS): continue @@ -275,7 +276,7 @@ class Unmarshaller: start_len = len(self._buf) missing_bytes = pos - (start_len - self._pos) if self._sock is None: - data = self._stream_reader(missing_bytes) + data = self._stream_reader(missing_bytes) # type: ignore[misc] else: data = self._read_sock(missing_bytes) if data == b"": @@ -286,46 +287,46 @@ class Unmarshaller: if len(data) + start_len != pos: raise MARSHALL_STREAM_END_ERROR - def read_uint32_unpack(self, type_) -> int: + def read_uint32_unpack(self, type_: _SignatureType) -> int: return self._read_uint32_unpack() def _read_uint32_unpack(self) -> int: self._pos += UINT32_SIZE + (-self._pos & (UINT32_SIZE - 1)) # align if self._is_native and cython.compiled: - return _cast_uint32_native( # pragma: no cover + return _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover self._buf, self._pos - UINT32_SIZE ) - return self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] + return self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] # type: ignore[misc] - def read_uint16_unpack(self, type_) -> int: + def read_uint16_unpack(self, type_: _SignatureType) -> int: return self._read_uint16_unpack() def _read_uint16_unpack(self) -> int: self._pos += UINT16_SIZE + (-self._pos & (UINT16_SIZE - 1)) # align if self._is_native and cython.compiled: - return _cast_uint16_native( # pragma: no cover + return _cast_uint16_native( # type: ignore[name-defined] # pragma: no cover self._buf, self._pos - UINT16_SIZE ) - return self._uint16_unpack(self._buf, self._pos - UINT16_SIZE)[0] + return self._uint16_unpack(self._buf, self._pos - UINT16_SIZE)[0] # type: ignore[misc] - def read_int16_unpack(self, type_) -> int: + def read_int16_unpack(self, type_: _SignatureType) -> int: return self._read_int16_unpack() def _read_int16_unpack(self) -> int: self._pos += INT16_SIZE + (-self._pos & (INT16_SIZE - 1)) # align if self._is_native and cython.compiled: - return _cast_int16_native( # pragma: no cover + return _cast_int16_native( # type: ignore[name-defined] # pragma: no cover self._buf, self._pos - INT16_SIZE ) - return self._int16_unpack(self._buf, self._pos - INT16_SIZE)[0] + return self._int16_unpack(self._buf, self._pos - INT16_SIZE)[0] # type: ignore[misc] - def read_boolean(self, type_) -> bool: + def read_boolean(self, type_: _SignatureType) -> bool: return self._read_boolean() def _read_boolean(self) -> bool: return bool(self._read_uint32_unpack()) - def read_string_unpack(self, type_) -> str: + def read_string_unpack(self, type_: _SignatureType) -> str: return self._read_string_unpack() def _read_string_unpack(self) -> str: @@ -335,13 +336,13 @@ class Unmarshaller: # read terminating '\0' byte as well (str_length + 1) if self._is_native and cython.compiled: self._pos += ( # pragma: no cover - _cast_uint32_native(self._buf, str_start - UINT32_SIZE) + 1 + _cast_uint32_native(self._buf, str_start - UINT32_SIZE) + 1 # type: ignore[name-defined] ) else: - self._pos += self._uint32_unpack(self._buf, str_start - UINT32_SIZE)[0] + 1 + self._pos += self._uint32_unpack(self._buf, str_start - UINT32_SIZE)[0] + 1 # type: ignore[misc] return self._buf[str_start : self._pos - 1].decode() - def read_signature(self, type_) -> str: + def read_signature(self, type_: _SignatureType) -> str: return self._read_signature() def _read_signature(self) -> str: @@ -351,7 +352,7 @@ class Unmarshaller: self._pos = o + signature_len + 1 return self._buf[o : o + signature_len].decode() - def read_variant(self, type_) -> Variant: + def read_variant(self, type_: _SignatureType) -> Variant: return self._read_variant() def _read_variant(self) -> Variant: @@ -402,33 +403,35 @@ class Unmarshaller: False, ) - def read_struct(self, type_) -> List[Any]: + def read_struct(self, type_: _SignatureType) -> List[Any]: self._pos += -self._pos & 7 # align 8 readers = self._readers return [ readers[child_type.token](self, child_type) for child_type in type_.children ] - def read_dict_entry(self, type_) -> Tuple[Any, Any]: + def read_dict_entry(self, type_: _SignatureType) -> Tuple[Any, Any]: self._pos += -self._pos & 7 # align 8 return self._readers[type_.children[0].token]( self, type_.children[0] ), self._readers[type_.children[1].token](self, type_.children[1]) - def read_array(self, type_) -> Iterable[Any]: + def read_array(self, type_: _SignatureType) -> Iterable[Any]: return self._read_array(type_) - def _read_array(self, type_) -> Iterable[Any]: + def _read_array(self, type_: _SignatureType) -> Iterable[Any]: self._pos += -self._pos & 3 # align 4 for the array self._pos += ( -self._pos & (UINT32_SIZE - 1) ) + UINT32_SIZE # align for the uint32 if self._is_native and cython.compiled: - array_length = _cast_uint32_native( # pragma: no cover - self._buf, self._pos - UINT32_SIZE + array_length = ( + _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover + self._buf, self._pos - UINT32_SIZE + ) ) else: - array_length = self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] + array_length = self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] # type: ignore[misc] child_type = type_.children[0] token = child_type.token @@ -442,7 +445,7 @@ class Unmarshaller: return self._buf[self._pos - array_length : self._pos] if token == "{": - result_dict = {} + result_dict: Dict[Any, Any] = {} beginning_pos = self._pos children = child_type.children child_0 = children[0] @@ -455,7 +458,7 @@ class Unmarshaller: if child_0_token in "os" and child_1_token == "v": while self._pos - beginning_pos < array_length: self._pos += -self._pos & 7 # align 8 - key = self._read_string_unpack() + key: Union[str, int] = self._read_string_unpack() result_dict[key] = self._read_variant() elif child_0_token == "q" and child_1_token == "v": while self._pos - beginning_pos < array_length: @@ -544,9 +547,15 @@ class Unmarshaller: or (endian == BIG_ENDIAN and SYS_IS_BIG_ENDIAN) ): self._is_native = 1 # pragma: no cover - self._body_len = _cast_uint32_native(self._buf, 4) # pragma: no cover - self._serial = _cast_uint32_native(self._buf, 8) # pragma: no cover - self._header_len = _cast_uint32_native(self._buf, 12) # pragma: no cover + self._body_len = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover + self._buf, 4 + ) + self._serial = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover + self._buf, 8 + ) + self._header_len = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover + self._buf, 12 + ) elif endian == LITTLE_ENDIAN: ( self._body_len, @@ -583,7 +592,7 @@ class Unmarshaller: signature = header_fields.pop("signature", "") if not self._body_len: tree = SIGNATURE_TREE_EMPTY - body = [] + body: List[Any] = [] elif signature == "s": tree = SIGNATURE_TREE_S body = [self._read_string_unpack()]