fix: cleanup typing in marshaller and unmarshaller (#190)

This commit is contained in:
J. Nick Koston
2022-12-24 09:42:10 -10:00
committed by GitHub
parent 1f42d28c47
commit 830183e188
2 changed files with 53 additions and 39 deletions

View File

@@ -10,6 +10,10 @@ PACKED_UINT32_ZERO = PACK_UINT32(0)
PACKED_BOOL_FALSE = PACK_UINT32(int(0)) PACKED_BOOL_FALSE = PACK_UINT32(int(0))
PACKED_BOOL_TRUE = PACK_UINT32(int(1)) PACKED_BOOL_TRUE = PACK_UINT32(int(1))
_int = int
_bytes = bytes
_str = str
class Marshaller: class Marshaller:
"""Marshall data for Dbus.""" """Marshall data for Dbus."""
@@ -29,10 +33,10 @@ class Marshaller:
def _buffer(self) -> bytearray: def _buffer(self) -> bytearray:
return self._buf return self._buf
def align(self, n): def align(self, n: _int) -> int:
return self._align(n) return self._align(n)
def _align(self, n): def _align(self, n: _int) -> _int:
offset = n - len(self._buf) % n offset = n - len(self._buf) % n
if offset == 0 or offset == n: if offset == 0 or offset == n:
return 0 return 0
@@ -51,7 +55,7 @@ class Marshaller:
def write_signature(self, signature: str, type_: SignatureType) -> int: def write_signature(self, signature: str, type_: SignatureType) -> int:
return self._write_signature(signature.encode()) return self._write_signature(signature.encode())
def _write_signature(self, signature_bytes) -> int: def _write_signature(self, signature_bytes: _bytes) -> int:
signature_len = len(signature_bytes) signature_len = len(signature_bytes)
buf = self._buf buf = self._buf
buf.append(signature_len) buf.append(signature_len)
@@ -59,10 +63,10 @@ class Marshaller:
buf.append(0) buf.append(0)
return signature_len + 2 return signature_len + 2
def write_string(self, value, type_: SignatureType) -> int: def write_string(self, value: _str, type_: SignatureType) -> int:
return self._write_string(value) return self._write_string(value)
def _write_string(self, value) -> int: def _write_string(self, value: _str) -> int:
value_bytes = value.encode() value_bytes = value.encode()
value_len = len(value) value_len = len(value)
written = self._align(4) + 4 written = self._align(4) + 4
@@ -81,7 +85,7 @@ class Marshaller:
signature = variant.signature signature = variant.signature
signature_bytes = signature.encode() signature_bytes = signature.encode()
written = self._write_signature(signature_bytes) written = self._write_signature(signature_bytes)
written += self._write_single(variant.type, variant.value) written += self._write_single(variant.type, variant.value) # type: ignore[has-type]
return written return written
def write_array( def write_array(
@@ -184,6 +188,7 @@ class Marshaller:
raise NotImplementedError(f'type is not implemented yet: "{ex.args}"') raise NotImplementedError(f'type is not implemented yet: "{ex.args}"')
except error: except error:
self.signature_tree.verify(self.body) self.signature_tree.verify(self.body)
raise RuntimeError("should not reach here")
def _construct_buffer(self) -> bytearray: def _construct_buffer(self) -> bytearray:
self._buf.clear() self._buf.clear()

View File

@@ -3,7 +3,7 @@ import io
import socket import socket
import sys import sys
from struct import Struct from struct import Struct
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP
from ..errors import InvalidMessageError from ..errors import InvalidMessageError
@@ -114,6 +114,7 @@ HEADER_MESSAGE_ARG_NAME = {
9: "unix_fds", 9: "unix_fds",
} }
_SignatureType = SignatureType
READER_TYPE = Callable[["Unmarshaller", SignatureType], Any] READER_TYPE = Callable[["Unmarshaller", SignatureType], Any]
@@ -210,7 +211,7 @@ class Unmarshaller:
self._stream_reader: Optional[Callable] = None self._stream_reader: Optional[Callable] = None
if self._sock is None: if self._sock is None:
if isinstance(stream, io.BufferedRWPair) and hasattr(stream, "reader"): if isinstance(stream, io.BufferedRWPair) and hasattr(stream, "reader"):
self._stream_reader = stream.reader.read self._stream_reader = stream.reader.read # type: ignore[attr-defined]
self._stream_reader = stream.read self._stream_reader = stream.read
def reset(self) -> None: def reset(self) -> None:
@@ -240,7 +241,7 @@ class Unmarshaller:
# every time a new message is processed. # every time a new message is processed.
@property @property
def message(self) -> Message: def message(self) -> Optional[Message]:
"""Return the message that has been unmarshalled.""" """Return the message that has been unmarshalled."""
return self._message return self._message
@@ -249,7 +250,7 @@ class Unmarshaller:
from the read itself""" from the read itself"""
# This will raise BlockingIOError if there is no data to read # This will raise BlockingIOError if there is no data to read
# which we store in the MARSHALL_STREAM_END_ERROR object # which we store in the MARSHALL_STREAM_END_ERROR object
msg, ancdata, _flags, _addr = self._sock.recvmsg(length, UNIX_FDS_CMSG_LENGTH) msg, ancdata, _flags, _addr = self._sock.recvmsg(length, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr]
for level, type_, data in ancdata: for level, type_, data in ancdata:
if not (level == SOL_SOCKET and type_ == SCM_RIGHTS): if not (level == SOL_SOCKET and type_ == SCM_RIGHTS):
continue continue
@@ -275,7 +276,7 @@ class Unmarshaller:
start_len = len(self._buf) start_len = len(self._buf)
missing_bytes = pos - (start_len - self._pos) missing_bytes = pos - (start_len - self._pos)
if self._sock is None: if self._sock is None:
data = self._stream_reader(missing_bytes) data = self._stream_reader(missing_bytes) # type: ignore[misc]
else: else:
data = self._read_sock(missing_bytes) data = self._read_sock(missing_bytes)
if data == b"": if data == b"":
@@ -286,46 +287,46 @@ class Unmarshaller:
if len(data) + start_len != pos: if len(data) + start_len != pos:
raise MARSHALL_STREAM_END_ERROR raise MARSHALL_STREAM_END_ERROR
def read_uint32_unpack(self, type_) -> int: def read_uint32_unpack(self, type_: _SignatureType) -> int:
return self._read_uint32_unpack() return self._read_uint32_unpack()
def _read_uint32_unpack(self) -> int: def _read_uint32_unpack(self) -> int:
self._pos += UINT32_SIZE + (-self._pos & (UINT32_SIZE - 1)) # align self._pos += UINT32_SIZE + (-self._pos & (UINT32_SIZE - 1)) # align
if self._is_native and cython.compiled: if self._is_native and cython.compiled:
return _cast_uint32_native( # pragma: no cover return _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
self._buf, self._pos - UINT32_SIZE self._buf, self._pos - UINT32_SIZE
) )
return self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] return self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] # type: ignore[misc]
def read_uint16_unpack(self, type_) -> int: def read_uint16_unpack(self, type_: _SignatureType) -> int:
return self._read_uint16_unpack() return self._read_uint16_unpack()
def _read_uint16_unpack(self) -> int: def _read_uint16_unpack(self) -> int:
self._pos += UINT16_SIZE + (-self._pos & (UINT16_SIZE - 1)) # align self._pos += UINT16_SIZE + (-self._pos & (UINT16_SIZE - 1)) # align
if self._is_native and cython.compiled: if self._is_native and cython.compiled:
return _cast_uint16_native( # pragma: no cover return _cast_uint16_native( # type: ignore[name-defined] # pragma: no cover
self._buf, self._pos - UINT16_SIZE self._buf, self._pos - UINT16_SIZE
) )
return self._uint16_unpack(self._buf, self._pos - UINT16_SIZE)[0] return self._uint16_unpack(self._buf, self._pos - UINT16_SIZE)[0] # type: ignore[misc]
def read_int16_unpack(self, type_) -> int: def read_int16_unpack(self, type_: _SignatureType) -> int:
return self._read_int16_unpack() return self._read_int16_unpack()
def _read_int16_unpack(self) -> int: def _read_int16_unpack(self) -> int:
self._pos += INT16_SIZE + (-self._pos & (INT16_SIZE - 1)) # align self._pos += INT16_SIZE + (-self._pos & (INT16_SIZE - 1)) # align
if self._is_native and cython.compiled: if self._is_native and cython.compiled:
return _cast_int16_native( # pragma: no cover return _cast_int16_native( # type: ignore[name-defined] # pragma: no cover
self._buf, self._pos - INT16_SIZE self._buf, self._pos - INT16_SIZE
) )
return self._int16_unpack(self._buf, self._pos - INT16_SIZE)[0] return self._int16_unpack(self._buf, self._pos - INT16_SIZE)[0] # type: ignore[misc]
def read_boolean(self, type_) -> bool: def read_boolean(self, type_: _SignatureType) -> bool:
return self._read_boolean() return self._read_boolean()
def _read_boolean(self) -> bool: def _read_boolean(self) -> bool:
return bool(self._read_uint32_unpack()) return bool(self._read_uint32_unpack())
def read_string_unpack(self, type_) -> str: def read_string_unpack(self, type_: _SignatureType) -> str:
return self._read_string_unpack() return self._read_string_unpack()
def _read_string_unpack(self) -> str: def _read_string_unpack(self) -> str:
@@ -335,13 +336,13 @@ class Unmarshaller:
# read terminating '\0' byte as well (str_length + 1) # read terminating '\0' byte as well (str_length + 1)
if self._is_native and cython.compiled: if self._is_native and cython.compiled:
self._pos += ( # pragma: no cover self._pos += ( # pragma: no cover
_cast_uint32_native(self._buf, str_start - UINT32_SIZE) + 1 _cast_uint32_native(self._buf, str_start - UINT32_SIZE) + 1 # type: ignore[name-defined]
) )
else: else:
self._pos += self._uint32_unpack(self._buf, str_start - UINT32_SIZE)[0] + 1 self._pos += self._uint32_unpack(self._buf, str_start - UINT32_SIZE)[0] + 1 # type: ignore[misc]
return self._buf[str_start : self._pos - 1].decode() return self._buf[str_start : self._pos - 1].decode()
def read_signature(self, type_) -> str: def read_signature(self, type_: _SignatureType) -> str:
return self._read_signature() return self._read_signature()
def _read_signature(self) -> str: def _read_signature(self) -> str:
@@ -351,7 +352,7 @@ class Unmarshaller:
self._pos = o + signature_len + 1 self._pos = o + signature_len + 1
return self._buf[o : o + signature_len].decode() return self._buf[o : o + signature_len].decode()
def read_variant(self, type_) -> Variant: def read_variant(self, type_: _SignatureType) -> Variant:
return self._read_variant() return self._read_variant()
def _read_variant(self) -> Variant: def _read_variant(self) -> Variant:
@@ -402,33 +403,35 @@ class Unmarshaller:
False, False,
) )
def read_struct(self, type_) -> List[Any]: def read_struct(self, type_: _SignatureType) -> List[Any]:
self._pos += -self._pos & 7 # align 8 self._pos += -self._pos & 7 # align 8
readers = self._readers readers = self._readers
return [ return [
readers[child_type.token](self, child_type) for child_type in type_.children readers[child_type.token](self, child_type) for child_type in type_.children
] ]
def read_dict_entry(self, type_) -> Tuple[Any, Any]: def read_dict_entry(self, type_: _SignatureType) -> Tuple[Any, Any]:
self._pos += -self._pos & 7 # align 8 self._pos += -self._pos & 7 # align 8
return self._readers[type_.children[0].token]( return self._readers[type_.children[0].token](
self, type_.children[0] self, type_.children[0]
), self._readers[type_.children[1].token](self, type_.children[1]) ), self._readers[type_.children[1].token](self, type_.children[1])
def read_array(self, type_) -> Iterable[Any]: def read_array(self, type_: _SignatureType) -> Iterable[Any]:
return self._read_array(type_) return self._read_array(type_)
def _read_array(self, type_) -> Iterable[Any]: def _read_array(self, type_: _SignatureType) -> Iterable[Any]:
self._pos += -self._pos & 3 # align 4 for the array self._pos += -self._pos & 3 # align 4 for the array
self._pos += ( self._pos += (
-self._pos & (UINT32_SIZE - 1) -self._pos & (UINT32_SIZE - 1)
) + UINT32_SIZE # align for the uint32 ) + UINT32_SIZE # align for the uint32
if self._is_native and cython.compiled: if self._is_native and cython.compiled:
array_length = _cast_uint32_native( # pragma: no cover array_length = (
self._buf, self._pos - UINT32_SIZE _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
self._buf, self._pos - UINT32_SIZE
)
) )
else: else:
array_length = self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] array_length = self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] # type: ignore[misc]
child_type = type_.children[0] child_type = type_.children[0]
token = child_type.token token = child_type.token
@@ -442,7 +445,7 @@ class Unmarshaller:
return self._buf[self._pos - array_length : self._pos] return self._buf[self._pos - array_length : self._pos]
if token == "{": if token == "{":
result_dict = {} result_dict: Dict[Any, Any] = {}
beginning_pos = self._pos beginning_pos = self._pos
children = child_type.children children = child_type.children
child_0 = children[0] child_0 = children[0]
@@ -455,7 +458,7 @@ class Unmarshaller:
if child_0_token in "os" and child_1_token == "v": if child_0_token in "os" and child_1_token == "v":
while self._pos - beginning_pos < array_length: while self._pos - beginning_pos < array_length:
self._pos += -self._pos & 7 # align 8 self._pos += -self._pos & 7 # align 8
key = self._read_string_unpack() key: Union[str, int] = self._read_string_unpack()
result_dict[key] = self._read_variant() result_dict[key] = self._read_variant()
elif child_0_token == "q" and child_1_token == "v": elif child_0_token == "q" and child_1_token == "v":
while self._pos - beginning_pos < array_length: while self._pos - beginning_pos < array_length:
@@ -544,9 +547,15 @@ class Unmarshaller:
or (endian == BIG_ENDIAN and SYS_IS_BIG_ENDIAN) or (endian == BIG_ENDIAN and SYS_IS_BIG_ENDIAN)
): ):
self._is_native = 1 # pragma: no cover self._is_native = 1 # pragma: no cover
self._body_len = _cast_uint32_native(self._buf, 4) # pragma: no cover self._body_len = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
self._serial = _cast_uint32_native(self._buf, 8) # pragma: no cover self._buf, 4
self._header_len = _cast_uint32_native(self._buf, 12) # pragma: no cover )
self._serial = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
self._buf, 8
)
self._header_len = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
self._buf, 12
)
elif endian == LITTLE_ENDIAN: elif endian == LITTLE_ENDIAN:
( (
self._body_len, self._body_len,
@@ -583,7 +592,7 @@ class Unmarshaller:
signature = header_fields.pop("signature", "") signature = header_fields.pop("signature", "")
if not self._body_len: if not self._body_len:
tree = SIGNATURE_TREE_EMPTY tree = SIGNATURE_TREE_EMPTY
body = [] body: List[Any] = []
elif signature == "s": elif signature == "s":
tree = SIGNATURE_TREE_S tree = SIGNATURE_TREE_S
body = [self._read_string_unpack()] body = [self._read_string_unpack()]