fix: cleanup typing in marshaller and unmarshaller (#190)
This commit is contained in:
parent
1f42d28c47
commit
830183e188
@ -10,6 +10,10 @@ PACKED_UINT32_ZERO = PACK_UINT32(0)
|
||||
PACKED_BOOL_FALSE = PACK_UINT32(int(0))
|
||||
PACKED_BOOL_TRUE = PACK_UINT32(int(1))
|
||||
|
||||
_int = int
|
||||
_bytes = bytes
|
||||
_str = str
|
||||
|
||||
|
||||
class Marshaller:
|
||||
"""Marshall data for Dbus."""
|
||||
@ -29,10 +33,10 @@ class Marshaller:
|
||||
def _buffer(self) -> bytearray:
|
||||
return self._buf
|
||||
|
||||
def align(self, n):
|
||||
def align(self, n: _int) -> int:
|
||||
return self._align(n)
|
||||
|
||||
def _align(self, n):
|
||||
def _align(self, n: _int) -> _int:
|
||||
offset = n - len(self._buf) % n
|
||||
if offset == 0 or offset == n:
|
||||
return 0
|
||||
@ -51,7 +55,7 @@ class Marshaller:
|
||||
def write_signature(self, signature: str, type_: SignatureType) -> int:
|
||||
return self._write_signature(signature.encode())
|
||||
|
||||
def _write_signature(self, signature_bytes) -> int:
|
||||
def _write_signature(self, signature_bytes: _bytes) -> int:
|
||||
signature_len = len(signature_bytes)
|
||||
buf = self._buf
|
||||
buf.append(signature_len)
|
||||
@ -59,10 +63,10 @@ class Marshaller:
|
||||
buf.append(0)
|
||||
return signature_len + 2
|
||||
|
||||
def write_string(self, value, type_: SignatureType) -> int:
|
||||
def write_string(self, value: _str, type_: SignatureType) -> int:
|
||||
return self._write_string(value)
|
||||
|
||||
def _write_string(self, value) -> int:
|
||||
def _write_string(self, value: _str) -> int:
|
||||
value_bytes = value.encode()
|
||||
value_len = len(value)
|
||||
written = self._align(4) + 4
|
||||
@ -81,7 +85,7 @@ class Marshaller:
|
||||
signature = variant.signature
|
||||
signature_bytes = signature.encode()
|
||||
written = self._write_signature(signature_bytes)
|
||||
written += self._write_single(variant.type, variant.value)
|
||||
written += self._write_single(variant.type, variant.value) # type: ignore[has-type]
|
||||
return written
|
||||
|
||||
def write_array(
|
||||
@ -184,6 +188,7 @@ class Marshaller:
|
||||
raise NotImplementedError(f'type is not implemented yet: "{ex.args}"')
|
||||
except error:
|
||||
self.signature_tree.verify(self.body)
|
||||
raise RuntimeError("should not reach here")
|
||||
|
||||
def _construct_buffer(self) -> bytearray:
|
||||
self._buf.clear()
|
||||
|
||||
@ -3,7 +3,7 @@ import io
|
||||
import socket
|
||||
import sys
|
||||
from struct import Struct
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP
|
||||
from ..errors import InvalidMessageError
|
||||
@ -114,6 +114,7 @@ HEADER_MESSAGE_ARG_NAME = {
|
||||
9: "unix_fds",
|
||||
}
|
||||
|
||||
_SignatureType = SignatureType
|
||||
|
||||
READER_TYPE = Callable[["Unmarshaller", SignatureType], Any]
|
||||
|
||||
@ -210,7 +211,7 @@ class Unmarshaller:
|
||||
self._stream_reader: Optional[Callable] = None
|
||||
if self._sock is None:
|
||||
if isinstance(stream, io.BufferedRWPair) and hasattr(stream, "reader"):
|
||||
self._stream_reader = stream.reader.read
|
||||
self._stream_reader = stream.reader.read # type: ignore[attr-defined]
|
||||
self._stream_reader = stream.read
|
||||
|
||||
def reset(self) -> None:
|
||||
@ -240,7 +241,7 @@ class Unmarshaller:
|
||||
# every time a new message is processed.
|
||||
|
||||
@property
|
||||
def message(self) -> Message:
|
||||
def message(self) -> Optional[Message]:
|
||||
"""Return the message that has been unmarshalled."""
|
||||
return self._message
|
||||
|
||||
@ -249,7 +250,7 @@ class Unmarshaller:
|
||||
from the read itself"""
|
||||
# This will raise BlockingIOError if there is no data to read
|
||||
# which we store in the MARSHALL_STREAM_END_ERROR object
|
||||
msg, ancdata, _flags, _addr = self._sock.recvmsg(length, UNIX_FDS_CMSG_LENGTH)
|
||||
msg, ancdata, _flags, _addr = self._sock.recvmsg(length, UNIX_FDS_CMSG_LENGTH) # type: ignore[union-attr]
|
||||
for level, type_, data in ancdata:
|
||||
if not (level == SOL_SOCKET and type_ == SCM_RIGHTS):
|
||||
continue
|
||||
@ -275,7 +276,7 @@ class Unmarshaller:
|
||||
start_len = len(self._buf)
|
||||
missing_bytes = pos - (start_len - self._pos)
|
||||
if self._sock is None:
|
||||
data = self._stream_reader(missing_bytes)
|
||||
data = self._stream_reader(missing_bytes) # type: ignore[misc]
|
||||
else:
|
||||
data = self._read_sock(missing_bytes)
|
||||
if data == b"":
|
||||
@ -286,46 +287,46 @@ class Unmarshaller:
|
||||
if len(data) + start_len != pos:
|
||||
raise MARSHALL_STREAM_END_ERROR
|
||||
|
||||
def read_uint32_unpack(self, type_) -> int:
|
||||
def read_uint32_unpack(self, type_: _SignatureType) -> int:
|
||||
return self._read_uint32_unpack()
|
||||
|
||||
def _read_uint32_unpack(self) -> int:
|
||||
self._pos += UINT32_SIZE + (-self._pos & (UINT32_SIZE - 1)) # align
|
||||
if self._is_native and cython.compiled:
|
||||
return _cast_uint32_native( # pragma: no cover
|
||||
return _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
|
||||
self._buf, self._pos - UINT32_SIZE
|
||||
)
|
||||
return self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0]
|
||||
return self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] # type: ignore[misc]
|
||||
|
||||
def read_uint16_unpack(self, type_) -> int:
|
||||
def read_uint16_unpack(self, type_: _SignatureType) -> int:
|
||||
return self._read_uint16_unpack()
|
||||
|
||||
def _read_uint16_unpack(self) -> int:
|
||||
self._pos += UINT16_SIZE + (-self._pos & (UINT16_SIZE - 1)) # align
|
||||
if self._is_native and cython.compiled:
|
||||
return _cast_uint16_native( # pragma: no cover
|
||||
return _cast_uint16_native( # type: ignore[name-defined] # pragma: no cover
|
||||
self._buf, self._pos - UINT16_SIZE
|
||||
)
|
||||
return self._uint16_unpack(self._buf, self._pos - UINT16_SIZE)[0]
|
||||
return self._uint16_unpack(self._buf, self._pos - UINT16_SIZE)[0] # type: ignore[misc]
|
||||
|
||||
def read_int16_unpack(self, type_) -> int:
|
||||
def read_int16_unpack(self, type_: _SignatureType) -> int:
|
||||
return self._read_int16_unpack()
|
||||
|
||||
def _read_int16_unpack(self) -> int:
|
||||
self._pos += INT16_SIZE + (-self._pos & (INT16_SIZE - 1)) # align
|
||||
if self._is_native and cython.compiled:
|
||||
return _cast_int16_native( # pragma: no cover
|
||||
return _cast_int16_native( # type: ignore[name-defined] # pragma: no cover
|
||||
self._buf, self._pos - INT16_SIZE
|
||||
)
|
||||
return self._int16_unpack(self._buf, self._pos - INT16_SIZE)[0]
|
||||
return self._int16_unpack(self._buf, self._pos - INT16_SIZE)[0] # type: ignore[misc]
|
||||
|
||||
def read_boolean(self, type_) -> bool:
|
||||
def read_boolean(self, type_: _SignatureType) -> bool:
|
||||
return self._read_boolean()
|
||||
|
||||
def _read_boolean(self) -> bool:
|
||||
return bool(self._read_uint32_unpack())
|
||||
|
||||
def read_string_unpack(self, type_) -> str:
|
||||
def read_string_unpack(self, type_: _SignatureType) -> str:
|
||||
return self._read_string_unpack()
|
||||
|
||||
def _read_string_unpack(self) -> str:
|
||||
@ -335,13 +336,13 @@ class Unmarshaller:
|
||||
# read terminating '\0' byte as well (str_length + 1)
|
||||
if self._is_native and cython.compiled:
|
||||
self._pos += ( # pragma: no cover
|
||||
_cast_uint32_native(self._buf, str_start - UINT32_SIZE) + 1
|
||||
_cast_uint32_native(self._buf, str_start - UINT32_SIZE) + 1 # type: ignore[name-defined]
|
||||
)
|
||||
else:
|
||||
self._pos += self._uint32_unpack(self._buf, str_start - UINT32_SIZE)[0] + 1
|
||||
self._pos += self._uint32_unpack(self._buf, str_start - UINT32_SIZE)[0] + 1 # type: ignore[misc]
|
||||
return self._buf[str_start : self._pos - 1].decode()
|
||||
|
||||
def read_signature(self, type_) -> str:
|
||||
def read_signature(self, type_: _SignatureType) -> str:
|
||||
return self._read_signature()
|
||||
|
||||
def _read_signature(self) -> str:
|
||||
@ -351,7 +352,7 @@ class Unmarshaller:
|
||||
self._pos = o + signature_len + 1
|
||||
return self._buf[o : o + signature_len].decode()
|
||||
|
||||
def read_variant(self, type_) -> Variant:
|
||||
def read_variant(self, type_: _SignatureType) -> Variant:
|
||||
return self._read_variant()
|
||||
|
||||
def _read_variant(self) -> Variant:
|
||||
@ -402,33 +403,35 @@ class Unmarshaller:
|
||||
False,
|
||||
)
|
||||
|
||||
def read_struct(self, type_) -> List[Any]:
|
||||
def read_struct(self, type_: _SignatureType) -> List[Any]:
|
||||
self._pos += -self._pos & 7 # align 8
|
||||
readers = self._readers
|
||||
return [
|
||||
readers[child_type.token](self, child_type) for child_type in type_.children
|
||||
]
|
||||
|
||||
def read_dict_entry(self, type_) -> Tuple[Any, Any]:
|
||||
def read_dict_entry(self, type_: _SignatureType) -> Tuple[Any, Any]:
|
||||
self._pos += -self._pos & 7 # align 8
|
||||
return self._readers[type_.children[0].token](
|
||||
self, type_.children[0]
|
||||
), self._readers[type_.children[1].token](self, type_.children[1])
|
||||
|
||||
def read_array(self, type_) -> Iterable[Any]:
|
||||
def read_array(self, type_: _SignatureType) -> Iterable[Any]:
|
||||
return self._read_array(type_)
|
||||
|
||||
def _read_array(self, type_) -> Iterable[Any]:
|
||||
def _read_array(self, type_: _SignatureType) -> Iterable[Any]:
|
||||
self._pos += -self._pos & 3 # align 4 for the array
|
||||
self._pos += (
|
||||
-self._pos & (UINT32_SIZE - 1)
|
||||
) + UINT32_SIZE # align for the uint32
|
||||
if self._is_native and cython.compiled:
|
||||
array_length = _cast_uint32_native( # pragma: no cover
|
||||
self._buf, self._pos - UINT32_SIZE
|
||||
array_length = (
|
||||
_cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
|
||||
self._buf, self._pos - UINT32_SIZE
|
||||
)
|
||||
)
|
||||
else:
|
||||
array_length = self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0]
|
||||
array_length = self._uint32_unpack(self._buf, self._pos - UINT32_SIZE)[0] # type: ignore[misc]
|
||||
|
||||
child_type = type_.children[0]
|
||||
token = child_type.token
|
||||
@ -442,7 +445,7 @@ class Unmarshaller:
|
||||
return self._buf[self._pos - array_length : self._pos]
|
||||
|
||||
if token == "{":
|
||||
result_dict = {}
|
||||
result_dict: Dict[Any, Any] = {}
|
||||
beginning_pos = self._pos
|
||||
children = child_type.children
|
||||
child_0 = children[0]
|
||||
@ -455,7 +458,7 @@ class Unmarshaller:
|
||||
if child_0_token in "os" and child_1_token == "v":
|
||||
while self._pos - beginning_pos < array_length:
|
||||
self._pos += -self._pos & 7 # align 8
|
||||
key = self._read_string_unpack()
|
||||
key: Union[str, int] = self._read_string_unpack()
|
||||
result_dict[key] = self._read_variant()
|
||||
elif child_0_token == "q" and child_1_token == "v":
|
||||
while self._pos - beginning_pos < array_length:
|
||||
@ -544,9 +547,15 @@ class Unmarshaller:
|
||||
or (endian == BIG_ENDIAN and SYS_IS_BIG_ENDIAN)
|
||||
):
|
||||
self._is_native = 1 # pragma: no cover
|
||||
self._body_len = _cast_uint32_native(self._buf, 4) # pragma: no cover
|
||||
self._serial = _cast_uint32_native(self._buf, 8) # pragma: no cover
|
||||
self._header_len = _cast_uint32_native(self._buf, 12) # pragma: no cover
|
||||
self._body_len = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
|
||||
self._buf, 4
|
||||
)
|
||||
self._serial = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
|
||||
self._buf, 8
|
||||
)
|
||||
self._header_len = _cast_uint32_native( # type: ignore[name-defined] # pragma: no cover
|
||||
self._buf, 12
|
||||
)
|
||||
elif endian == LITTLE_ENDIAN:
|
||||
(
|
||||
self._body_len,
|
||||
@ -583,7 +592,7 @@ class Unmarshaller:
|
||||
signature = header_fields.pop("signature", "")
|
||||
if not self._body_len:
|
||||
tree = SIGNATURE_TREE_EMPTY
|
||||
body = []
|
||||
body: List[Any] = []
|
||||
elif signature == "s":
|
||||
tree = SIGNATURE_TREE_S
|
||||
body = [self._read_string_unpack()]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user