diff --git a/sway_context_manager/interface.py b/sway_context_manager/interface.py index 4d6ee70..8687d87 100644 --- a/sway_context_manager/interface.py +++ b/sway_context_manager/interface.py @@ -30,7 +30,7 @@ class ContextMngrInterface( """Request a context switch. This will fail if the current monitor configuration is not compatible with the requested context.""" print("requesting context", context) try: - self.workspace_tree.switch_context(context) + await self.workspace_tree.activate_context(self.connection, context) return "OK" except Exception as e: return str(e) diff --git a/sway_context_manager/swayipc.py b/sway_context_manager/swayipc.py index 9b935d3..d0445e8 100644 --- a/sway_context_manager/swayipc.py +++ b/sway_context_manager/swayipc.py @@ -6,6 +6,8 @@ from .utils import async_debounce class SwayConnection: + last_output_set = None + def __init__(self, workspace_tree: WorkspaceTree): self.connection = Connection(auto_reconnect=True) self.workspace_tree = workspace_tree @@ -17,10 +19,17 @@ class SwayConnection: for callback in self.on_change_callbacks: await callback(self.workspace_tree) - @async_debounce(2) + @async_debounce(0.1) async def on_output(self, connection, event): - """On output event, after 2 seconds, update the workspace tree""" - self.workspace_tree.update_context(connection) + """On output event, update the workspace tree""" + print("Output event received", flush=True) + outputs = await connection.get_outputs() + if self.last_output_set: + if len(outputs) == len(self.last_output_set): + # If the number of outputs is the same, we can assume that the outputs are the same + return + self.last_output_set = outputs + await self.workspace_tree.update_context(connection) async def on_mode(self, event): """On mode change event, do something. Not sure what yet.""" @@ -35,6 +44,9 @@ class SwayConnection: self.connection.on(Event.OUTPUT, self.on_output) self.connection.on(Event.MODE, self.on_mode) + # Get initial output set + self.last_output_set = await self.connection.get_outputs() + # Update the workspace tree await self.workspace_tree.update_context(self.connection) await self.on_workspace() diff --git a/sway_context_manager/utils.py b/sway_context_manager/utils.py index 12cd8b6..2a180c0 100644 --- a/sway_context_manager/utils.py +++ b/sway_context_manager/utils.py @@ -1,5 +1,6 @@ import asyncio from functools import wraps +from enum import IntEnum def async_debounce(wait): @@ -22,3 +23,20 @@ def async_debounce(wait): return debounced return decorator + + +class OutputMatch(IntEnum): + """Enum for how an output was matched to a workspace group + + This is used to break ties between contexts that have the same compatibility score. Output name matching + is considered weaker than matching by make, model, and serial number, however, it is generic and can be used + to match, say, projectors in conference rooms. + + NO_MATCH: The output was not matched to the workspace group + NAME_MATCH: The output was matched by name to the workspace group + ID_MATCH: The output was matched by its make, model, and serial number to the workspace group + """ + + NO_MATCH = 0 + NAME_MATCH = 1 + ID_MATCH = 2 diff --git a/sway_context_manager/workspace_tree.py b/sway_context_manager/workspace_tree.py index 57234af..e22ab38 100644 --- a/sway_context_manager/workspace_tree.py +++ b/sway_context_manager/workspace_tree.py @@ -4,6 +4,7 @@ from i3ipc.replies import OutputReply from i3ipc.aio import Connection import asyncio import subprocess +from .utils import OutputMatch DEFAULT_TERMINAL = "alacritty" @@ -76,6 +77,7 @@ class WorkspaceGroup: def __init__(self, output_data: dict[str, str]): self.name = output_data["group"] + self.output_names = output_data.get("names", []) self.make = output_data.get("make", None) self.model = output_data.get("model", None) self.serial = output_data.get("serial", None) @@ -113,7 +115,19 @@ class WorkspaceGroup: """Returns whether the group is active.""" return any(workspace.focused for workspace in self.workspaces) - async def configure(self, i3: Connection): + async def focus(self, i3: Connection): + """Focus the group in Sway.""" + if self.make and self.model and self.serial: + await i3.command(f"focus output {self.make} {self.model} {self.serial}") + elif len(self.output_names) > 0: + for name in self.output_names: + await i3.command(f"focus output {name}") + else: + raise ValueError( + "No output name or make/model/serial provided, cannot focus group" + ) + + async def configure(self, i3: Connection, outputs: list[OutputReply]): """Configure the group output in Sway.""" transform = "" mode = "" @@ -124,21 +138,69 @@ class WorkspaceGroup: mode = f"mode {self.mode}" if self.make and self.model and self.serial: selector = f'"{self.make} {self.model} {self.serial}"' + elif len(self.output_names) > 0: + for name in self.output_names: + if name in [output.name for output in outputs]: + selector = name + break + # First, assign workspaces to the output + for workspace in self.workspaces: + await i3.command(f"workspace {workspace.index} output {selector}") + # Then, configure the output await i3.command( - f"output {selector} position {self.position[0]} {self.position[1]} {mode} {transform}" + f"output {selector} position {self.position[0]} {self.position[1]} {mode} {transform} enable" ) async def get_output_name(self, i3: Connection) -> str: """Get the name of the output in Sway.""" outputs = await i3.get_outputs() - for output in outputs: + # If we have make, model, and serial, search by those first + if self.make and self.model and self.serial: + for output in outputs: + if ( + output.make == self.make + and output.model == self.model + and output.serial == self.serial + ): + print( + f"Found output {output.name} by make, model, and serial for group {self.name}", + flush=True, + ) + return output.name + # If we don't find an exact match for the output, search by name if we have any + if len(self.output_names) > 0: + for output in outputs: + if output.name in self.output_names: + print( + f"Found output {output.name} by name for group {self.name}", + flush=True, + ) + return output.name + return None + + def get_match_level(self, output: OutputReply) -> OutputMatch: + """Get the match level score for the output.""" + if self.make and self.model and self.serial: if ( output.make == self.make and output.model == self.model and output.serial == self.serial ): - return output.name - return None + print( + f"Match level: ID_MATCH for {output.name} on group {self.name}", + flush=True, + ) + return OutputMatch.ID_MATCH + if output.name in self.output_names: + print( + f"Match level: NAME_MATCH for {output.name} on group {self.name}", + flush=True, + ) + return OutputMatch.NAME_MATCH + print( + f"Match level: NO_MATCH for {output.name} on group {self.name}", flush=True + ) + return OutputMatch.NO_MATCH class WorkspaceContext: @@ -157,6 +219,8 @@ class WorkspaceContext: self.name = name for group in self.groups: group_object = data["groups"][group.name] + if group_object.get("reverse", False): + group.reverse = True for workspace in group_object["workspaces"]: workspace_obj = next( (w for w in workspaces if w.index == workspace), @@ -168,31 +232,53 @@ class WorkspaceContext: raise Exception( f"Error: undefined workspace {workspace} referenced in context {name} group {group.name}" ) + primary_group_name = data.get("primary", self.groups[0].name) + self.primary_group = next( + (group for group in self.groups if group.name == primary_group_name), + None, + ) def add_group(self, group: WorkspaceGroup): """Add a group to the context.""" self.groups.append(group) + def compatability_rank(self, outputs: list[OutputReply]) -> int: + """Get the compatability rank of the context with the given outputs.""" + + result = 0 + for group in self.groups: + group_result = max( + [group.get_match_level(output).value for output in outputs] + ) + if group_result == 0: + return 0 + result += group_result + return result + async def activate(self, i3: Connection): """Activate the context in Sway.""" defined_displays = [ f"{group.make} {group.model} {group.serial}" for group in self.groups ] - # First, disable all displays not defined in the context - for output in await i3.get_outputs(): - if f"{output.make} {output.model} {output.serial}" not in defined_displays: - print("Disabling", output.name) - await i3.command(f"output {output.name} disable") - - # Next, configure all displays defined in the context - for group in self.groups: - await group.configure(i3) - - # Then, close all EWW windows + outputs = await i3.get_outputs() + # First, close all EWW windows proc = await asyncio.create_subprocess_exec("eww", "close-all") await proc.wait() - # Then, open all EWW windows defined in the context on the appropriate windows + # Second, if the focused workspace is not in the context, focus the primary output + if not self.active_group: + + return + + # Then, disable all displays, so we can assign workspaces to the correct ones before enabling them + for output in outputs: + await i3.command(f"output {output.name} disable") + + # Next, configure all displays defined in the context + for group in self.groups: + await group.configure(i3, outputs) + + # Finally, open all EWW windows defined in the context on the appropriate windows for group in self.groups: for window in group.eww_windows: proc = await asyncio.create_subprocess_exec( @@ -265,7 +351,7 @@ class WorkspaceTree: return f"WorkspaceTree({self.current_context}, {repr(self.contexts)})" def __json__(self): - return {context.name: context.__json__() for context in self.contexts} + return {self.current_context.name: self.current_context.__json__()} def get_workspace(self, user_index: int) -> Workspace: """Returns a workspace object based on the user index.""" @@ -330,37 +416,22 @@ class WorkspaceTree: """Activates a new context in Sway based on the current display configuration.""" # First, get the current display configuration outputs = await i3.get_outputs() - active_outputs = [ - f'"{output.make} {output.model} {output.serial}"' for output in outputs + print(outputs, flush=True) + + # Next, calculate match scores for each context + scores = [ + (context, context.compatability_rank(outputs)) for context in self.contexts ] + print([f"{context.name}: {score}" for context, score in scores], flush=True) + # Sort the scores by rank + scores.sort(key=lambda x: x[1], reverse=True) - # Next, find the context that matches the current display configuration. - # We first want to find an exact match, and if that fails, we want to find one where - # all of its required outputs are present (so extra outputs are fine, they'll be disabled). - # We will not return a partial match except for the case noted above. + # If the top context is the current context, or the rank is 0, do nothing + if scores[0][0] == self.current_context or scores[0][1] == 0: + return - for context in self.contexts: - # First pass, look for exact matches - context_outputs = [ - f'"{group.make} {group.model} {group.serial}"' - for group in context.groups - ] - if all(output in active_outputs for output in context_outputs) and all( - output in context_outputs for output in active_outputs - ): - self.current_context = context - await context.activate(i3) - return - - for context in self.contexts: - context_outputs = [ - f'"{group.make} {group.model} {group.serial}"' - for group in context.groups - ] - if all(output in active_outputs for output in context_outputs): - self.current_context = context - await context.activate(i3) - return + self.current_context = scores[0][0] + await self.current_context.activate(i3) async def activate_context(self, i3: Connection, name: str): """Activates a context by name. This will fail if the current display configuration is incompatible.""" @@ -370,18 +441,12 @@ class WorkspaceTree: ) outputs = await i3.get_outputs() - active_outputs = [ - f'"{output.make} {output.model} {output.serial}"' for output in outputs - ] - context_outputs = [ - f'"{group.make} {group.model} {group.serial}"' for group in context.groups - ] - - if all(output in active_outputs for output in context_outputs): - self.current_context = context - await context.activate(i3) - else: + score = context.compatability_rank(outputs) + if score == 0: raise ValueError( - f"Context {name} is incompatible with the current display configuration." + "Context is incompatible with current display configuration." ) + print(f"Activating context {context.name} with score {score}.", flush=True) + self.current_context = context + await context.activate(i3)