feat: add more typing to unmarshaller (#102)

This commit is contained in:
J. Nick Koston
2022-10-10 14:24:08 -10:00
committed by GitHub
parent 561bef2c18
commit e7048fa38b
3 changed files with 21 additions and 21 deletions

View File

@@ -3,7 +3,7 @@ import io
import socket import socket
import sys import sys
from struct import Struct from struct import Struct
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP from ..constants import MESSAGE_FLAG_MAP, MESSAGE_TYPE_MAP
from ..errors import InvalidMessageError from ..errors import InvalidMessageError
@@ -89,7 +89,7 @@ def build_simple_parsers(
endian: int, can_cast: bool endian: int, can_cast: bool
) -> Dict[str, Callable[["Unmarshaller", SignatureType], Any]]: ) -> Dict[str, Callable[["Unmarshaller", SignatureType], Any]]:
"""Build a dict of parsers for simple types.""" """Build a dict of parsers for simple types."""
parsers: Dict[str, Callable[["Unmarshaller", SignatureType], Any]] = {} parsers: Dict[str, READER_TYPE] = {}
for dbus_type, ctype_size in DBUS_TO_CTYPE.items(): for dbus_type, ctype_size in DBUS_TO_CTYPE.items():
ctype, size = ctype_size ctype, size = ctype_size
size = ctype_size[1] size = ctype_size[1]
@@ -150,13 +150,13 @@ class Unmarshaller:
"_uint32_unpack", "_uint32_unpack",
) )
def __init__(self, stream: io.BufferedRWPair, sock=None): def __init__(self, stream: io.BufferedRWPair, sock: Optional[socket.socket] = None):
self._unix_fds: List[int] = [] self._unix_fds: List[int] = []
self._buf = bytearray() # Actual buffer self._buf = bytearray() # Actual buffer
self._view = None # Memory view of the buffer self._view: Optional[memoryview] = None # Memory view of the buffer
self._stream = stream self._stream = stream
self._sock = sock self._sock = sock
self._message: Message | None = None self._message: Optional[Message] = None
self._readers: Dict[str, READER_TYPE] = {} self._readers: Dict[str, READER_TYPE] = {}
self._pos = 0 self._pos = 0
self._body_len = 0 self._body_len = 0
@@ -173,7 +173,7 @@ class Unmarshaller:
Call this before processing a new message. Call this before processing a new message.
""" """
self._unix_fds: List[int] = [] self._unix_fds = []
self._view = None self._view = None
self._buf.clear() self._buf.clear()
self._message = None self._message = None
@@ -213,7 +213,7 @@ class Unmarshaller:
return msg return msg
def read_to_pos(self, pos) -> None: def read_to_pos(self, pos: int) -> None:
""" """
Read from underlying socket into buffer. Read from underlying socket into buffer.
@@ -310,13 +310,13 @@ class Unmarshaller:
readers[child_type.token](self, child_type) for child_type in type_.children readers[child_type.token](self, child_type) for child_type in type_.children
] ]
def read_dict_entry(self, type_: SignatureType) -> Dict[Any, Any]: def read_dict_entry(self, type_: SignatureType) -> Tuple[Any, Any]:
self._pos += -self._pos & 7 # align 8 self._pos += -self._pos & 7 # align 8
return self._readers[type_.children[0].token]( return self._readers[type_.children[0].token](
self, type_.children[0] self, type_.children[0]
), self._readers[type_.children[1].token](self, type_.children[1]) ), self._readers[type_.children[1].token](self, type_.children[1])
def read_array(self, type_: SignatureType) -> List[Any]: def read_array(self, type_: SignatureType) -> Iterable[Any]:
self._pos += -self._pos & 3 # align 4 for the array self._pos += -self._pos & 3 # align 4 for the array
self._pos += ( self._pos += (
-self._pos & (UINT32_SIZE - 1) -self._pos & (UINT32_SIZE - 1)
@@ -378,7 +378,7 @@ class Unmarshaller:
result_list.append(reader(self, child_type)) result_list.append(reader(self, child_type))
return result_list return result_list
def header_fields(self, header_length) -> Dict[str, Any]: def header_fields(self, header_length: int) -> Dict[str, Any]:
"""Header fields are always a(yv).""" """Header fields are always a(yv)."""
beginning_pos = self._pos beginning_pos = self._pos
headers = {} headers = {}
@@ -447,7 +447,7 @@ class Unmarshaller:
if not can_cast: if not can_cast:
self._uint32_unpack = UINT32_UNPACK_BY_ENDIAN[endian] self._uint32_unpack = UINT32_UNPACK_BY_ENDIAN[endian]
def _read_body(self): def _read_body(self) -> None:
"""Read the body of the message.""" """Read the body of the message."""
self.read_to_pos(HEADER_SIGNATURE_SIZE + self._msg_len) self.read_to_pos(HEADER_SIGNATURE_SIZE + self._msg_len)
self._view = memoryview(self._buf) self._view = memoryview(self._buf)
@@ -513,19 +513,17 @@ class Unmarshaller:
INT16_DBUS_TYPE: read_int16_cast, INT16_DBUS_TYPE: read_int16_cast,
} }
_ctype_by_endian: Dict[ _ctype_by_endian: Dict[Tuple[int, bool], Dict[str, READER_TYPE]] = {
Tuple[int, bool], Dict[str, Tuple[None, str, int, Callable]]
] = {
endian_can_cast: build_simple_parsers(*endian_can_cast) endian_can_cast: build_simple_parsers(*endian_can_cast)
for endian_can_cast in [ for endian_can_cast in (
(LITTLE_ENDIAN, True), (LITTLE_ENDIAN, True),
(LITTLE_ENDIAN, False), (LITTLE_ENDIAN, False),
(BIG_ENDIAN, True), (BIG_ENDIAN, True),
(BIG_ENDIAN, False), (BIG_ENDIAN, False),
] )
} }
_readers_by_type: Dict[Tuple[int, bool], READER_TYPE] = { _readers_by_type: Dict[Tuple[int, bool], Dict[str, READER_TYPE]] = {
(LITTLE_ENDIAN, True): { (LITTLE_ENDIAN, True): {
**_ctype_by_endian[(LITTLE_ENDIAN, True)], **_ctype_by_endian[(LITTLE_ENDIAN, True)],
**_complex_parsers_cast, **_complex_parsers_cast,

View File

@@ -423,9 +423,10 @@ class MessageBus(BaseMessageBus):
unmarshaller = self._unmarshaller unmarshaller = self._unmarshaller
try: try:
while True: while True:
if unmarshaller.unmarshall(): message = unmarshaller.unmarshall()
if message:
try: try:
self._process_message(unmarshaller.message) self._process_message(message)
except Exception as e: except Exception as e:
logging.error( logging.error(
f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}" f"got unexpected error processing a message: {e}.\n{traceback.format_exc()}"

View File

@@ -49,8 +49,9 @@ class _MessageSource(_GLibSource):
if not self.unmarshaller: if not self.unmarshaller:
self.unmarshaller = Unmarshaller(self.bus._stream) self.unmarshaller = Unmarshaller(self.bus._stream)
if self.unmarshaller.unmarshall(): message = self.unmarshaller.unmarshall()
callback(self.unmarshaller.message) if message:
callback(message)
self.unmarshaller = None self.unmarshaller = None
else: else:
break break