fix: clean up address parsing and tests (#244)

This commit is contained in:
J. Nick Koston 2023-09-08 18:33:53 -05:00 committed by GitHub
parent 21f8544d4a
commit 370791da86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 23 deletions

View File

@ -1,5 +1,6 @@
import os import os
import re import re
from typing import Dict, List, Optional, Tuple
from urllib.parse import unquote from urllib.parse import unquote
from ..constants import BusType from ..constants import BusType
@ -8,25 +9,29 @@ from ..errors import InvalidAddressError
invalid_address_chars_re = re.compile(r"[^-0-9A-Za-z_/.%]") invalid_address_chars_re = re.compile(r"[^-0-9A-Za-z_/.%]")
def parse_address(address_str): def parse_address(address_str: str) -> List[Tuple[str, Dict[str, str]]]:
addresses = [] """Parse a dbus address string into a list of addresses."""
addresses: List[Tuple[str, Dict[str, str]]] = []
for address in filter(lambda a: a, address_str.split(";")): for address in address_str.split(";"):
if not address:
continue
if address.find(":") == -1: if address.find(":") == -1:
raise InvalidAddressError("address did not contain a transport") raise InvalidAddressError("address did not contain a transport")
transport, opt_string = address.split(":", 1) transport, opt_string = address.split(":", 1)
options = {} options: Dict[str, str] = {}
for kv in filter(lambda s: s, opt_string.split(",")): for kv in opt_string.split(","):
if not kv:
continue
if kv.find("=") == -1: if kv.find("=") == -1:
raise InvalidAddressError("address option did not contain a value") raise InvalidAddressError("address option did not contain a value")
k, v = kv.split("=", 1) k, v = kv.split("=", 1)
if invalid_address_chars_re.search(v): if invalid_address_chars_re.search(v):
raise InvalidAddressError("address contains invalid characters") raise InvalidAddressError("address contains invalid characters")
# XXX the actual unquote rules are simpler than this # XXX the actual unquote rules are simpler than this
v = unquote(v) options[k] = unquote(v)
options[k] = v
addresses.append((transport, options)) addresses.append((transport, options))
@ -38,20 +43,23 @@ def parse_address(address_str):
return addresses return addresses
def get_system_bus_address(): def get_system_bus_address() -> str:
if "DBUS_SYSTEM_BUS_ADDRESS" in os.environ: """Get the system bus address from the environment or return the default."""
return os.environ["DBUS_SYSTEM_BUS_ADDRESS"] return (
else: os.environ.get("DBUS_SYSTEM_BUS_ADDRESS")
return "unix:path=/var/run/dbus/system_bus_socket" or "unix:path=/var/run/dbus/system_bus_socket"
)
display_re = re.compile(r".*:([0-9]+)\.?.*") display_re = re.compile(r".*:([0-9]+)\.?.*")
remove_quotes_re = re.compile(r"""^['"]?(.*?)['"]?$""") remove_quotes_re = re.compile(r"""^['"]?(.*?)['"]?$""")
def get_session_bus_address(): def get_session_bus_address() -> str:
if "DBUS_SESSION_BUS_ADDRESS" in os.environ: """Get the session bus address from the environment or return the default."""
return os.environ["DBUS_SESSION_BUS_ADDRESS"] dbus_session_bus_address = os.environ.get("DBUS_SESSION_BUS_ADDRESS")
if dbus_session_bus_address:
return dbus_session_bus_address
home = os.environ["HOME"] home = os.environ["HOME"]
if "DISPLAY" not in os.environ: if "DISPLAY" not in os.environ:
@ -75,7 +83,7 @@ def get_session_bus_address():
machine_id = f.read().rstrip() machine_id = f.read().rstrip()
dbus_info_file_name = f"{home}/.dbus/session-bus/{machine_id}-{display}" dbus_info_file_name = f"{home}/.dbus/session-bus/{machine_id}-{display}"
dbus_info = None dbus_info: Optional[str] = None
try: try:
with open(dbus_info_file_name) as f: with open(dbus_info_file_name) as f:
dbus_info = f.read().rstrip() dbus_info = f.read().rstrip()
@ -97,10 +105,10 @@ def get_session_bus_address():
raise InvalidAddressError("could not find dbus session bus address") raise InvalidAddressError("could not find dbus session bus address")
def get_bus_address(bus_type): def get_bus_address(bus_type: BusType) -> str:
"""Get the address of the bus specified by the bus type."""
if bus_type == BusType.SESSION: if bus_type == BusType.SESSION:
return get_session_bus_address() return get_session_bus_address()
elif bus_type == BusType.SYSTEM: if bus_type == BusType.SYSTEM:
return get_system_bus_address() return get_system_bus_address()
else: raise Exception(f"got unknown bus type: {bus_type}")
raise Exception("got unknown bus type: {bus_type}")

View File

@ -1,4 +1,16 @@
from dbus_fast._private.address import parse_address import os
from unittest.mock import patch
import pytest
from dbus_fast._private.address import (
get_bus_address,
get_session_bus_address,
get_system_bus_address,
parse_address,
)
from dbus_fast.constants import BusType
from dbus_fast.errors import InvalidAddressError
def test_valid_addresses(): def test_valid_addresses():
@ -21,7 +33,42 @@ def test_valid_addresses():
"tcp:host=127.0.0.1,port=55556": [ "tcp:host=127.0.0.1,port=55556": [
("tcp", {"host": "127.0.0.1", "port": "55556"}) ("tcp", {"host": "127.0.0.1", "port": "55556"})
], ],
"unix:tmpdir=/tmp,;": [("unix", {"tmpdir": "/tmp"})],
} }
for address, parsed in valid_addresses.items(): for address, parsed in valid_addresses.items():
assert parse_address(address) == parsed assert parse_address(address) == parsed
def test_invalid_addresses():
with pytest.raises(InvalidAddressError):
assert parse_address("")
with pytest.raises(InvalidAddressError):
assert parse_address("unix")
with pytest.raises(InvalidAddressError):
assert parse_address("unix:tmpdir")
with pytest.raises(InvalidAddressError):
assert parse_address("unix:tmpdir=😁")
def test_get_system_bus_address():
with patch.dict(os.environ, DBUS_SYSTEM_BUS_ADDRESS="unix:path=/dog"):
assert get_system_bus_address() == "unix:path=/dog"
assert get_bus_address(BusType.SYSTEM) == "unix:path=/dog"
with patch.dict(os.environ, DBUS_SYSTEM_BUS_ADDRESS=""):
assert get_system_bus_address() == "unix:path=/var/run/dbus/system_bus_socket"
def test_get_session_bus_address():
with patch.dict(os.environ, DBUS_SESSION_BUS_ADDRESS="unix:path=/dog"):
assert get_session_bus_address() == "unix:path=/dog"
assert get_bus_address(BusType.SESSION) == "unix:path=/dog"
with patch.dict(os.environ, DBUS_SESSION_BUS_ADDRESS="", DISPLAY=""), pytest.raises(
InvalidAddressError
):
assert get_session_bus_address()
def test_invalid_bus_address():
with pytest.raises(Exception):
assert get_bus_address(-1)

View File

@ -16,8 +16,13 @@ async def test_tcp_connection_with_forwarding(event_loop):
addr_info = parse_address(os.environ.get("DBUS_SESSION_BUS_ADDRESS")) addr_info = parse_address(os.environ.get("DBUS_SESSION_BUS_ADDRESS"))
assert addr_info assert addr_info
assert "abstract" in addr_info[0][1]
path = f'\0{addr_info[0][1]["abstract"]}' addr_zero_options = addr_info[0][1]
if "abstract" in addr_zero_options:
path = f'\0{addr_zero_options["abstract"]}'
else:
path = addr_zero_options["path"]
async def handle_connection(tcp_reader, tcp_writer): async def handle_connection(tcp_reader, tcp_writer):
unix_reader, unix_writer = await asyncio.open_unix_connection(path) unix_reader, unix_writer = await asyncio.open_unix_connection(path)