fix: clean up address parsing and tests (#244)
This commit is contained in:
parent
21f8544d4a
commit
370791da86
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from urllib.parse import unquote
|
||||
|
||||
from ..constants import BusType
|
||||
@ -8,25 +9,29 @@ from ..errors import InvalidAddressError
|
||||
invalid_address_chars_re = re.compile(r"[^-0-9A-Za-z_/.%]")
|
||||
|
||||
|
||||
def parse_address(address_str):
|
||||
addresses = []
|
||||
def parse_address(address_str: str) -> List[Tuple[str, Dict[str, str]]]:
|
||||
"""Parse a dbus address string into a list of addresses."""
|
||||
addresses: List[Tuple[str, Dict[str, str]]] = []
|
||||
|
||||
for address in filter(lambda a: a, address_str.split(";")):
|
||||
for address in address_str.split(";"):
|
||||
if not address:
|
||||
continue
|
||||
if address.find(":") == -1:
|
||||
raise InvalidAddressError("address did not contain a transport")
|
||||
|
||||
transport, opt_string = address.split(":", 1)
|
||||
options = {}
|
||||
options: Dict[str, str] = {}
|
||||
|
||||
for kv in filter(lambda s: s, opt_string.split(",")):
|
||||
for kv in opt_string.split(","):
|
||||
if not kv:
|
||||
continue
|
||||
if kv.find("=") == -1:
|
||||
raise InvalidAddressError("address option did not contain a value")
|
||||
k, v = kv.split("=", 1)
|
||||
if invalid_address_chars_re.search(v):
|
||||
raise InvalidAddressError("address contains invalid characters")
|
||||
# XXX the actual unquote rules are simpler than this
|
||||
v = unquote(v)
|
||||
options[k] = v
|
||||
options[k] = unquote(v)
|
||||
|
||||
addresses.append((transport, options))
|
||||
|
||||
@ -38,20 +43,23 @@ def parse_address(address_str):
|
||||
return addresses
|
||||
|
||||
|
||||
def get_system_bus_address():
|
||||
if "DBUS_SYSTEM_BUS_ADDRESS" in os.environ:
|
||||
return os.environ["DBUS_SYSTEM_BUS_ADDRESS"]
|
||||
else:
|
||||
return "unix:path=/var/run/dbus/system_bus_socket"
|
||||
def get_system_bus_address() -> str:
|
||||
"""Get the system bus address from the environment or return the default."""
|
||||
return (
|
||||
os.environ.get("DBUS_SYSTEM_BUS_ADDRESS")
|
||||
or "unix:path=/var/run/dbus/system_bus_socket"
|
||||
)
|
||||
|
||||
|
||||
display_re = re.compile(r".*:([0-9]+)\.?.*")
|
||||
remove_quotes_re = re.compile(r"""^['"]?(.*?)['"]?$""")
|
||||
|
||||
|
||||
def get_session_bus_address():
|
||||
if "DBUS_SESSION_BUS_ADDRESS" in os.environ:
|
||||
return os.environ["DBUS_SESSION_BUS_ADDRESS"]
|
||||
def get_session_bus_address() -> str:
|
||||
"""Get the session bus address from the environment or return the default."""
|
||||
dbus_session_bus_address = os.environ.get("DBUS_SESSION_BUS_ADDRESS")
|
||||
if dbus_session_bus_address:
|
||||
return dbus_session_bus_address
|
||||
|
||||
home = os.environ["HOME"]
|
||||
if "DISPLAY" not in os.environ:
|
||||
@ -75,7 +83,7 @@ def get_session_bus_address():
|
||||
machine_id = f.read().rstrip()
|
||||
|
||||
dbus_info_file_name = f"{home}/.dbus/session-bus/{machine_id}-{display}"
|
||||
dbus_info = None
|
||||
dbus_info: Optional[str] = None
|
||||
try:
|
||||
with open(dbus_info_file_name) as f:
|
||||
dbus_info = f.read().rstrip()
|
||||
@ -97,10 +105,10 @@ def get_session_bus_address():
|
||||
raise InvalidAddressError("could not find dbus session bus address")
|
||||
|
||||
|
||||
def get_bus_address(bus_type):
|
||||
def get_bus_address(bus_type: BusType) -> str:
|
||||
"""Get the address of the bus specified by the bus type."""
|
||||
if bus_type == BusType.SESSION:
|
||||
return get_session_bus_address()
|
||||
elif bus_type == BusType.SYSTEM:
|
||||
if bus_type == BusType.SYSTEM:
|
||||
return get_system_bus_address()
|
||||
else:
|
||||
raise Exception("got unknown bus type: {bus_type}")
|
||||
raise Exception(f"got unknown bus type: {bus_type}")
|
||||
|
||||
@ -1,4 +1,16 @@
|
||||
from dbus_fast._private.address import parse_address
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dbus_fast._private.address import (
|
||||
get_bus_address,
|
||||
get_session_bus_address,
|
||||
get_system_bus_address,
|
||||
parse_address,
|
||||
)
|
||||
from dbus_fast.constants import BusType
|
||||
from dbus_fast.errors import InvalidAddressError
|
||||
|
||||
|
||||
def test_valid_addresses():
|
||||
@ -21,7 +33,42 @@ def test_valid_addresses():
|
||||
"tcp:host=127.0.0.1,port=55556": [
|
||||
("tcp", {"host": "127.0.0.1", "port": "55556"})
|
||||
],
|
||||
"unix:tmpdir=/tmp,;": [("unix", {"tmpdir": "/tmp"})],
|
||||
}
|
||||
|
||||
for address, parsed in valid_addresses.items():
|
||||
assert parse_address(address) == parsed
|
||||
|
||||
|
||||
def test_invalid_addresses():
|
||||
with pytest.raises(InvalidAddressError):
|
||||
assert parse_address("")
|
||||
with pytest.raises(InvalidAddressError):
|
||||
assert parse_address("unix")
|
||||
with pytest.raises(InvalidAddressError):
|
||||
assert parse_address("unix:tmpdir")
|
||||
with pytest.raises(InvalidAddressError):
|
||||
assert parse_address("unix:tmpdir=😁")
|
||||
|
||||
|
||||
def test_get_system_bus_address():
|
||||
with patch.dict(os.environ, DBUS_SYSTEM_BUS_ADDRESS="unix:path=/dog"):
|
||||
assert get_system_bus_address() == "unix:path=/dog"
|
||||
assert get_bus_address(BusType.SYSTEM) == "unix:path=/dog"
|
||||
with patch.dict(os.environ, DBUS_SYSTEM_BUS_ADDRESS=""):
|
||||
assert get_system_bus_address() == "unix:path=/var/run/dbus/system_bus_socket"
|
||||
|
||||
|
||||
def test_get_session_bus_address():
|
||||
with patch.dict(os.environ, DBUS_SESSION_BUS_ADDRESS="unix:path=/dog"):
|
||||
assert get_session_bus_address() == "unix:path=/dog"
|
||||
assert get_bus_address(BusType.SESSION) == "unix:path=/dog"
|
||||
with patch.dict(os.environ, DBUS_SESSION_BUS_ADDRESS="", DISPLAY=""), pytest.raises(
|
||||
InvalidAddressError
|
||||
):
|
||||
assert get_session_bus_address()
|
||||
|
||||
|
||||
def test_invalid_bus_address():
|
||||
with pytest.raises(Exception):
|
||||
assert get_bus_address(-1)
|
||||
|
||||
@ -16,8 +16,13 @@ async def test_tcp_connection_with_forwarding(event_loop):
|
||||
|
||||
addr_info = parse_address(os.environ.get("DBUS_SESSION_BUS_ADDRESS"))
|
||||
assert addr_info
|
||||
assert "abstract" in addr_info[0][1]
|
||||
path = f'\0{addr_info[0][1]["abstract"]}'
|
||||
|
||||
addr_zero_options = addr_info[0][1]
|
||||
|
||||
if "abstract" in addr_zero_options:
|
||||
path = f'\0{addr_zero_options["abstract"]}'
|
||||
else:
|
||||
path = addr_zero_options["path"]
|
||||
|
||||
async def handle_connection(tcp_reader, tcp_writer):
|
||||
unix_reader, unix_writer = await asyncio.open_unix_connection(path)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user