diff --git a/src/dbus_fast/_private/address.py b/src/dbus_fast/_private/address.py index 4f41e41..f6322c1 100644 --- a/src/dbus_fast/_private/address.py +++ b/src/dbus_fast/_private/address.py @@ -1,5 +1,6 @@ import os import re +from typing import Dict, List, Optional, Tuple from urllib.parse import unquote from ..constants import BusType @@ -8,25 +9,29 @@ from ..errors import InvalidAddressError invalid_address_chars_re = re.compile(r"[^-0-9A-Za-z_/.%]") -def parse_address(address_str): - addresses = [] +def parse_address(address_str: str) -> List[Tuple[str, Dict[str, str]]]: + """Parse a dbus address string into a list of addresses.""" + addresses: List[Tuple[str, Dict[str, str]]] = [] - for address in filter(lambda a: a, address_str.split(";")): + for address in address_str.split(";"): + if not address: + continue if address.find(":") == -1: raise InvalidAddressError("address did not contain a transport") transport, opt_string = address.split(":", 1) - options = {} + options: Dict[str, str] = {} - for kv in filter(lambda s: s, opt_string.split(",")): + for kv in opt_string.split(","): + if not kv: + continue if kv.find("=") == -1: raise InvalidAddressError("address option did not contain a value") k, v = kv.split("=", 1) if invalid_address_chars_re.search(v): raise InvalidAddressError("address contains invalid characters") # XXX the actual unquote rules are simpler than this - v = unquote(v) - options[k] = v + options[k] = unquote(v) addresses.append((transport, options)) @@ -38,20 +43,23 @@ def parse_address(address_str): return addresses -def get_system_bus_address(): - if "DBUS_SYSTEM_BUS_ADDRESS" in os.environ: - return os.environ["DBUS_SYSTEM_BUS_ADDRESS"] - else: - return "unix:path=/var/run/dbus/system_bus_socket" +def get_system_bus_address() -> str: + """Get the system bus address from the environment or return the default.""" + return ( + os.environ.get("DBUS_SYSTEM_BUS_ADDRESS") + or "unix:path=/var/run/dbus/system_bus_socket" + ) display_re = re.compile(r".*:([0-9]+)\.?.*") remove_quotes_re = re.compile(r"""^['"]?(.*?)['"]?$""") -def get_session_bus_address(): - if "DBUS_SESSION_BUS_ADDRESS" in os.environ: - return os.environ["DBUS_SESSION_BUS_ADDRESS"] +def get_session_bus_address() -> str: + """Get the session bus address from the environment or return the default.""" + dbus_session_bus_address = os.environ.get("DBUS_SESSION_BUS_ADDRESS") + if dbus_session_bus_address: + return dbus_session_bus_address home = os.environ["HOME"] if "DISPLAY" not in os.environ: @@ -75,7 +83,7 @@ def get_session_bus_address(): machine_id = f.read().rstrip() dbus_info_file_name = f"{home}/.dbus/session-bus/{machine_id}-{display}" - dbus_info = None + dbus_info: Optional[str] = None try: with open(dbus_info_file_name) as f: dbus_info = f.read().rstrip() @@ -97,10 +105,10 @@ def get_session_bus_address(): raise InvalidAddressError("could not find dbus session bus address") -def get_bus_address(bus_type): +def get_bus_address(bus_type: BusType) -> str: + """Get the address of the bus specified by the bus type.""" if bus_type == BusType.SESSION: return get_session_bus_address() - elif bus_type == BusType.SYSTEM: + if bus_type == BusType.SYSTEM: return get_system_bus_address() - else: - raise Exception("got unknown bus type: {bus_type}") + raise Exception(f"got unknown bus type: {bus_type}") diff --git a/tests/test_address_parser.py b/tests/test_address_parser.py index 848adc2..bf999c1 100644 --- a/tests/test_address_parser.py +++ b/tests/test_address_parser.py @@ -1,4 +1,16 @@ -from dbus_fast._private.address import parse_address +import os +from unittest.mock import patch + +import pytest + +from dbus_fast._private.address import ( + get_bus_address, + get_session_bus_address, + get_system_bus_address, + parse_address, +) +from dbus_fast.constants import BusType +from dbus_fast.errors import InvalidAddressError def test_valid_addresses(): @@ -21,7 +33,42 @@ def test_valid_addresses(): "tcp:host=127.0.0.1,port=55556": [ ("tcp", {"host": "127.0.0.1", "port": "55556"}) ], + "unix:tmpdir=/tmp,;": [("unix", {"tmpdir": "/tmp"})], } for address, parsed in valid_addresses.items(): assert parse_address(address) == parsed + + +def test_invalid_addresses(): + with pytest.raises(InvalidAddressError): + assert parse_address("") + with pytest.raises(InvalidAddressError): + assert parse_address("unix") + with pytest.raises(InvalidAddressError): + assert parse_address("unix:tmpdir") + with pytest.raises(InvalidAddressError): + assert parse_address("unix:tmpdir=😁") + + +def test_get_system_bus_address(): + with patch.dict(os.environ, DBUS_SYSTEM_BUS_ADDRESS="unix:path=/dog"): + assert get_system_bus_address() == "unix:path=/dog" + assert get_bus_address(BusType.SYSTEM) == "unix:path=/dog" + with patch.dict(os.environ, DBUS_SYSTEM_BUS_ADDRESS=""): + assert get_system_bus_address() == "unix:path=/var/run/dbus/system_bus_socket" + + +def test_get_session_bus_address(): + with patch.dict(os.environ, DBUS_SESSION_BUS_ADDRESS="unix:path=/dog"): + assert get_session_bus_address() == "unix:path=/dog" + assert get_bus_address(BusType.SESSION) == "unix:path=/dog" + with patch.dict(os.environ, DBUS_SESSION_BUS_ADDRESS="", DISPLAY=""), pytest.raises( + InvalidAddressError + ): + assert get_session_bus_address() + + +def test_invalid_bus_address(): + with pytest.raises(Exception): + assert get_bus_address(-1) diff --git a/tests/test_tcp_address.py b/tests/test_tcp_address.py index b6f4c78..5c20f09 100644 --- a/tests/test_tcp_address.py +++ b/tests/test_tcp_address.py @@ -16,8 +16,13 @@ async def test_tcp_connection_with_forwarding(event_loop): addr_info = parse_address(os.environ.get("DBUS_SESSION_BUS_ADDRESS")) assert addr_info - assert "abstract" in addr_info[0][1] - path = f'\0{addr_info[0][1]["abstract"]}' + + addr_zero_options = addr_info[0][1] + + if "abstract" in addr_zero_options: + path = f'\0{addr_zero_options["abstract"]}' + else: + path = addr_zero_options["path"] async def handle_connection(tcp_reader, tcp_writer): unix_reader, unix_writer = await asyncio.open_unix_connection(path)