feat: add cython def for unmarshaller read_sock for fd passing (#143)

This commit is contained in:
J. Nick Koston 2022-11-03 22:13:26 +01:00 committed by GitHub
parent a4a77cb13b
commit f438c369bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 8 deletions

View File

@ -5,6 +5,12 @@ import cython
from ..signature import SignatureType
cdef object MAX_UNIX_FDS_SIZE
cdef object ARRAY
cdef object UNIX_FDS_CMSG_LENGTH
cdef object SOL_SOCKET
cdef object SCM_RIGHTS
cdef unsigned int UINT32_SIZE
cdef unsigned int INT16_SIZE
cdef unsigned int UINT16_SIZE
@ -106,7 +112,10 @@ cdef class Unmarshaller:
cpdef reset(self)
cdef bytes _read_sock(self, unsigned long length)
@cython.locals(
msg=cython.bytes,
)
cdef bytes _read_sock(self, object length)
@cython.locals(
start_len=cython.ulong,

View File

@ -12,6 +12,8 @@ from ..signature import SignatureType, Variant, get_signature_tree
from .constants import BIG_ENDIAN, LITTLE_ENDIAN, PROTOCOL_VERSION
MAX_UNIX_FDS = 16
MAX_UNIX_FDS_SIZE = array.array("i").itemsize
UNIX_FDS_CMSG_LENGTH = socket.CMSG_LEN(MAX_UNIX_FDS_SIZE)
UNPACK_SYMBOL = {LITTLE_ENDIAN: "<", BIG_ENDIAN: ">"}
@ -88,6 +90,10 @@ TOKEN_O_AS_INT = ord("o")
TOKEN_S_AS_INT = ord("s")
TOKEN_G_AS_INT = ord("g")
ARRAY = array.array
SOL_SOCKET = socket.SOL_SOCKET
SCM_RIGHTS = socket.SCM_RIGHTS
HEADER_MESSAGE_ARG_NAME = {
1: "path",
2: "interface",
@ -229,22 +235,20 @@ class Unmarshaller:
def _read_sock(self, length: int) -> bytes:
"""reads from the socket, storing any fds sent and handling errors
from the read itself"""
unix_fd_list = array.array("i")
try:
msg, ancdata, *_ = self._sock.recvmsg(
length, socket.CMSG_LEN(MAX_UNIX_FDS * unix_fd_list.itemsize)
msg, ancdata, _flags, _addr = self._sock.recvmsg(
length, UNIX_FDS_CMSG_LENGTH
)
except BlockingIOError:
raise MarshallerStreamEndError()
for level, type_, data in ancdata:
if not (level == socket.SOL_SOCKET and type_ == socket.SCM_RIGHTS):
if not (level == SOL_SOCKET and type_ == SCM_RIGHTS):
continue
unix_fd_list.frombytes(
data[: len(data) - (len(data) % unix_fd_list.itemsize)]
self._unix_fds.extend(
ARRAY("i", data[: len(data) - (len(data) % MAX_UNIX_FDS_SIZE)])
)
self._unix_fds.extend(list(unix_fd_list))
return msg