fix: cleanup typing in marshaller and unmarshaller (#190)

This commit is contained in:
J. Nick Koston 2022-12-24 09:42:10 -10:00 committed by GitHub
parent 1f42d28c47
commit 830183e188
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 39 deletions

View File

@ -10,6 +10,10 @@ PACKED_UINT32_ZERO = PACK_UINT32(0)
PACKED_BOOL_FALSE = PACK_UINT32(int(0))
PACKED_BOOL_TRUE = PACK_UINT32(int(1))
_int = int
_bytes = bytes
_str = str
class Marshaller:
"""Marshall data for Dbus."""
@ -29,10 +33,10 @@ class Marshaller:
def _buffer(self) -> bytearray:
return self._buf
def align(self, n):
def align(self, n: _int) -> int:
return self._align(n)
def _align(self, n):
def _align(self, n: _int) -> _int:
offset = n - len(self._buf) % n
if offset == 0 or offset == n:
return 0
@ -51,7 +55,7 @@ class Marshaller:
def write_signature(self, signature: str, type_: SignatureType) -> int:
return self._write_signature(signature.encode())
def _write_signature(self, signature_bytes) -> int:
def _write_signature(self, signature_bytes: _bytes) -> int:
signature_len = len(signature_bytes)
buf = self._buf
buf.append(signature_len)
@ -59,10 +63,10 @@ class Marshaller:
buf.append(0)
return signature_len + 2
def write_string(self, value, type_: SignatureType) -> int:
def write_string(self, value: _str, type_: SignatureType) -> int:
return self._write_string(value)
def _write_string(self, value) -> int:
def _write_string(self, value: _str) -> int:
value_bytes = value.encode()
value_len = len(value)
written = self._align(4) + 4
@ -81,7 +85,7 @@ class Marshaller:
signature = variant.signature
signature_bytes = signature.encode()
written = self._write_signature(signature_bytes)
written += self._write_single(variant.type, variant.value)
written += self._write_single(variant.type, variant.value) # type: ignore[has-type]
return written
def write_array(
@ -184,6 +188,7 @@ class Marshaller:
raise NotImplementedError(f'type is not implemented yet: "{ex.args}"')
except error:
self.signature_tree.verify(self.body)
raise RuntimeError("should not reach here")
def _construct_buffer(self) -> bytearray:
self._buf.clear()

View File

@ -3,7 +3,7 @@ import io
import socket
import sys
from struct import Struct
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP
from ..errors import InvalidMessageError
@ -114,6 +114,7 @@ HEADER_MESSAGE_ARG_NAME = {
9: "unix_fds",
}
_SignatureType = SignatureType
READER_TYPE = Callable[["Unmarshaller", SignatureType], Any]
@ -210,7 +211,7 @@ class Unmarshaller:
self._stream_reader: Optional[Callable] = None
if self._sock is None:
if isinstance(stream, io.BufferedRWPair) and hasattr(stream, "reader"):
self._stream_reader = stream.reader.read
self._stream_reader = stream.reader.read # type: ignore[attr-defined]
self._stream_reader = stream.read
def reset(self) -> None:
@ -240,7 +241,7 @@ class Unmarshaller:
# every time a new message is processed.
@property
def message(self) -> Message:
def message(self) -> Optional[Message]:
"""Return the message that has been unmarshalled."""
return self._message
@ -249,7 +250,7 @@ class Unmarshaller:
from the read itself"""
# This will raise BlockingIOError if there is no data to read
# which we store in the MARSHALL_STREAM_END_ERROR object
msg, ancdata, _flags, _addr = self._sock.recvmsg(length, UNIX_FDS_CMSG_LENGTH)
msg, ancdata, _flags, _addr = self._sock.recvmsg(length, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr]
for level, type_, data in ancdata:
if not (level == SOL_SOCKET and type_ == SCM_RIGHTS):
continue
@ -275,7 +276,7 @@ class Unmarshaller:
start_len = len(self._buf)
missing_bytes = pos - (start_len - self._pos)
if self._sock is None:
data = self._stream_reader(missing_bytes)
data = self._stream_reader(missing_bytes) # type: ignore[misc]
else:
data = self._read_sock(missing_bytes)
if data == b"":
@ -286,46 +287,46 @@ class Unmarshaller:
if len(data) + start_len != pos:
raise MARSHALL_STREAM_END_ERROR
def read_uint32_unpack(self, type_) -> int:
def read_uint32_unpack(self, type_: _SignatureType) -> int:
return self._read_uint32_unpack()
def _read_uint32_unpack(self) -> int:
self._pos += UINT32_SIZE + (-self._pos & (UINT32_SIZE - 1)) # align
if self._is_native and cython.compiled:
return _cast_uint32_native( # pragma: no cover
return _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
self._buf, self._pos - UINT32_SIZE
)
return self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0]
return self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] # type: ignore[misc]
def read_uint16_unpack(self, type_) -> int:
def read_uint16_unpack(self, type_: _SignatureType) -> int:
return self._read_uint16_unpack()
def _read_uint16_unpack(self) -> int:
self._pos += UINT16_SIZE + (-self._pos & (UINT16_SIZE - 1)) # align
if self._is_native and cython.compiled:
return _cast_uint16_native( # pragma: no cover
return _cast_uint16_native( # type: ignore[name-defined] # pragma: no cover
self._buf, self._pos - UINT16_SIZE
)
return self._uint16_unpack(self._buf, self._pos - UINT16_SIZE)[0]
return self._uint16_unpack(self._buf, self._pos - UINT16_SIZE)[0] # type: ignore[misc]
def read_int16_unpack(self, type_) -> int:
def read_int16_unpack(self, type_: _SignatureType) -> int:
return self._read_int16_unpack()
def _read_int16_unpack(self) -> int:
self._pos += INT16_SIZE + (-self._pos & (INT16_SIZE - 1)) # align
if self._is_native and cython.compiled:
return _cast_int16_native( # pragma: no cover
return _cast_int16_native( # type: ignore[name-defined] # pragma: no cover
self._buf, self._pos - INT16_SIZE
)
return self._int16_unpack(self._buf, self._pos - INT16_SIZE)[0]
return self._int16_unpack(self._buf, self._pos - INT16_SIZE)[0] # type: ignore[misc]
def read_boolean(self, type_) -> bool:
def read_boolean(self, type_: _SignatureType) -> bool:
return self._read_boolean()
def _read_boolean(self) -> bool:
return bool(self._read_uint32_unpack())
def read_string_unpack(self, type_) -> str:
def read_string_unpack(self, type_: _SignatureType) -> str:
return self._read_string_unpack()
def _read_string_unpack(self) -> str:
@ -335,13 +336,13 @@ class Unmarshaller:
# read terminating '\0' byte as well (str_length + 1)
if self._is_native and cython.compiled:
self._pos += ( # pragma: no cover
_cast_uint32_native(self._buf, str_start - UINT32_SIZE) + 1
_cast_uint32_native(self._buf, str_start - UINT32_SIZE) + 1 # type: ignore[name-defined]
)
else:
self._pos += self._uint32_unpack(self._buf, str_start - UINT32_SIZE)[0] + 1
self._pos += self._uint32_unpack(self._buf, str_start - UINT32_SIZE)[0] + 1 # type: ignore[misc]
return self._buf[str_start : self._pos - 1].decode()
def read_signature(self, type_) -> str:
def read_signature(self, type_: _SignatureType) -> str:
return self._read_signature()
def _read_signature(self) -> str:
@ -351,7 +352,7 @@ class Unmarshaller:
self._pos = o + signature_len + 1
return self._buf[o : o + signature_len].decode()
def read_variant(self, type_) -> Variant:
def read_variant(self, type_: _SignatureType) -> Variant:
return self._read_variant()
def _read_variant(self) -> Variant:
@ -402,33 +403,35 @@ class Unmarshaller:
False,
)
def read_struct(self, type_) -> List[Any]:
def read_struct(self, type_: _SignatureType) -> List[Any]:
self._pos += -self._pos & 7 # align 8
readers = self._readers
return [
readers[child_type.token](self, child_type) for child_type in type_.children
]
def read_dict_entry(self, type_) -> Tuple[Any, Any]:
def read_dict_entry(self, type_: _SignatureType) -> Tuple[Any, Any]:
self._pos += -self._pos & 7 # align 8
return self._readers[type_.children[0].token](
self, type_.children[0]
), self._readers[type_.children[1].token](self, type_.children[1])
def read_array(self, type_) -> Iterable[Any]:
def read_array(self, type_: _SignatureType) -> Iterable[Any]:
return self._read_array(type_)
def _read_array(self, type_) -> Iterable[Any]:
def _read_array(self, type_: _SignatureType) -> Iterable[Any]:
self._pos += -self._pos & 3 # align 4 for the array
self._pos += (
-self._pos & (UINT32_SIZE - 1)
) + UINT32_SIZE # align for the uint32
if self._is_native and cython.compiled:
array_length = _cast_uint32_native( # pragma: no cover
self._buf, self._pos - UINT32_SIZE
array_length = (
_cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
self._buf, self._pos - UINT32_SIZE
)
)
else:
array_length = self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0]
array_length = self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] # type: ignore[misc]
child_type = type_.children[0]
token = child_type.token
@ -442,7 +445,7 @@ class Unmarshaller:
return self._buf[self._pos - array_length : self._pos]
if token == "{":
result_dict = {}
result_dict: Dict[Any, Any] = {}
beginning_pos = self._pos
children = child_type.children
child_0 = children[0]
@ -455,7 +458,7 @@ class Unmarshaller:
if child_0_token in "os" and child_1_token == "v":
while self._pos - beginning_pos < array_length:
self._pos += -self._pos & 7 # align 8
key = self._read_string_unpack()
key: Union[str, int] = self._read_string_unpack()
result_dict[key] = self._read_variant()
elif child_0_token == "q" and child_1_token == "v":
while self._pos - beginning_pos < array_length:
@ -544,9 +547,15 @@ class Unmarshaller:
or (endian == BIG_ENDIAN and SYS_IS_BIG_ENDIAN)
):
self._is_native = 1 # pragma: no cover
self._body_len = _cast_uint32_native(self._buf, 4) # pragma: no cover
self._serial = _cast_uint32_native(self._buf, 8) # pragma: no cover
self._header_len = _cast_uint32_native(self._buf, 12) # pragma: no cover
self._body_len = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
self._buf, 4
)
self._serial = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
self._buf, 8
)
self._header_len = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
self._buf, 12
)
elif endian == LITTLE_ENDIAN:
(
self._body_len,
@ -583,7 +592,7 @@ class Unmarshaller:
signature = header_fields.pop("signature", "")
if not self._body_len:
tree = SIGNATURE_TREE_EMPTY
body = []
body: List[Any] = []
elif signature == "s":
tree = SIGNATURE_TREE_S
body = [self._read_string_unpack()]