reworked agent registration to be more robust

This commit is contained in:
Ezri Brimhall 2025-10-14 11:40:26 -06:00
parent 110ce50a95
commit 92576b044d
Signed by: ezri
GPG Key ID: 058A78E5680C6F24
7 changed files with 314 additions and 84 deletions

View File

@ -8,7 +8,7 @@
<allow own="dev.ezri.vpn1"/>
<allow send_destination="dev.ezri.vpn1"/>
<allow receive_sender="dev.ezri.vpn1"/>
<allow send_path="/dev/ezri/vpn1"/>
<allow send_interface="dev.ezri.vpn1.Agent"/>
</policy>
<policy context="default">

View File

@ -6,6 +6,10 @@ from dbus_fast import Variant, PropertyAccess, BusType, DBusError
from ..common.implementations import AUTH_HANDLERS, AGENT_LISTENERS
from ..common.errors import AuthFlowUnsupported, UnknownTarget
from ..common import introspection, async_signal
from ..common.introspection import (
manager as manager_introspection_data,
dbus as dbus_introspection_data,
)
from ..utils import aobject
from ..common.dbus_types import *
from systemd import daemon
@ -16,12 +20,6 @@ from .notifications import NotificationClient
import atexit
import asyncio
manager_introspection_file = (
import_resources.files(introspection) / "dev.ezri.vpn1.Manager.xml"
)
with manager_introspection_file.open("rt") as f:
manager_introspection_data = f.read()
class AgentListenerInterface(ServiceInterface):
"""Session bus agent interface to receive out-of-band authentication data."""
@ -61,51 +59,46 @@ class AgentInterface(ServiceInterface, aobject):
await session_bus.request_name("dev.ezri.vpn1.Agent")
loop = get_running_loop()
self._auth_lock = asyncio.Lock()
self._registration_task = get_running_loop().create_task(
self.manager_register_watcher()
)
async def register_with_manager(self):
"""Connect to the manager and register the agent."""
proxy_object = self._bus.get_proxy_object(
self._manager_connected = False
self.manager = bus.get_proxy_object(
"dev.ezri.vpn1", "/dev/ezri/vpn1", manager_introspection_data
)
self.manager = proxy_object.get_interface("dev.ezri.vpn1.Manager")
await self._register_agent()
).get_interface("dev.ezri.vpn1.Manager")
dbus = bus.get_proxy_object(
"org.freedesktop.DBus", "/org/freedesktop/DBus", dbus_introspection_data
).get_interface("org.freedesktop.DBus")
dbus.on_name_owner_changed(self._on_name_change)
# Try to register ourselves with the manager immediately
try:
await self._register_agent()
except:
self.logger.warning(
"VPN manager not available, will register when it comes online."
)
async def manager_register_watcher(self):
"""Async loop to keep agent registered with VPN manager."""
while True:
async def _on_name_change(self, bus_name: str, old_name: str, new_name: str):
if bus_name != "dev.ezri.vpn1":
# we only care about the vpn manager so ignore any other names
return
if old_name == "":
# new manager, register ourselves with it
try:
await self.register_with_manager()
except DBusError as e:
self.logger.debug(f"Got error when registering agent: {e}")
self.logger.warn("Registration failed, trying again in 5 seconds")
daemon.notify("STATUS=Waiting for manager...")
await self._register_agent()
except:
self.logger.error("Failed to register with manager!")
else:
self.logger.info("Registered agent with manager")
daemon.notify("STATUS=Ready")
await async_signal.for_signal(self.manager, "Shutdown")
self.logger.info("Manager shutting down, will try to reconnect")
await sleep(5)
self._manager_connected = True
if new_name == "":
# manager shutting down, mark us as disconnected
self.logger.info("Manager shutting down.")
async def _register_agent(self):
"""Register the agent with the service."""
self.logger.info("Registering agent with manager")
await self.manager.call_register_agent(self._bus.unique_name, "/dev/ezri/vpn1")
async def _unregister_agent(self):
"""Unregister the agent with the service."""
self.logger.info("Unregistering agent from manager")
await self.manager.call_unregister_agent(
self._bus.unique_name, "/dev/ezri/vpn1"
)
await self.manager.call_register("/dev/ezri/vpn1")
async def shutdown(self):
"""Shutdown the agent."""
self.logger.debug("Shutting down agent")
await self._unregister_agent()
@dbus_property(access=PropertyAccess.READ)
def SupportedAuthFlows(self) -> "as":

View File

@ -38,6 +38,7 @@ class SignalIterator:
def _on_signal(self, *args):
self._future.set_result(args)
self._gen_new_future()
def __aiter__(self):
"""Async iterator."""

View File

@ -5,8 +5,8 @@ from typing import IO
from dbus_fast.introspection import Node
def _intro(interface: str) -> Node:
with (resources.files() / f"dev.ezri.vpn1.{interface}.xml").open() as f:
def _intro(interface: str, namespace: str = "dev.ezri.vpn1") -> Node:
with (resources.files() / f"{base}.{interface}.xml").open() as f:
return Node.parse(f.read())
@ -15,5 +15,6 @@ agent_listener = _intro("AgentListener")
connection = _intro("Connection")
manager = _intro("Manager")
protocol = _intro("Protocol")
dbus = _intro("DBus", namespace="org.freedesktop")
__all__ = ("agent", "agent_listener", "connection", "manager", "protocol")
__all__ = ("agent", "agent_listener", "connection", "manager", "protocol", "dbus")

View File

@ -34,6 +34,11 @@ node PUBLIC "-//freedesktop//DTD D-BUS Object Introspection 1.0//EN"
<arg type="as" name="invalidated_properties" />
</signal>
</interface>
<interface name="dev.ezri.vpn1.AgentManager">
<method name="Register">
<arg name="path" type="o" direction="in" />
</method>
</interface>
<interface name="dev.ezri.vpn1.Manager">
<property name="Connected" type="ao" access="read" />
<property name="AvailableConnections" type="ao" access="read" />

View File

@ -0,0 +1,144 @@
<!DOCTYPE node PUBLIC "-//freedesktop//DTD D-BUS Object Introspection 1.0//EN"
"http://www.freedesktop.org/standards/dbus/1.0/introspect.dtd">
<node>
<interface name="org.freedesktop.DBus">
<method name="Hello">
<arg direction="out" type="s"/>
</method>
<method name="RequestName">
<arg direction="in" type="s"/>
<arg direction="in" type="u"/>
<arg direction="out" type="u"/>
</method>
<method name="ReleaseName">
<arg direction="in" type="s"/>
<arg direction="out" type="u"/>
</method>
<method name="StartServiceByName">
<arg direction="in" type="s"/>
<arg direction="in" type="u"/>
<arg direction="out" type="u"/>
</method>
<method name="UpdateActivationEnvironment">
<arg direction="in" type="a{ss}"/>
</method>
<method name="NameHasOwner">
<arg direction="in" type="s"/>
<arg direction="out" type="b"/>
</method>
<method name="ListNames">
<arg direction="out" type="as"/>
</method>
<method name="ListActivatableNames">
<arg direction="out" type="as"/>
</method>
<method name="AddMatch">
<arg direction="in" type="s"/>
</method>
<method name="RemoveMatch">
<arg direction="in" type="s"/>
</method>
<method name="GetNameOwner">
<arg direction="in" type="s"/>
<arg direction="out" type="s"/>
</method>
<method name="ListQueuedOwners">
<arg direction="in" type="s"/>
<arg direction="out" type="as"/>
</method>
<method name="GetConnectionUnixUser">
<arg direction="in" type="s"/>
<arg direction="out" type="u"/>
</method>
<method name="GetConnectionUnixProcessID">
<arg direction="in" type="s"/>
<arg direction="out" type="u"/>
</method>
<method name="GetAdtAuditSessionData">
<arg direction="in" type="s"/>
<arg direction="out" type="ay"/>
</method>
<method name="GetConnectionSELinuxSecurityContext">
<arg direction="in" type="s"/>
<arg direction="out" type="ay"/>
</method>
<method name="ReloadConfig">
</method>
<method name="GetId">
<arg direction="out" type="s"/>
</method>
<method name="GetConnectionCredentials">
<arg direction="in" type="s"/>
<arg direction="out" type="a{sv}"/>
</method>
<property name="Features" type="as" access="read">
<annotation name="org.freedesktop.DBus.Property.EmitsChangedSignal" value="const"/>
</property>
<property name="Interfaces" type="as" access="read">
<annotation name="org.freedesktop.DBus.Property.EmitsChangedSignal" value="const"/>
</property>
<signal name="NameOwnerChanged">
<arg type="s"/>
<arg type="s"/>
<arg type="s"/>
</signal>
<signal name="NameLost">
<arg type="s"/>
</signal>
<signal name="NameAcquired">
<arg type="s"/>
</signal>
</interface>
<interface name="org.freedesktop.DBus.Properties">
<method name="Get">
<arg direction="in" type="s"/>
<arg direction="in" type="s"/>
<arg direction="out" type="v"/>
</method>
<method name="GetAll">
<arg direction="in" type="s"/>
<arg direction="out" type="a{sv}"/>
</method>
<method name="Set">
<arg direction="in" type="s"/>
<arg direction="in" type="s"/>
<arg direction="in" type="v"/>
</method>
<signal name="PropertiesChanged">
<arg type="s" name="interface_name"/>
<arg type="a{sv}" name="changed_properties"/>
<arg type="as" name="invalidated_properties"/>
</signal>
</interface>
<interface name="org.freedesktop.DBus.Introspectable">
<method name="Introspect">
<arg direction="out" type="s"/>
</method>
</interface>
<interface name="org.freedesktop.DBus.Monitoring">
<method name="BecomeMonitor">
<arg direction="in" type="as"/>
<arg direction="in" type="u"/>
</method>
</interface>
<interface name="org.freedesktop.DBus.Peer">
<method name="GetMachineId">
<arg direction="out" type="s"/>
</method>
<method name="Ping">
</method>
</interface>
<interface name="org.freedesktop.DBus.Debug.Stats">
<method name="GetStats">
<arg direction="out" type="a{sv}"/>
</method>
<method name="GetConnectionStats">
<arg direction="in" type="s"/>
<arg direction="out" type="a{sv}"/>
</method>
<method name="GetAllMatchRules">
<arg direction="out" type="a{sas}"/>
</method>
</interface>
</node>

View File

@ -18,11 +18,116 @@ from vpn_manager.common.implementations import VPN_BACKENDS, export_protocols
from vpn_manager.utils import encode_object_path_segment, load_introspection
from ..common.context import allow_interactive_authorization, sender
from ..common.dbus_types import *
from asyncio import timeout, TaskGroup
from ..common.async_signal import SignalIterator
from ..common.introspection import (
agent as agent_introspection_data,
dbus as dbus_introspection_data,
)
from operator import attrgetter
from asyncio import timeout, TaskGroup, create_task
from .connection import ConnectionInterface
import logging
agent_introspection_data = load_introspection("dev.ezri.vpn1.Agent")
class AgentManagerInterface(ServiceInterface):
"""Agent manager sub-interface for VPN service."""
class _Agent:
"""Wrapper around the agent."""
def __init__(
self, bus: MessageBus, agent_interface: ProxyInterface, priority: int
):
self._bus = bus
self._iface = agent_interface
self._uid = -1
self.priority = priority
async def populate(self):
"""Populate the object with relevant values."""
dbus_proxy = self._bus.get_proxy_object(
"org.freedesktop.DBus", "/org/freedesktop/DBus", dbus_introspection_data
).get_interface("org.freedesktop.DBus")
self._dbus_proxy.on_name_owner_changed(self._on_name_owner_changed)
self._uid = await dbus_proxy.call_get_connection_unix_user(
self._iface.bus_name
)
@property
def user_id(self) -> int:
if self._uid == -1:
raise RuntimeError("Must populate agent data before getting user ID.")
return self._uid
async def authenticate(self, auth_flow: str, options: Variant) -> Variant:
"""Make an authentication request to the agent."""
flags = (
MessageFlag.ALLOW_INTERACTIVE_AUTHORIZATION
if allow_interactive_authorization.get()
else MessageFlag.NONE
)
return await self._iface.call_authenticate(auth_flow, options, flags=flags)
logger = logging.getLogger(f"{__name__}.AgentManagerInterface")
def __init__(self, bus: MessageBus, manager: "ManagerInterface"):
super().__init__("dev.ezri.vpn1.AgentManager")
self._agents: dict[str, self._Agent] = {}
self._bus = bus
self._dbus_proxy = bus.get_proxy_object(
"org.freedesktop.DBus", "/org/freedesktop/DBus", dbus_introspection_data
).get_interface("org.freedesktop.DBus")
self._dbus_proxy.on_name_owner_changed(self._on_name_owner_changed)
def _on_name_owner_changed(self, bus_name: str, old_name: str, new_name: str):
if bus_name in self._agents and bus_name == old_name and new_name == "":
# One of our agents is disconnecting from the bus. Unregister it.
self.logger.info(f"Agent at {bus_name} disconnecting, unregistering...")
del self._agents[bus_name]
@method()
async def Register(self, path: object_path, priority: int32) -> empty:
"""Register an agent with the manager."""
from_addr = sender.get()
agent = self._Agent(
self._bus,
self._bus.get_proxy_object(
from_addr, path, agent_introspection_data
).get_interface("dev.ezri.vpn1.Agent"),
priority,
)
await agent.populate()
self._agents[from_addr] = agent
self.logger.info(f"New agent registered for user {agent.user_id}")
async def authenticate(
self, auth_flow: str, options: Variant, allow_interactive: bool
) -> Variant:
"""
Select an agent and perform an authentication request.
This must be called in a method call context (i.e. a user has requested connection to a VPN), as it requires
a sender bus name to get a user ID that will allow the manager to match to the calling user's agent.
This does mean that calling these methods as root with e.g. sudo will likely not work, unless the root user
is also running an agent, which defeats the entire point of this project; at that point, just connect manually.
"""
bus_name = sender.get()
user_id = await self._dbus_proxy.call_get_connection_unix_user(bus_name)
matching_agents = [
agent for agent in self._agents.values() if agent.user_id == user_id
]
# Sort by priority (so a client could register itself as an agent to provide a fallback if no agent is available).
matching_agents.sort(key=attrgetter("priority"), reverse=True)
for agent in matching_agents:
try:
return await agent.authenticate(auth_flow, options)
except DBusError as e:
# Absorb dbus errors, we're asking forgiveness, not permission
continue
# If we get here, none of the agents for the user are available, or the user has no agents.
raise NoAvailableAgent()
class ManagerInterface(ServiceInterface):
@ -40,6 +145,8 @@ class ManagerInterface(ServiceInterface):
self.load_stored_connections()
self._polkit_proxy: ProxyInterface | None = None
self._logind_proxy: ProxyInterface | None = None
self._agent_manager = AgentManagerInterface(bus, self)
bus.export(self._agent_manager)
export_protocols(bus)
def load_stored_connections(self):
@ -75,47 +182,9 @@ class ManagerInterface(ServiceInterface):
f"Attempting to load file {file.name} failed: {e.text}"
)
async def _find_usable_agent(self, auth_flow: str) -> ProxyInterface:
"""Find the first agent that is available and supports the requested auth flow."""
to_purge = set()
for agent in self._agents:
try:
async with timeout(0.5):
if not await agent.get_can_handle_requests(): # type: ignore[attr-defined]
self.logger.debug(
f"Agent at {agent.bus_name} not considered because it cannot handle requests"
)
continue
if auth_flow in await agent.get_supported_auth_flows(): # type: ignore[attr-defined]
self.logger.debug(f"Found agent {agent.bus_name}")
return agent
else:
self.logger.debug(
f"Agent at {agent.bus_name} does not support auth_flow {auth_flow}"
)
except TimeoutError:
self.logger.warning(
f"Agent at {agent.bus_name} timed out, purging from known agents"
)
to_purge.add(agent)
except DBusError as e:
self.logger.warning(
f"DBus error {e} encounted at {agent.bus_name}, purging from known agents"
)
to_purge.add(agent)
self._agents = list(set(self._agents).difference(to_purge))
self.logger.error(f"No available agent can handle {auth_flow} requests")
raise NoAvailableAgent(f"No available agent can handle {auth_flow} requests")
async def request_credentials(self, auth_flow: str, options: Variant) -> Variant:
"""Request credentials from a user agent."""
agent = await self._find_usable_agent(auth_flow)
flags = (
MessageFlag.ALLOW_INTERACTIVE_AUTHORIZATION
if allow_interactive_authorization.get()
else MessageFlag.NONE
)
return await agent.call_authenticate(auth_flow, options, flags=flags) # type: ignore[attr-defined]
return await self._agent_manager.authenticate(auth_flow, options)
def notify_status(self):
"""Update systemd status with current connection state."""
@ -163,6 +232,23 @@ class ManagerInterface(ServiceInterface):
self._logind_proxy = obj.get_interface("org.freedesktop.login1.Manager")
return self._logind_proxy
async def on_resume_from_sleep(
self, func: Callable[[], None]
) -> Callable[[], None]:
"""
Register a function to be called when the system resumes from sleep.
Returns a function to unregister the callback.
"""
def on_signal(start: bool):
if not start:
func()
proxy = await self._get_logind_proxy()
proxy.on_prepare_for_sleep(on_signal)
return lambda: proxy.off_prepare_for_sleep(on_signal)
async def verify_polkit_auth(self, action_id: str, **kwargs: str):
"""Get polkit authorization for the given action."""
subject = ("system-bus-name", {"name": Variant("s", sender.get())})