fix: incorrect pxd typing for for _marshall (#75)

This commit is contained in:
J. Nick Koston
2022-10-05 14:39:38 -10:00
committed by GitHub
parent 23903c3b9b
commit cf1f0129ba
9 changed files with 46 additions and 27 deletions

View File

@@ -11,7 +11,7 @@ message = Message(
def marhsall_bluez_get_managed_objects_message(): def marhsall_bluez_get_managed_objects_message():
message._marshall() message._marshall(False)
count = 1000000 count = 1000000

17
poetry.lock generated
View File

@@ -357,6 +357,17 @@ pytest = ">=4.6"
[package.extras] [package.extras]
testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"]
[[package]]
name = "pytest-timeout"
version = "2.1.0"
description = "pytest plugin to abort hanging tests"
category = "dev"
optional = false
python-versions = ">=3.6"
[package.dependencies]
pytest = ">=5.0.0"
[[package]] [[package]]
name = "pytz" name = "pytz"
version = "2022.4" version = "2022.4"
@@ -596,7 +607,7 @@ docs = ["myst-parser", "Sphinx", "sphinx-rtd-theme", "sphinxcontrib-asyncio", "s
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.7" python-versions = "^3.7"
content-hash = "96d521d1e66777febd43aad81ec451dbf5a15873e1434052283b6ecdf3095c07" content-hash = "381552380ec2e3115cbc867021d088eeac5b7e7a494070586fa3bc40ad01f6a3"
[metadata.files] [metadata.files]
alabaster = [ alabaster = [
@@ -849,6 +860,10 @@ pytest-cov = [
{file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"}, {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"},
{file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"}, {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"},
] ]
pytest-timeout = [
{file = "pytest-timeout-2.1.0.tar.gz", hash = "sha256:c07ca07404c612f8abbe22294b23c368e2e5104b521c1790195561f37e1ac3d9"},
{file = "pytest_timeout-2.1.0-py3-none-any.whl", hash = "sha256:f6f50101443ce70ad325ceb4473c4255e9d74e3c7cd0ef827309dfa4c0d975c6"},
]
pytz = [ pytz = [
{file = "pytz-2022.4-py2.py3-none-any.whl", hash = "sha256:2c0784747071402c6e99f0bafdb7da0fa22645f06554c7ae06bf6358897e9c91"}, {file = "pytz-2022.4-py2.py3-none-any.whl", hash = "sha256:2c0784747071402c6e99f0bafdb7da0fa22645f06554c7ae06bf6358897e9c91"},
{file = "pytz-2022.4.tar.gz", hash = "sha256:48ce799d83b6f8aab2020e369b627446696619e79645419610b9facd909b3174"}, {file = "pytz-2022.4.tar.gz", hash = "sha256:48ce799d83b6f8aab2020e369b627446696619e79645419610b9facd909b3174"},

View File

@@ -54,6 +54,7 @@ pycairo = "^1.21.0"
PyGObject = "^3.42.2" PyGObject = "^3.42.2"
Cython = "^0.29.32" Cython = "^0.29.32"
setuptools = "^65.4.1" setuptools = "^65.4.1"
pytest-timeout = "^2.1.0"
[tool.semantic_release] [tool.semantic_release]
branch = "main" branch = "main"

View File

@@ -95,7 +95,7 @@ class _MessageWriter:
def buffer_message(self, msg: Message, future=None) -> None: def buffer_message(self, msg: Message, future=None) -> None:
self.messages.append( self.messages.append(
( (
msg._marshall(negotiate_unix_fd=self.negotiate_unix_fd), msg._marshall(self.negotiate_unix_fd),
copy(msg.unix_fds), copy(msg.unix_fds),
future, future,
) )
@@ -216,7 +216,7 @@ class MessageBus(BaseMessageBus):
) )
self._method_return_handlers[hello_msg.serial] = on_hello self._method_return_handlers[hello_msg.serial] = on_hello
self._stream.write(hello_msg._marshall()) self._stream.write(hello_msg._marshall(False))
self._stream.flush() self._stream.flush()
return await future return await future

View File

@@ -98,7 +98,7 @@ class _MessageWritableSource(_GLibSource):
return GLib.SOURCE_REMOVE return GLib.SOURCE_REMOVE
else: else:
message = self.bus._buffered_messages.pop(0) message = self.bus._buffered_messages.pop(0)
self.message_stream = io.BytesIO(message._marshall()) self.message_stream = io.BytesIO(message._marshall(False))
return GLib.SOURCE_CONTINUE return GLib.SOURCE_CONTINUE
except BlockingIOError: except BlockingIOError:
return GLib.SOURCE_CONTINUE return GLib.SOURCE_CONTINUE
@@ -233,7 +233,7 @@ class MessageBus(BaseMessageBus):
) )
self._method_return_handlers[hello_msg.serial] = on_hello self._method_return_handlers[hello_msg.serial] = on_hello
self._stream.write(hello_msg._marshall()) self._stream.write(hello_msg._marshall(False))
self._stream.flush() self._stream.flush()
self._authenticate(authenticate_notify) self._authenticate(authenticate_notify)

View File

@@ -20,4 +20,4 @@ cdef class Message:
cdef public list body cdef public list body
cdef public unsigned int serial cdef public unsigned int serial
cpdef _marshall(self, negotiate_unix_fd: bint) cpdef _marshall(self, bint negotiate_unix_fd)

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Union from typing import Any, List, Optional, Union
from ._private.constants import LITTLE_ENDIAN, PROTOCOL_VERSION, HeaderField from ._private.constants import LITTLE_ENDIAN, PROTOCOL_VERSION, HeaderField
from ._private.marshaller import Marshaller from ._private.marshaller import Marshaller
@@ -95,17 +95,17 @@ class Message:
def __init__( def __init__(
self, self,
destination: str = None, destination: Optional[str] = None,
path: str = None, path: Optional[str] = None,
interface: str = None, interface: Optional[str] = None,
member: str = None, member: Optional[str] = None,
message_type: MessageType = MessageType.METHOD_CALL, message_type: MessageType = MessageType.METHOD_CALL,
flags: MessageFlag = MessageFlag.NONE, flags: MessageFlag = MessageFlag.NONE,
error_name: str = None, error_name: Optional[Union[str, ErrorType]] = None,
reply_serial: int = None, reply_serial=0,
sender: str = None, sender: Optional[str] = None,
unix_fds: List[int] = [], unix_fds: List[int] = [],
signature: Union[str, SignatureTree] = "", signature: Optional[Union[SignatureTree, str]] = None,
body: List[Any] = [], body: List[Any] = [],
serial: int = 0, serial: int = 0,
validate: bool = True, validate: bool = True,
@@ -119,7 +119,7 @@ class Message:
flags if type(flags) is MessageFlag else MessageFlag(bytes([flags])) flags if type(flags) is MessageFlag else MessageFlag(bytes([flags]))
) )
self.error_name = ( self.error_name = (
error_name if type(error_name) is not ErrorType else error_name.value str(error_name.value) if type(error_name) is ErrorType else error_name
) )
self.reply_serial = reply_serial or 0 self.reply_serial = reply_serial or 0
self.sender = sender self.sender = sender
@@ -128,8 +128,8 @@ class Message:
self.signature = signature.signature self.signature = signature.signature
self.signature_tree = signature self.signature_tree = signature
else: else:
self.signature = signature self.signature = signature or ""
self.signature_tree = get_signature_tree(signature) self.signature_tree = get_signature_tree(signature or "")
self.body = body self.body = body
self.serial = serial or 0 self.serial = serial or 0
@@ -154,7 +154,9 @@ class Message:
raise InvalidMessageError(f"missing required field: {field}") raise InvalidMessageError(f"missing required field: {field}")
@staticmethod @staticmethod
def new_error(msg: "Message", error_name: str, error_text: str) -> "Message": def new_error(
msg: "Message", error_name: Union[str, ErrorType], error_text: str
) -> "Message":
"""A convenience constructor to create an error message in reply to the given message. """A convenience constructor to create an error message in reply to the given message.
:param msg: The message this error is in reply to. :param msg: The message this error is in reply to.
@@ -255,7 +257,8 @@ class Message:
unix_fds=unix_fds, unix_fds=unix_fds,
) )
def _marshall(self, negotiate_unix_fd=False): def _marshall(self, negotiate_unix_fd: bool) -> bytearray:
"""Marshall this message into a byte array."""
# TODO maximum message size is 134217728 (128 MiB) # TODO maximum message size is 134217728 (128 MiB)
body_block = Marshaller(self.signature, self.body) body_block = Marshaller(self.signature, self.body)
body_block.marshall() body_block.marshall()

View File

@@ -135,7 +135,7 @@ def is_member_name_valid(member: str) -> bool:
return True return True
def assert_bus_name_valid(name: str): def assert_bus_name_valid(name: str) -> None:
"""Raise an error if this is not a valid bus name. """Raise an error if this is not a valid bus name.
.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-bus .. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-bus
@@ -150,7 +150,7 @@ def assert_bus_name_valid(name: str):
raise InvalidBusNameError(name) raise InvalidBusNameError(name)
def assert_object_path_valid(path: str): def assert_object_path_valid(path: str) -> None:
"""Raise an error if this is not a valid object path. """Raise an error if this is not a valid object path.
.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-marshaling-object-path .. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-marshaling-object-path
@@ -165,7 +165,7 @@ def assert_object_path_valid(path: str):
raise InvalidObjectPathError(path) raise InvalidObjectPathError(path)
def assert_interface_name_valid(name: str): def assert_interface_name_valid(name: str) -> None:
"""Raise an error if this is not a valid interface name. """Raise an error if this is not a valid interface name.
.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-interface .. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-interface
@@ -180,7 +180,7 @@ def assert_interface_name_valid(name: str):
raise InvalidInterfaceNameError(name) raise InvalidInterfaceNameError(name)
def assert_member_name_valid(member): def assert_member_name_valid(member) -> None:
"""Raise an error if this is not a valid member name. """Raise an error if this is not a valid member name.
.. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-member .. seealso:: https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-names-member

View File

@@ -78,7 +78,7 @@ def test_marshalling_with_table():
body.append(replace_variants(type_, message.body[i])) body.append(replace_variants(type_, message.body[i]))
message.body = body message.body = body
buf = message._marshall() buf = message._marshall(False)
data = bytes.fromhex(item["data"]) data = bytes.fromhex(item["data"])
if buf != data: if buf != data:
@@ -173,6 +173,6 @@ def test_unmarshall_can_resume():
def test_ay_buffer(): def test_ay_buffer():
body = [bytes(10000)] body = [bytes(10000)]
msg = Message(path="/test", member="test", signature="ay", body=body) msg = Message(path="/test", member="test", signature="ay", body=body)
marshalled = msg._marshall() marshalled = msg._marshall(False)
unmarshalled_msg = Unmarshaller(io.BytesIO(marshalled)).unmarshall() unmarshalled_msg = Unmarshaller(io.BytesIO(marshalled)).unmarshall()
assert unmarshalled_msg.body[0] == body[0] assert unmarshalled_msg.body[0] == body[0]