fix: clean up address parsing and tests (#244)

This commit is contained in:
J. Nick Koston 2023-09-08 18:33:53 -05:00 committed by GitHub
parent 21f8544d4a
commit 370791da86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 23 deletions

View File

@ -1,5 +1,6 @@
import os
import re
from typing import Dict, List, Optional, Tuple
from urllib.parse import unquote
from ..constants import BusType
@ -8,25 +9,29 @@ from ..errors import InvalidAddressError
invalid_address_chars_re = re.compile(r"[^-0-9A-Za-z_/.%]")
def parse_address(address_str):
addresses = []
def parse_address(address_str: str) -> List[Tuple[str, Dict[str, str]]]:
"""Parse a dbus address string into a list of addresses."""
addresses: List[Tuple[str, Dict[str, str]]] = []
for address in filter(lambda a: a, address_str.split(";")):
for address in address_str.split(";"):
if not address:
continue
if address.find(":") == -1:
raise InvalidAddressError("address did not contain a transport")
transport, opt_string = address.split(":", 1)
options = {}
options: Dict[str, str] = {}
for kv in filter(lambda s: s, opt_string.split(",")):
for kv in opt_string.split(","):
if not kv:
continue
if kv.find("=") == -1:
raise InvalidAddressError("address option did not contain a value")
k, v = kv.split("=", 1)
if invalid_address_chars_re.search(v):
raise InvalidAddressError("address contains invalid characters")
# XXX the actual unquote rules are simpler than this
v = unquote(v)
options[k] = v
options[k] = unquote(v)
addresses.append((transport, options))
@ -38,20 +43,23 @@ def parse_address(address_str):
return addresses
def get_system_bus_address():
if "DBUS_SYSTEM_BUS_ADDRESS" in os.environ:
return os.environ["DBUS_SYSTEM_BUS_ADDRESS"]
else:
return "unix:path=/var/run/dbus/system_bus_socket"
def get_system_bus_address() -> str:
"""Get the system bus address from the environment or return the default."""
return (
os.environ.get("DBUS_SYSTEM_BUS_ADDRESS")
or "unix:path=/var/run/dbus/system_bus_socket"
)
display_re = re.compile(r".*:([0-9]+)\.?.*")
remove_quotes_re = re.compile(r"""^['"]?(.*?)['"]?$""")
def get_session_bus_address():
if "DBUS_SESSION_BUS_ADDRESS" in os.environ:
return os.environ["DBUS_SESSION_BUS_ADDRESS"]
def get_session_bus_address() -> str:
"""Get the session bus address from the environment or return the default."""
dbus_session_bus_address = os.environ.get("DBUS_SESSION_BUS_ADDRESS")
if dbus_session_bus_address:
return dbus_session_bus_address
home = os.environ["HOME"]
if "DISPLAY" not in os.environ:
@ -75,7 +83,7 @@ def get_session_bus_address():
machine_id = f.read().rstrip()
dbus_info_file_name = f"{home}/.dbus/session-bus/{machine_id}-{display}"
dbus_info = None
dbus_info: Optional[str] = None
try:
with open(dbus_info_file_name) as f:
dbus_info = f.read().rstrip()
@ -97,10 +105,10 @@ def get_session_bus_address():
raise InvalidAddressError("could not find dbus session bus address")
def get_bus_address(bus_type):
def get_bus_address(bus_type: BusType) -> str:
"""Get the address of the bus specified by the bus type."""
if bus_type == BusType.SESSION:
return get_session_bus_address()
elif bus_type == BusType.SYSTEM:
if bus_type == BusType.SYSTEM:
return get_system_bus_address()
else:
raise Exception("got unknown bus type: {bus_type}")
raise Exception(f"got unknown bus type: {bus_type}")

View File

@ -1,4 +1,16 @@
from dbus_fast._private.address import parse_address
import os
from unittest.mock import patch
import pytest
from dbus_fast._private.address import (
get_bus_address,
get_session_bus_address,
get_system_bus_address,
parse_address,
)
from dbus_fast.constants import BusType
from dbus_fast.errors import InvalidAddressError
def test_valid_addresses():
@ -21,7 +33,42 @@ def test_valid_addresses():
"tcp:host=127.0.0.1,port=55556": [
("tcp", {"host": "127.0.0.1", "port": "55556"})
],
"unix:tmpdir=/tmp,;": [("unix", {"tmpdir": "/tmp"})],
}
for address, parsed in valid_addresses.items():
assert parse_address(address) == parsed
def test_invalid_addresses():
with pytest.raises(InvalidAddressError):
assert parse_address("")
with pytest.raises(InvalidAddressError):
assert parse_address("unix")
with pytest.raises(InvalidAddressError):
assert parse_address("unix:tmpdir")
with pytest.raises(InvalidAddressError):
assert parse_address("unix:tmpdir=😁")
def test_get_system_bus_address():
with patch.dict(os.environ, DBUS_SYSTEM_BUS_ADDRESS="unix:path=/dog"):
assert get_system_bus_address() == "unix:path=/dog"
assert get_bus_address(BusType.SYSTEM) == "unix:path=/dog"
with patch.dict(os.environ, DBUS_SYSTEM_BUS_ADDRESS=""):
assert get_system_bus_address() == "unix:path=/var/run/dbus/system_bus_socket"
def test_get_session_bus_address():
with patch.dict(os.environ, DBUS_SESSION_BUS_ADDRESS="unix:path=/dog"):
assert get_session_bus_address() == "unix:path=/dog"
assert get_bus_address(BusType.SESSION) == "unix:path=/dog"
with patch.dict(os.environ, DBUS_SESSION_BUS_ADDRESS="", DISPLAY=""), pytest.raises(
InvalidAddressError
):
assert get_session_bus_address()
def test_invalid_bus_address():
with pytest.raises(Exception):
assert get_bus_address(-1)

View File

@ -16,8 +16,13 @@ async def test_tcp_connection_with_forwarding(event_loop):
addr_info = parse_address(os.environ.get("DBUS_SESSION_BUS_ADDRESS"))
assert addr_info
assert "abstract" in addr_info[0][1]
path = f'\0{addr_info[0][1]["abstract"]}'
addr_zero_options = addr_info[0][1]
if "abstract" in addr_zero_options:
path = f'\0{addr_zero_options["abstract"]}'
else:
path = addr_zero_options["path"]
async def handle_connection(tcp_reader, tcp_writer):
unix_reader, unix_writer = await asyncio.open_unix_connection(path)