diff --git a/src/dbus_fast/_private/unmarshaller.pxd b/src/dbus_fast/_private/unmarshaller.pxd index d3ac5a4..ee96541 100644 --- a/src/dbus_fast/_private/unmarshaller.pxd +++ b/src/dbus_fast/_private/unmarshaller.pxd @@ -5,6 +5,12 @@ import cython from ..signature import SignatureType +cdef object MAX_UNIX_FDS_SIZE +cdef object ARRAY +cdef object UNIX_FDS_CMSG_LENGTH +cdef object SOL_SOCKET +cdef object SCM_RIGHTS + cdef unsigned int UINT32_SIZE cdef unsigned int INT16_SIZE cdef unsigned int UINT16_SIZE @@ -106,7 +112,10 @@ cdef class Unmarshaller: cpdef reset(self) - cdef bytes _read_sock(self, unsigned long length) + @cython.locals( + msg=cython.bytes, + ) + cdef bytes _read_sock(self, object length) @cython.locals( start_len=cython.ulong, diff --git a/src/dbus_fast/_private/unmarshaller.py b/src/dbus_fast/_private/unmarshaller.py index 4fcaa27..39074bf 100644 --- a/src/dbus_fast/_private/unmarshaller.py +++ b/src/dbus_fast/_private/unmarshaller.py @@ -12,6 +12,8 @@ from ..signature import SignatureType, Variant, get_signature_tree from .constants import BIG_ENDIAN, LITTLE_ENDIAN, PROTOCOL_VERSION MAX_UNIX_FDS = 16 +MAX_UNIX_FDS_SIZE = array.array("i").itemsize +UNIX_FDS_CMSG_LENGTH = socket.CMSG_LEN(MAX_UNIX_FDS_SIZE) UNPACK_SYMBOL = {LITTLE_ENDIAN: "<", BIG_ENDIAN: ">"} @@ -88,6 +90,10 @@ TOKEN_O_AS_INT = ord("o") TOKEN_S_AS_INT = ord("s") TOKEN_G_AS_INT = ord("g") +ARRAY = array.array +SOL_SOCKET = socket.SOL_SOCKET +SCM_RIGHTS = socket.SCM_RIGHTS + HEADER_MESSAGE_ARG_NAME = { 1: "path", 2: "interface", @@ -229,22 +235,20 @@ class Unmarshaller: def _read_sock(self, length: int) -> bytes: """reads from the socket, storing any fds sent and handling errors from the read itself""" - unix_fd_list = array.array("i") try: - msg, ancdata, *_ = self._sock.recvmsg( - length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize) + msg, ancdata, _flags, _addr = self._sock.recvmsg( + length, UNIX_FDS_CMSG_LENGTH ) except BlockingIOError: raise MarshallerStreamEndError() for level, type_, data in ancdata: - if not (level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS): + if not (level == SOL_SOCKET and type_ == SCM_RIGHTS): continue - unix_fd_list.frombytes( - data[: len(data) - (len(data) % unix_fd_list.itemsize)] + self._unix_fds.extend( + ARRAY("i", data[: len(data) - (len(data) % MAX_UNIX_FDS_SIZE)]) ) - self._unix_fds.extend(list(unix_fd_list)) return msg