feat: improve Marshaller performance (#15)
This commit is contained in:
@@ -1,88 +1,29 @@
|
||||
from struct import pack
|
||||
from struct import Struct, error, pack
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from ..signature import SignatureTree
|
||||
from ..signature import SignatureTree, SignatureType, Variant
|
||||
|
||||
PACK_UINT32 = Struct("<I").pack
|
||||
|
||||
|
||||
class Marshaller:
|
||||
def __init__(self, signature, body):
|
||||
def __init__(self, signature: str, body: Any) -> None:
|
||||
self.signature_tree = SignatureTree._get(signature)
|
||||
self.signature_tree.verify(body)
|
||||
self.buffer = bytearray()
|
||||
self.body = body
|
||||
|
||||
self.writers = {
|
||||
"y": self.write_byte,
|
||||
"b": self.write_boolean,
|
||||
"n": self.write_int16,
|
||||
"q": self.write_uint16,
|
||||
"i": self.write_int32,
|
||||
"u": self.write_uint32,
|
||||
"x": self.write_int64,
|
||||
"t": self.write_uint64,
|
||||
"d": self.write_double,
|
||||
"h": self.write_uint32,
|
||||
"o": self.write_string,
|
||||
"s": self.write_string,
|
||||
"g": self.write_signature,
|
||||
"a": self.write_array,
|
||||
"(": self.write_struct,
|
||||
"{": self.write_dict_entry,
|
||||
"v": self.write_variant,
|
||||
}
|
||||
|
||||
def align(self, n):
|
||||
def align(self, n) -> int:
|
||||
offset = n - len(self.buffer) % n
|
||||
if offset == 0 or offset == n:
|
||||
return 0
|
||||
self.buffer.extend(bytes(offset))
|
||||
return offset
|
||||
|
||||
def write_byte(self, byte, _=None):
|
||||
self.buffer.append(byte)
|
||||
return 1
|
||||
def write_boolean(self, boolean: bool, _=None) -> int:
|
||||
self.buffer.extend(PACK_UINT32(int(boolean)))
|
||||
return self.align(4) + 4
|
||||
|
||||
def write_boolean(self, boolean, _=None):
|
||||
if boolean:
|
||||
return self.write_uint32(1)
|
||||
else:
|
||||
return self.write_uint32(0)
|
||||
|
||||
def write_int16(self, int16, _=None):
|
||||
written = self.align(2)
|
||||
self.buffer.extend(pack("<h", int16))
|
||||
return written + 2
|
||||
|
||||
def write_uint16(self, uint16, _=None):
|
||||
written = self.align(2)
|
||||
self.buffer.extend(pack("<H", uint16))
|
||||
return written + 2
|
||||
|
||||
def write_int32(self, int32, _):
|
||||
written = self.align(4)
|
||||
self.buffer.extend(pack("<i", int32))
|
||||
return written + 4
|
||||
|
||||
def write_uint32(self, uint32, _=None):
|
||||
written = self.align(4)
|
||||
self.buffer.extend(pack("<I", uint32))
|
||||
return written + 4
|
||||
|
||||
def write_int64(self, int64, _=None):
|
||||
written = self.align(8)
|
||||
self.buffer.extend(pack("<q", int64))
|
||||
return written + 8
|
||||
|
||||
def write_uint64(self, uint64, _=None):
|
||||
written = self.align(8)
|
||||
self.buffer.extend(pack("<Q", uint64))
|
||||
return written + 8
|
||||
|
||||
def write_double(self, double, _=None):
|
||||
written = self.align(8)
|
||||
self.buffer.extend(pack("<d", double))
|
||||
return written + 8
|
||||
|
||||
def write_signature(self, signature, _=None):
|
||||
def write_signature(self, signature: str, _=None) -> int:
|
||||
signature = signature.encode()
|
||||
signature_len = len(signature)
|
||||
self.buffer.append(signature_len)
|
||||
@@ -90,27 +31,29 @@ class Marshaller:
|
||||
self.buffer.append(0)
|
||||
return signature_len + 2
|
||||
|
||||
def write_string(self, value, _=None):
|
||||
def write_string(self, value: str, _=None) -> int:
|
||||
value = value.encode()
|
||||
value_len = len(value)
|
||||
written = self.write_uint32(value_len)
|
||||
written = self.align(4) + 4
|
||||
self.buffer.extend(PACK_UINT32(value_len))
|
||||
self.buffer.extend(value)
|
||||
written += value_len
|
||||
self.buffer.append(0)
|
||||
written += 1
|
||||
return written
|
||||
|
||||
def write_variant(self, variant, _=None):
|
||||
def write_variant(self, variant: Variant, _=None) -> int:
|
||||
written = self.write_signature(variant.signature)
|
||||
written += self.write_single(variant.type, variant.value)
|
||||
return written
|
||||
|
||||
def write_array(self, array, type_):
|
||||
def write_array(self, array: Any, type_: SignatureType) -> int:
|
||||
# TODO max array size is 64MiB (67108864 bytes)
|
||||
written = self.align(4)
|
||||
# length placeholder
|
||||
offset = len(self.buffer)
|
||||
written += self.write_uint32(0)
|
||||
written += self.align(4) + 4
|
||||
self.buffer.extend(PACK_UINT32(0))
|
||||
child_type = type_.children[0]
|
||||
|
||||
if child_type.token in "xtd{(":
|
||||
@@ -128,34 +71,87 @@ class Marshaller:
|
||||
for value in array:
|
||||
array_len += self.write_single(child_type, value)
|
||||
|
||||
array_len_packed = pack("<I", array_len)
|
||||
array_len_packed = PACK_UINT32(array_len)
|
||||
for i in range(offset, offset + 4):
|
||||
self.buffer[i] = array_len_packed[i - offset]
|
||||
|
||||
return written + array_len
|
||||
|
||||
def write_struct(self, array, type_):
|
||||
def write_struct(self, array: List[Any], type_: SignatureType) -> int:
|
||||
written = self.align(8)
|
||||
for i, value in enumerate(array):
|
||||
written += self.write_single(type_.children[i], value)
|
||||
return written
|
||||
|
||||
def write_dict_entry(self, dict_entry, type_):
|
||||
def write_dict_entry(self, dict_entry: List[Any], type_: SignatureType) -> int:
|
||||
written = self.align(8)
|
||||
written += self.write_single(type_.children[0], dict_entry[0])
|
||||
written += self.write_single(type_.children[1], dict_entry[1])
|
||||
return written
|
||||
|
||||
def write_single(self, type_, body):
|
||||
def write_single(self, type_: SignatureType, body: Any) -> int:
|
||||
t = type_.token
|
||||
|
||||
if t not in self.writers:
|
||||
raise NotImplementedError(f'type isnt implemented yet: "{t}"')
|
||||
if t not in self._writers:
|
||||
raise NotImplementedError(f'type is not implemented yet: "{t}"')
|
||||
|
||||
return self.writers[t](body, type_)
|
||||
writer, packer, size = self._writers[t]
|
||||
if packer and size:
|
||||
written = self.align(size)
|
||||
self.buffer.extend(packer(body))
|
||||
return written + size
|
||||
return writer(self, body, type_)
|
||||
|
||||
def marshall(self):
|
||||
"""Marshalls the body into a byte array"""
|
||||
try:
|
||||
self._construct_buffer()
|
||||
except error:
|
||||
self.signature_tree.verify(self.body)
|
||||
return self.buffer
|
||||
|
||||
def _construct_buffer(self):
|
||||
self.buffer.clear()
|
||||
for i, type_ in enumerate(self.signature_tree.types):
|
||||
self.write_single(type_, self.body[i])
|
||||
return self.buffer
|
||||
t = type_.token
|
||||
if t not in self._writers:
|
||||
raise NotImplementedError(f'type is not implemented yet: "{t}"')
|
||||
|
||||
writer, packer, size = self._writers[t]
|
||||
if packer and size:
|
||||
|
||||
# In-line align
|
||||
offset = size - len(self.buffer) % size
|
||||
if offset != 0 and offset != size:
|
||||
self.buffer.extend(bytes(offset))
|
||||
|
||||
self.buffer.extend(packer(self.body[i]))
|
||||
else:
|
||||
writer(self, self.body[i], type_)
|
||||
|
||||
_writers: Dict[
|
||||
str,
|
||||
Tuple[
|
||||
Optional[Callable[[Any, Any], int]],
|
||||
Optional[Callable[[Any], bytes]],
|
||||
Optional[int],
|
||||
],
|
||||
] = {
|
||||
"y": (None, Struct("<B").pack, 1),
|
||||
"b": (write_boolean, None, None),
|
||||
"n": (None, Struct("<h").pack, 2),
|
||||
"q": (None, Struct("<H").pack, 2),
|
||||
"i": (None, Struct("<i").pack, 4),
|
||||
"u": (None, PACK_UINT32, 4),
|
||||
"x": (None, Struct("<q").pack, 8),
|
||||
"t": (None, Struct("<Q").pack, 8),
|
||||
"d": (None, Struct("<d").pack, 8),
|
||||
"h": (None, Struct("<I").pack, 4),
|
||||
"o": (write_string, None, None),
|
||||
"s": (write_string, None, None),
|
||||
"g": (write_signature, None, None),
|
||||
"a": (write_array, None, None),
|
||||
"(": (write_struct, None, None),
|
||||
"{": (write_dict_entry, None, None),
|
||||
"v": (write_variant, None, None),
|
||||
}
|
||||
|
||||
@@ -19,6 +19,15 @@ REQUIRED_FIELDS = {
|
||||
MessageType.METHOD_RETURN: ("reply_serial",),
|
||||
}
|
||||
|
||||
HEADER_PATH = HeaderField.PATH.value
|
||||
HEADER_INTERFACE = HeaderField.INTERFACE.value
|
||||
HEADER_MEMBER = HeaderField.MEMBER.value
|
||||
HEADER_ERROR_NAME = HeaderField.ERROR_NAME.value
|
||||
HEADER_REPLY_SERIAL = HeaderField.REPLY_SERIAL.value
|
||||
HEADER_DESTINATION = HeaderField.DESTINATION.value
|
||||
HEADER_SIGNATURE = HeaderField.SIGNATURE.value
|
||||
HEADER_UNIX_FDS = HeaderField.UNIX_FDS.value
|
||||
|
||||
|
||||
class Message:
|
||||
"""A class for sending and receiving messages through the
|
||||
@@ -242,27 +251,36 @@ class Message:
|
||||
|
||||
fields = []
|
||||
|
||||
# No verify here since the marshaller will raise an exception if the
|
||||
# Variant is invalid.
|
||||
|
||||
if self.path:
|
||||
fields.append([HeaderField.PATH.value, Variant("o", self.path)])
|
||||
fields.append([HEADER_PATH, Variant("o", self.path, verify=False)])
|
||||
if self.interface:
|
||||
fields.append([HeaderField.INTERFACE.value, Variant("s", self.interface)])
|
||||
fields.append(
|
||||
[HEADER_INTERFACE, Variant("s", self.interface, verify=False)]
|
||||
)
|
||||
if self.member:
|
||||
fields.append([HeaderField.MEMBER.value, Variant("s", self.member)])
|
||||
fields.append([HEADER_MEMBER, Variant("s", self.member, verify=False)])
|
||||
if self.error_name:
|
||||
fields.append([HeaderField.ERROR_NAME.value, Variant("s", self.error_name)])
|
||||
fields.append(
|
||||
[HEADER_ERROR_NAME, Variant("s", self.error_name, verify=False)]
|
||||
)
|
||||
if self.reply_serial:
|
||||
fields.append(
|
||||
[HeaderField.REPLY_SERIAL.value, Variant("u", self.reply_serial)]
|
||||
[HEADER_REPLY_SERIAL, Variant("u", self.reply_serial, verify=False)]
|
||||
)
|
||||
if self.destination:
|
||||
fields.append(
|
||||
[HeaderField.DESTINATION.value, Variant("s", self.destination)]
|
||||
[HEADER_DESTINATION, Variant("s", self.destination, verify=False)]
|
||||
)
|
||||
if self.signature:
|
||||
fields.append([HeaderField.SIGNATURE.value, Variant("g", self.signature)])
|
||||
fields.append(
|
||||
[HEADER_SIGNATURE, Variant("g", self.signature, verify=False)]
|
||||
)
|
||||
if self.unix_fds and negotiate_unix_fd:
|
||||
fields.append(
|
||||
[HeaderField.UNIX_FDS.value, Variant("u", len(self.unix_fds))]
|
||||
[HEADER_UNIX_FDS, Variant("u", len(self.unix_fds), verify=False)]
|
||||
)
|
||||
|
||||
header_body = [
|
||||
|
||||
Reference in New Issue
Block a user