fix: clean up address parsing and tests (#244)
This commit is contained in:
parent
21f8544d4a
commit
370791da86
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from ..constants import BusType
|
from ..constants import BusType
|
||||||
@ -8,25 +9,29 @@ from ..errors import InvalidAddressError
|
|||||||
invalid_address_chars_re = re.compile(r"[^-0-9A-Za-z_/.%]")
|
invalid_address_chars_re = re.compile(r"[^-0-9A-Za-z_/.%]")
|
||||||
|
|
||||||
|
|
||||||
def parse_address(address_str):
|
def parse_address(address_str: str) -> List[Tuple[str, Dict[str, str]]]:
|
||||||
addresses = []
|
"""Parse a dbus address string into a list of addresses."""
|
||||||
|
addresses: List[Tuple[str, Dict[str, str]]] = []
|
||||||
|
|
||||||
for address in filter(lambda a: a, address_str.split(";")):
|
for address in address_str.split(";"):
|
||||||
|
if not address:
|
||||||
|
continue
|
||||||
if address.find(":") == -1:
|
if address.find(":") == -1:
|
||||||
raise InvalidAddressError("address did not contain a transport")
|
raise InvalidAddressError("address did not contain a transport")
|
||||||
|
|
||||||
transport, opt_string = address.split(":", 1)
|
transport, opt_string = address.split(":", 1)
|
||||||
options = {}
|
options: Dict[str, str] = {}
|
||||||
|
|
||||||
for kv in filter(lambda s: s, opt_string.split(",")):
|
for kv in opt_string.split(","):
|
||||||
|
if not kv:
|
||||||
|
continue
|
||||||
if kv.find("=") == -1:
|
if kv.find("=") == -1:
|
||||||
raise InvalidAddressError("address option did not contain a value")
|
raise InvalidAddressError("address option did not contain a value")
|
||||||
k, v = kv.split("=", 1)
|
k, v = kv.split("=", 1)
|
||||||
if invalid_address_chars_re.search(v):
|
if invalid_address_chars_re.search(v):
|
||||||
raise InvalidAddressError("address contains invalid characters")
|
raise InvalidAddressError("address contains invalid characters")
|
||||||
# XXX the actual unquote rules are simpler than this
|
# XXX the actual unquote rules are simpler than this
|
||||||
v = unquote(v)
|
options[k] = unquote(v)
|
||||||
options[k] = v
|
|
||||||
|
|
||||||
addresses.append((transport, options))
|
addresses.append((transport, options))
|
||||||
|
|
||||||
@ -38,20 +43,23 @@ def parse_address(address_str):
|
|||||||
return addresses
|
return addresses
|
||||||
|
|
||||||
|
|
||||||
def get_system_bus_address():
|
def get_system_bus_address() -> str:
|
||||||
if "DBUS_SYSTEM_BUS_ADDRESS" in os.environ:
|
"""Get the system bus address from the environment or return the default."""
|
||||||
return os.environ["DBUS_SYSTEM_BUS_ADDRESS"]
|
return (
|
||||||
else:
|
os.environ.get("DBUS_SYSTEM_BUS_ADDRESS")
|
||||||
return "unix:path=/var/run/dbus/system_bus_socket"
|
or "unix:path=/var/run/dbus/system_bus_socket"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
display_re = re.compile(r".*:([0-9]+)\.?.*")
|
display_re = re.compile(r".*:([0-9]+)\.?.*")
|
||||||
remove_quotes_re = re.compile(r"""^['"]?(.*?)['"]?$""")
|
remove_quotes_re = re.compile(r"""^['"]?(.*?)['"]?$""")
|
||||||
|
|
||||||
|
|
||||||
def get_session_bus_address():
|
def get_session_bus_address() -> str:
|
||||||
if "DBUS_SESSION_BUS_ADDRESS" in os.environ:
|
"""Get the session bus address from the environment or return the default."""
|
||||||
return os.environ["DBUS_SESSION_BUS_ADDRESS"]
|
dbus_session_bus_address = os.environ.get("DBUS_SESSION_BUS_ADDRESS")
|
||||||
|
if dbus_session_bus_address:
|
||||||
|
return dbus_session_bus_address
|
||||||
|
|
||||||
home = os.environ["HOME"]
|
home = os.environ["HOME"]
|
||||||
if "DISPLAY" not in os.environ:
|
if "DISPLAY" not in os.environ:
|
||||||
@ -75,7 +83,7 @@ def get_session_bus_address():
|
|||||||
machine_id = f.read().rstrip()
|
machine_id = f.read().rstrip()
|
||||||
|
|
||||||
dbus_info_file_name = f"{home}/.dbus/session-bus/{machine_id}-{display}"
|
dbus_info_file_name = f"{home}/.dbus/session-bus/{machine_id}-{display}"
|
||||||
dbus_info = None
|
dbus_info: Optional[str] = None
|
||||||
try:
|
try:
|
||||||
with open(dbus_info_file_name) as f:
|
with open(dbus_info_file_name) as f:
|
||||||
dbus_info = f.read().rstrip()
|
dbus_info = f.read().rstrip()
|
||||||
@ -97,10 +105,10 @@ def get_session_bus_address():
|
|||||||
raise InvalidAddressError("could not find dbus session bus address")
|
raise InvalidAddressError("could not find dbus session bus address")
|
||||||
|
|
||||||
|
|
||||||
def get_bus_address(bus_type):
|
def get_bus_address(bus_type: BusType) -> str:
|
||||||
|
"""Get the address of the bus specified by the bus type."""
|
||||||
if bus_type == BusType.SESSION:
|
if bus_type == BusType.SESSION:
|
||||||
return get_session_bus_address()
|
return get_session_bus_address()
|
||||||
elif bus_type == BusType.SYSTEM:
|
if bus_type == BusType.SYSTEM:
|
||||||
return get_system_bus_address()
|
return get_system_bus_address()
|
||||||
else:
|
raise Exception(f"got unknown bus type: {bus_type}")
|
||||||
raise Exception("got unknown bus type: {bus_type}")
|
|
||||||
|
|||||||
@ -1,4 +1,16 @@
|
|||||||
from dbus_fast._private.address import parse_address
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from dbus_fast._private.address import (
|
||||||
|
get_bus_address,
|
||||||
|
get_session_bus_address,
|
||||||
|
get_system_bus_address,
|
||||||
|
parse_address,
|
||||||
|
)
|
||||||
|
from dbus_fast.constants import BusType
|
||||||
|
from dbus_fast.errors import InvalidAddressError
|
||||||
|
|
||||||
|
|
||||||
def test_valid_addresses():
|
def test_valid_addresses():
|
||||||
@ -21,7 +33,42 @@ def test_valid_addresses():
|
|||||||
"tcp:host=127.0.0.1,port=55556": [
|
"tcp:host=127.0.0.1,port=55556": [
|
||||||
("tcp", {"host": "127.0.0.1", "port": "55556"})
|
("tcp", {"host": "127.0.0.1", "port": "55556"})
|
||||||
],
|
],
|
||||||
|
"unix:tmpdir=/tmp,;": [("unix", {"tmpdir": "/tmp"})],
|
||||||
}
|
}
|
||||||
|
|
||||||
for address, parsed in valid_addresses.items():
|
for address, parsed in valid_addresses.items():
|
||||||
assert parse_address(address) == parsed
|
assert parse_address(address) == parsed
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_addresses():
|
||||||
|
with pytest.raises(InvalidAddressError):
|
||||||
|
assert parse_address("")
|
||||||
|
with pytest.raises(InvalidAddressError):
|
||||||
|
assert parse_address("unix")
|
||||||
|
with pytest.raises(InvalidAddressError):
|
||||||
|
assert parse_address("unix:tmpdir")
|
||||||
|
with pytest.raises(InvalidAddressError):
|
||||||
|
assert parse_address("unix:tmpdir=😁")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_system_bus_address():
|
||||||
|
with patch.dict(os.environ, DBUS_SYSTEM_BUS_ADDRESS="unix:path=/dog"):
|
||||||
|
assert get_system_bus_address() == "unix:path=/dog"
|
||||||
|
assert get_bus_address(BusType.SYSTEM) == "unix:path=/dog"
|
||||||
|
with patch.dict(os.environ, DBUS_SYSTEM_BUS_ADDRESS=""):
|
||||||
|
assert get_system_bus_address() == "unix:path=/var/run/dbus/system_bus_socket"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_session_bus_address():
|
||||||
|
with patch.dict(os.environ, DBUS_SESSION_BUS_ADDRESS="unix:path=/dog"):
|
||||||
|
assert get_session_bus_address() == "unix:path=/dog"
|
||||||
|
assert get_bus_address(BusType.SESSION) == "unix:path=/dog"
|
||||||
|
with patch.dict(os.environ, DBUS_SESSION_BUS_ADDRESS="", DISPLAY=""), pytest.raises(
|
||||||
|
InvalidAddressError
|
||||||
|
):
|
||||||
|
assert get_session_bus_address()
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_bus_address():
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
assert get_bus_address(-1)
|
||||||
|
|||||||
@ -16,8 +16,13 @@ async def test_tcp_connection_with_forwarding(event_loop):
|
|||||||
|
|
||||||
addr_info = parse_address(os.environ.get("DBUS_SESSION_BUS_ADDRESS"))
|
addr_info = parse_address(os.environ.get("DBUS_SESSION_BUS_ADDRESS"))
|
||||||
assert addr_info
|
assert addr_info
|
||||||
assert "abstract" in addr_info[0][1]
|
|
||||||
path = f'\0{addr_info[0][1]["abstract"]}'
|
addr_zero_options = addr_info[0][1]
|
||||||
|
|
||||||
|
if "abstract" in addr_zero_options:
|
||||||
|
path = f'\0{addr_zero_options["abstract"]}'
|
||||||
|
else:
|
||||||
|
path = addr_zero_options["path"]
|
||||||
|
|
||||||
async def handle_connection(tcp_reader, tcp_writer):
|
async def handle_connection(tcp_reader, tcp_writer):
|
||||||
unix_reader, unix_writer = await asyncio.open_unix_connection(path)
|
unix_reader, unix_writer = await asyncio.open_unix_connection(path)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user