diff --git a/mypy.ini b/mypy.ini index cfb75eb0..7a8a99d0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -16,6 +16,8 @@ plugins = pydantic.mypy,sqlalchemy.ext.mypy.plugin exclude = (?x)( ^src/askui/models/ui_tars_ep/ui_tars_api\.py$ | ^src/askui/tools/askui/askui_ui_controller_grpc/.*$ + | ^venv/.*$ + | ^\.venv/.*$ ) mypy_path = src:tests explicit_package_bases = true diff --git a/src/askui/__init__.py b/src/askui/__init__.py index c5349466..e1e1b5e9 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -1,6 +1,6 @@ """AskUI Python SDK""" -__version__ = "0.32.1" +__version__ = "0.33.0" import logging import os @@ -45,6 +45,7 @@ from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase from .retry import ConfigurableRetry, Retry from .tools import ModifierKey, PcKey +from .tools.askui import LocalAgentOsServer, RemoteAgentOsServer from .utils.image_utils import ImageSource from .utils.source_utils import InputSource @@ -69,6 +70,8 @@ logging.getLogger(__name__).addHandler(logging.NullHandler()) __all__ = [ + "RemoteAgentOsServer", + "LocalAgentOsServer", "Agent", "AutomationError", "ComputerAgent", diff --git a/src/askui/computer_agent.py b/src/askui/computer_agent.py index 016607bd..0c0917a4 100644 --- a/src/askui/computer_agent.py +++ b/src/askui/computer_agent.py @@ -17,11 +17,13 @@ create_computer_agent_prompt, ) from askui.tools.computer import ( + ComputerGetActiveAgentOsServerTool, ComputerGetMousePositionTool, ComputerGetSystemInfoTool, ComputerKeyboardPressedTool, ComputerKeyboardReleaseTool, ComputerKeyboardTapTool, + ComputerListAgentOsServersTool, ComputerListDisplaysTool, ComputerMouseClickTool, ComputerMouseHoldDownTool, @@ -31,6 +33,7 @@ ComputerRetrieveActiveDisplayTool, ComputerScreenshotTool, ComputerSetActiveDisplayTool, + ComputerSwitchAgentOsServerTool, ComputerTypeTool, ) from askui.tools.exception_tool import ExceptionTool @@ -38,7 +41,7 @@ from .reporting import CompositeReporter, Reporter from .retry import Retry from .tools import AgentToolbox, ComputerAgentOsFacade, ModifierKey, PcKey -from .tools.askui import AskUiControllerClient +from .tools.askui import AgentOsServer, AskUiControllerClient logger = logging.getLogger(__name__) @@ -50,15 +53,36 @@ class ComputerAgent(Agent): This agent can perform various UI interactions like clicking, typing, scrolling, and more. It uses computer vision models to locate UI elements and execute actions on them. + A single `ComputerAgent` can drive **one or more machines** through the + `agent_os_servers` argument. Each entry is an Agent OS server (local + subprocess or remote gRPC endpoint) identified by a stable `computer_id`. + At any moment one server is *active* and receives all explicit calls + (`click`, `type`, `keyboard`, ...). The active server can be changed at + runtime via `agent.tools.os.switch_agent_os_server(computer_id)` or + scoped to a block using `agent.tools.os.temporary_select(computer_id)`. + The `act()` model is also given list/switch/get-active tools so it can + orchestrate work across machines on its own (e.g. read something on one + computer and re-enter it on another). + Args: - display (int, optional): The display number to use for screen interactions. Defaults to `1`. + display (int, optional): The display number to use for screen interactions on the default local server. Ignored when `agent_os_servers` is provided. Defaults to `1`. reporters (list[Reporter] | None, optional): List of reporter instances for logging and reporting. If `None`, an empty list is used. - tools (AgentToolbox | None, optional): Custom toolbox instance. If `None`, a default one will be created with `AskUiControllerClient`. + agent_os_servers (list[AgentOsServer] | None, optional): + Agent OS servers the agent can route actions to. May mix one + `LocalAgentOsServer` (managing a controller subprocess on this + machine) with any number of `RemoteAgentOsServer`s pointing at + controllers already running on other machines. Constraints: at + least one server, at most one local, and remote `address`es plus + all `computer_id`s must be unique. The first entry becomes the + initial active server. Defaults to a single local server bound to + `display`. settings (AgentSettings | None, optional): Provider-based model settings. If `None`, uses the default AskUI model stack. retry (Retry, optional): The retry instance to use for retrying failed actions. Defaults to `ConfigurableRetry` with exponential backoff. Currently only supported for `locate()` method. act_tools (list[Tool] | None, optional): Additional tools to make available for the `act()` method. Example: + Single local machine (the default): + ```python from askui import ComputerAgent @@ -67,16 +91,60 @@ class ComputerAgent(Agent): agent.type("Hello World") agent.act("Open settings menu") ``` + + Example: + Research on one machine and write up the findings on another. The + first server in the list is the active one; `temporary_select` + re-routes a block of explicit calls and restores the previous + active server on exit. + + ```python + from askui import ComputerAgent + from askui.tools.askui import LocalAgentOsServer, RemoteAgentOsServer + + with ComputerAgent( + agent_os_servers=[ + LocalAgentOsServer(computer_id="research-box"), + RemoteAgentOsServer( + address="192.168.1.42:26000", + description="Writer box with a text editor open", + computer_id="writer-box", + ), + ], + ) as agent: + agent.act( + "On research-box, open a browser, google 'askui', and read " + "the top results to gather key facts about what AskUI is, " + "what it does, and notable features. Then switch to " + "writer-box and write a Markdown document titled " + "'AskUI Findings' summarizing those facts as a bulleted " + "list in the open text editor." + ) + ``` + + Example: + Register a remote machine at runtime: + + ```python + from askui import ComputerAgent + + with ComputerAgent() as agent: + agent.tools.os.add_remote_agent_os_server( + address="10.0.0.5:26000", + description="Build server", + ) + agent.act("Kick off a release build on the build server") + ``` """ @telemetry.record_call( exclude={ "reporters", - "tools", "settings", "act_tools", "callbacks", "truncation_strategy", + "agent_os_servers", } ) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @@ -84,7 +152,7 @@ def __init__( self, display: Annotated[int, Field(ge=1)] = 1, reporters: list[Reporter] | None = None, - tools: AgentToolbox | None = None, + agent_os_servers: list[AgentOsServer] | None = None, settings: AgentSettings | None = None, retry: Retry | None = None, act_tools: list[Tool] | None = None, @@ -92,10 +160,11 @@ def __init__( truncation_strategy: TruncationStrategy | None = None, ) -> None: reporter = CompositeReporter(reporters=reporters) - self.tools = tools or AgentToolbox( + self.tools = AgentToolbox( agent_os=AskUiControllerClient( display=display, reporter=reporter, + agent_os_servers=agent_os_servers, ) ) super().__init__( @@ -519,6 +588,9 @@ def get_default_tools() -> list[Tool]: ComputerListDisplaysTool(), ComputerRetrieveActiveDisplayTool(), ComputerSetActiveDisplayTool(), + ComputerListAgentOsServersTool(), + ComputerSwitchAgentOsServerTool(), + ComputerGetActiveAgentOsServerTool(), ] diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index 74912911..93a47d85 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -534,12 +534,23 @@ def reset_tools(self, tools: list[Tool] | None = None) -> None: """Reset the tools in the collection with new tools.""" self._tools = tools or [] - def get_agent_os_by_tags(self, tags: list[str]) -> AgentOs | AndroidAgentOs: - """Get an agent OS by tags.""" + def get_agent_os_by_tags( + self, required_tags: list[str] + ) -> AgentOs | AndroidAgentOs: + """ + Find the first registered agent OS whose tags are a superset of + `required_tags`. + + Every tag in `required_tags` must appear in the agent OS's tags; the + agent OS may declare additional tags beyond those. + + Raises: + ValueError: when no registered agent OS satisfies the required tags. + """ for agent_os in self._agent_os_list: - if all(tag in agent_os.tags for tag in tags): + if all(required in agent_os.tags for required in required_tags): return agent_os - msg = f"Agent OS with tags [{', '.join(tags)}] not found" + msg = f"No agent OS satisfies required tags [{', '.join(required_tags)}]" raise ValueError(msg) def _initialize_tools(self) -> None: diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index 344a83ab..e10723e9 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -1,12 +1,18 @@ from abc import ABC, abstractmethod +from contextlib import AbstractContextManager from typing import TYPE_CHECKING, Literal from PIL import Image from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Self from askui.models.shared.tool_tags import ToolTags if TYPE_CHECKING: + from askui.tools.askui.agent_os_server import ( + AgentOsServer, + RemoteAgentOsServer, + ) from askui.tools.askui.askui_ui_controller_grpc.generated import ( Controller_V1_pb2 as controller_v1_pbs, ) @@ -676,3 +682,64 @@ def set_window_in_focus(self, process_id: int, window_id: int) -> None: window_id (int): The ID of the window to set as active. """ raise NotImplementedError + + # --- Agent-OS-server management ----------------------------------------------- + # These methods only do something meaningful for backends that talk to multiple + # Agent OS servers (`AskUiControllerClient`). Other `AgentOs` implementations + # (Playwright, Android, ...) inherit the default implementations, which raise + # `NotImplementedError`. + + def add_agent_os_server(self, server: "AgentOsServer") -> "AgentOsServer": + """Register an additional Agent OS server. Auto-connects if connected.""" + raise NotImplementedError + + def add_remote_agent_os_server( + self, + address: str, + description: str, + ) -> "RemoteAgentOsServer": + """Register an additional remote Agent OS server.""" + raise NotImplementedError + + def reset_agent_os_servers( + self, + agent_os_servers: "list[AgentOsServer] | None" = None, + ) -> None: + """Disconnect (if connected) and replace the Agent-OS-server list.""" + raise NotImplementedError + + def list_agent_os_servers(self) -> "list[AgentOsServer]": + """Return all registered Agent OS servers.""" + raise NotImplementedError + + def get_active_agent_os_server(self, report: bool = True) -> "AgentOsServer": + """Return the currently active Agent OS server.""" + raise NotImplementedError + + def switch_agent_os_server(self, computer_id: str) -> "AgentOsServer": + """Switch the active Agent OS server by its `computer_id`.""" + raise NotImplementedError + + def temporary_select(self, computer_id: str) -> AbstractContextManager[Self]: + """ + Temporarily switch the active Agent OS server for the duration of a `with` + block, then restore the previously-active server on exit (even if the block + raises). + + Args: + computer_id (str): Computer id of the server to activate inside the + block. + + Returns: + AbstractContextManager[Self]: Context manager that yields this + `AgentOs` with the selected server active. + + Example: + ```python + with agent_os.temporary_select('Remote-Machine') as remote_machine: + img = remote_machine.screenshot() + img.save("remote_machine.png") + # previous active server restored here + ``` + """ + raise NotImplementedError diff --git a/src/askui/tools/android/agent_os.py b/src/askui/tools/android/agent_os.py index 3a5a8285..d7fe7e04 100644 --- a/src/askui/tools/android/agent_os.py +++ b/src/askui/tools/android/agent_os.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod +from contextlib import AbstractContextManager from typing import List, Literal from PIL import Image +from typing_extensions import Self from askui.tools.android.uiautomator_hierarchy import UIElementCollection @@ -502,3 +504,26 @@ def get_ui_elements(self) -> UIElementCollection: Gets the UI elements. """ raise NotImplementedError + + def temporary_select(self, device_sn: str) -> AbstractContextManager[Self]: + """ + Temporarily switch the active device for the duration of a `with` block, + then restore the previously-active device on exit (even if the block + raises). + + Args: + device_sn (str): Serial number of the device to activate inside the + block. + + Returns: + AbstractContextManager[Self]: Context manager that yields this + `AndroidAgentOs` with `device_sn` active. + + Example: + ```python + with android_agent_os.temporary_select('table_phone') as table_phone: + table_phone.tap(100, 200) + # previous active device restored here + ``` + """ + raise NotImplementedError diff --git a/src/askui/tools/android/agent_os_facade.py b/src/askui/tools/android/agent_os_facade.py index f27d0eee..0bc19aea 100644 --- a/src/askui/tools/android/agent_os_facade.py +++ b/src/askui/tools/android/agent_os_facade.py @@ -1,6 +1,9 @@ +from collections.abc import Iterator +from contextlib import contextmanager from typing import List, Optional, Tuple from PIL import Image +from typing_extensions import Self from askui.models.shared.tool_tags import ToolTags from askui.tools.android.agent_os import ANDROID_KEY, AndroidAgentOs, AndroidDisplay @@ -112,6 +115,15 @@ def set_device_by_serial_number(self, device_sn: str) -> None: self._agent_os.set_device_by_serial_number(device_sn) self._real_screen_resolution = None + @contextmanager + def temporary_select(self, device_sn: str) -> Iterator[Self]: + with self._agent_os.temporary_select(device_sn): + self._real_screen_resolution = None + try: + yield self + finally: + self._real_screen_resolution = None + def get_connected_devices_serial_numbers(self) -> list[str]: return self._agent_os.get_connected_devices_serial_numbers() diff --git a/src/askui/tools/android/ppadb_agent_os.py b/src/askui/tools/android/ppadb_agent_os.py index 9ffa7452..517ed4e1 100644 --- a/src/askui/tools/android/ppadb_agent_os.py +++ b/src/askui/tools/android/ppadb_agent_os.py @@ -2,12 +2,15 @@ import re import shlex import string +from collections.abc import Iterator +from contextlib import contextmanager from pathlib import Path from typing import List, Optional, get_args from PIL import Image from ppadb.client import Client as AdbClient from ppadb.device import Device as AndroidDevice +from typing_extensions import Self from askui.reporting import NULL_REPORTER, Reporter from askui.tools.android.agent_os import ( @@ -202,6 +205,24 @@ def set_device_by_serial_number(self, device_sn: str) -> None: msg = f"Device name {device_sn} not found" raise AndroidAgentOsError(msg) + @contextmanager + def temporary_select(self, device_sn: str) -> Iterator[Self]: + previous_sn = self._device.serial if self._device is not None else None + self._reporter.add_message( + self._REPORTER_ROLE_NAME, + f"temporary_select({device_sn!r}) [previous={previous_sn!r}]", + ) + self.set_device_by_serial_number(device_sn) + try: + yield self + finally: + if previous_sn is not None and previous_sn != device_sn: + self.set_device_by_serial_number(previous_sn) + self._reporter.add_message( + self._REPORTER_ROLE_NAME, + f"temporary_select({device_sn!r}) -> restored", + ) + def _screenshot_without_reporting(self) -> Image.Image: device: AndroidDevice = self._get_selected_device() self._check_if_display_is_selected() diff --git a/src/askui/tools/askui/__init__.py b/src/askui/tools/askui/__init__.py index 5d46a982..7fb5b777 100644 --- a/src/askui/tools/askui/__init__.py +++ b/src/askui/tools/askui/__init__.py @@ -1,6 +1,17 @@ -from .askui_controller import AskUiControllerClient, AskUiControllerServer +from .agent_os_server import ( + AgentOsServer, + LocalAgentOsServer, + RemoteAgentOsServer, +) +from .agent_os_server_manager import ( + AgentOsServerManager, +) +from .askui_controller import AskUiControllerClient __all__ = [ + "AgentOsServer", + "AgentOsServerManager", "AskUiControllerClient", - "AskUiControllerServer", + "LocalAgentOsServer", + "RemoteAgentOsServer", ] diff --git a/src/askui/tools/askui/agent_os_server.py b/src/askui/tools/askui/agent_os_server.py new file mode 100644 index 00000000..d94340c2 --- /dev/null +++ b/src/askui/tools/askui/agent_os_server.py @@ -0,0 +1,320 @@ +import logging +import pathlib +import subprocess +import sys +import time +import uuid +from urllib.parse import urlparse + +from typing_extensions import override + +from askui.tools.askui.askui_controller_settings import AskUiControllerSettings +from askui.tools.utils import process_exists, wait_for_port + +logger = logging.getLogger(__name__) + + +def _generate_session_guid() -> str: + return "{" + str(uuid.uuid4()) + "}" + + +def _replace_port(address: str, port: int) -> str: + addr = address if "://" in address else "//" + address + parsed = urlparse(addr) + host = parsed.hostname or "localhost" + return f"{host}:{port}" + + +class AgentOsServer: + """ + Base class describing an Agent OS server that the `AskUiControllerClient` can + connect to. + + An Agent OS server is the server-side counterpart of the `AgentOs` client + abstraction. It runs on the machine being automated, exposes a gRPC API for + OS-level operations (screenshot, mouse, keyboard, ...), and is identified by + a unique session GUID. Each server also tracks which display it is currently + operating against. + + Args: + address (str): gRPC address of the Agent OS server + (e.g. ``"localhost:23000"``). + description (str): Human-readable description. + display (int, optional): Display ID selected for this server. Defaults to `1`. + computer_id (str | None, optional): Stable, human-friendly identifier for the + computer this server runs on. Used by `AgentOsServerManager` lookup + helpers. Must be unique across registered servers. Defaults to the + server's `session_guid`. + """ + + def __init__( + self, + address: str, + description: str, + display: int = 1, + computer_id: str | None = None, + ) -> None: + self._session_guid = _generate_session_guid() + self._address = address + self._description = description + self._display = display + self._computer_id = ( + computer_id if computer_id is not None else self._session_guid + ) + + @property + def session_guid(self) -> str: + """Unique session GUID assigned to this Agent OS server.""" + return self._session_guid + + @property + def computer_id(self) -> str: + """ + Stable identifier for the computer this Agent OS server runs on. Defaults + to `session_guid` when no custom id was supplied at construction time. + """ + return self._computer_id + + @property + def address(self) -> str: + """gRPC address of the Agent OS server.""" + return self._address + + @property + def description(self) -> str: + """Description of this Agent OS server.""" + return self._description + + @property + def display(self) -> int: + """Display ID currently selected for this Agent OS server.""" + return self._display + + @display.setter + def display(self, value: int) -> None: + self._display = value + + @property + def is_local(self) -> bool: + """Whether this server represents a locally-managed process.""" + return False + + def start(self, clean_up: bool = False) -> None: + """Start the underlying server process. No-op for non-local servers.""" + + def stop(self, force: bool = False) -> None: + """Stop the underlying server process. No-op for non-local servers.""" + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"computer_id={self._computer_id!r}, " + f"description={self._description!r}, " + f"display={self._display!r})" + ) + + +class LocalAgentOsServer(AgentOsServer): + """ + Local Agent OS server: manages an AskUI Remote Device Controller subprocess on + this machine. + + Args: + settings (AskUiControllerSettings | None, optional): Process-level settings + (executable path, args). Defaults to a fresh `AskUiControllerSettings`. + address (str, optional): gRPC address. Defaults to ``"localhost:23000"``. + is_service (bool, optional): When `True`, `start()` does not launch the + controller binary because it is managed externally (e.g. AskUI Core + Service on Windows). Defaults to `False`. + discover_service (bool, optional): On Windows, probe for a running + ``askuicoreservice`` and, if found, switch the address to port + ``26000`` and set `is_service` to `True`. Defaults to `True`. + description (str, optional) + display (int, optional): Display ID selected for this server. Defaults to `1`. + """ + + _ASKUI_CORE_SERVICE_NAME = "AskuiCoreService" + _ASKUI_CORE_SERVICE_PORT = 26000 + + def __init__( + self, + description: str = "Local Agent OS server", + settings: AskUiControllerSettings | None = None, + address: str = "localhost:23000", + is_service: bool = False, + discover_service: bool = True, + display: int = 1, + computer_id: str | None = None, + ) -> None: + if discover_service and self._is_askui_core_service_running(): + service_msg = ( + f"Detected running {self._ASKUI_CORE_SERVICE_NAME}; using port " + f"{self._ASKUI_CORE_SERVICE_PORT} (controller managed by service)" + ) + logger.info(service_msg) + address = _replace_port(address, self._ASKUI_CORE_SERVICE_PORT) + is_service = True + super().__init__( + address=address, + description=description, + display=display, + computer_id=computer_id, + ) + self._is_service = is_service + self._settings = settings or AskUiControllerSettings() + self._process: subprocess.Popen[bytes] | None = None + + @property + @override + def is_local(self) -> bool: + return True + + @property + def is_service(self) -> bool: + """Whether the server process is managed externally (skip `start()`).""" + return self._is_service + + @staticmethod + def _is_askui_core_service_running() -> bool: + """Return `True` when the `AskuiCoreService` Windows service is RUNNING.""" + if sys.platform == "win32": + try: + result = subprocess.run( + ["sc", "query", LocalAgentOsServer._ASKUI_CORE_SERVICE_NAME], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + except (OSError, subprocess.SubprocessError): + error_msg = ( + "Failed to query " + f"{LocalAgentOsServer._ASKUI_CORE_SERVICE_NAME} service" + ) + logger.debug(error_msg) + return False + if result.returncode != 0: + return False + return "RUNNING" in result.stdout.upper() + return False + + def _parse_port(self) -> int: + addr = self._address if "://" in self._address else "//" + self._address + parsed = urlparse(addr) + if parsed.port is None: + error_msg = ( + f"Could not parse port from address {self._address!r}. " + "Expected format 'host:port' (e.g. 'localhost:23000')." + ) + raise ValueError(error_msg) + return parsed.port + + def _start_process( + self, + path: pathlib.Path, + args: str | None = None, + ) -> None: + commands = [str(path)] + if args: + commands.extend(args.split()) + if not logger.isEnabledFor(logging.DEBUG): + self._process = subprocess.Popen( + commands, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + else: + self._process = subprocess.Popen(commands) + wait_for_port(self._parse_port()) + + @override + def start(self, clean_up: bool = False) -> None: + """ + Start the server process unless this server uses a service-managed binary. + + Args: + clean_up (bool, optional): Whether to clean up existing processes + (only on Windows) before starting. Defaults to `False`. + """ + if self._is_service: + logger.debug( + "Skipping local Agent OS server start; process is managed by service" + ) + return + if ( + sys.platform == "win32" + and clean_up + and process_exists("AskuiRemoteDeviceController.exe") + ): + self.clean_up() + logger.debug( + "Starting AskUI Remote Device Controller", + extra={"path": str(self._settings.controller_path)}, + ) + self._start_process( + self._settings.controller_path, self._settings.controller_args + ) + time.sleep(0.5) + + def clean_up(self) -> None: + subprocess.run("taskkill.exe /IM AskUI*") + time.sleep(0.1) + + @override + def stop(self, force: bool = False) -> None: + """ + Stop the server process. + + Args: + force (bool, optional): Whether to forcefully terminate the process. + Defaults to `False`. + """ + if self._process is None: + return + + try: + if force: + self._process.kill() + if sys.platform == "win32": + self.clean_up() + else: + self._process.terminate() + except Exception: # noqa: BLE001 - We want to catch all other exceptions here + logger.exception("Agent OS server error") + finally: + self._process = None + + +class RemoteAgentOsServer(AgentOsServer): + """ + Remote Agent OS server: the client connects to an already-running server on + another machine. + + No process management is performed; `start()` and `stop()` are no-ops. + + Args: + address (str): gRPC address of the remote Agent OS server (required). + description (str): Human-readable description. + display (int, optional): Display ID selected for this server. Defaults to `1`. + computer_id (str | None, optional): Stable, human-friendly identifier for the + computer this server runs on. Defaults to the server's `session_guid`. + """ + + def __init__( + self, + address: str, + description: str, + display: int = 1, + computer_id: str | None = None, + ) -> None: + super().__init__( + address=address, + description=description, + display=display, + computer_id=computer_id, + ) + + +__all__ = [ + "AgentOsServer", + "LocalAgentOsServer", + "RemoteAgentOsServer", +] diff --git a/src/askui/tools/askui/agent_os_server_manager.py b/src/askui/tools/askui/agent_os_server_manager.py new file mode 100644 index 00000000..612a5681 --- /dev/null +++ b/src/askui/tools/askui/agent_os_server_manager.py @@ -0,0 +1,195 @@ +from askui.tools.askui.agent_os_server import ( + AgentOsServer, + RemoteAgentOsServer, +) + + +class AgentOsServerManager: + """ + Manages a collection of `AgentOsServer` instances and tracks the currently + active one. + + Constraints: + - At most one local Agent OS server (where `is_local` is `True`) may be + registered at a time. + - Session GUIDs are unique across registered servers. + - Computer ids (`AgentOsServer.computer_id`) are unique across registered + servers. + - Remote server addresses must be unique across registered remote servers. + + The first server added becomes the active one by default. Use `switch` to change + the active server. + + Args: + agent_os_servers (list[AgentOsServer] | None, optional): Initial servers + to register. + """ + + def __init__( + self, + agent_os_servers: list[AgentOsServer] | None = None, + ) -> None: + self._servers: list[AgentOsServer] = [] + self._active_session_guid: str | None = None + if agent_os_servers: + for server in agent_os_servers: + self.add(server) + + def add(self, server: AgentOsServer) -> AgentOsServer: + """ + Register an Agent OS server. + + Args: + server (AgentOsServer): The server to register. + + Returns: + AgentOsServer: The registered server. + + Raises: + ValueError: If a local server is already registered, the same session + GUID or computer id is already registered, or a remote server with + the same address is already registered. + """ + if server.is_local and any(s.is_local for s in self._servers): + existing = next(s for s in self._servers if s.is_local) + error_msg = ( + "Cannot register a second local Agent OS server. At most one local " + f"server is supported. Existing local server: " + f"{existing.description!r} (computer_id={existing.computer_id!r}). " + "Remove it first via `remove(computer_id)`." + ) + raise ValueError(error_msg) + if any(s.session_guid == server.session_guid for s in self._servers): + error_msg = ( + f"An Agent OS server with session_guid={server.session_guid} is " + "already registered. Each server must have a unique session GUID." + ) + raise ValueError(error_msg) + if any(s.computer_id == server.computer_id for s in self._servers): + error_msg = ( + f"An Agent OS server with computer_id={server.computer_id!r} is " + "already registered. Each server must have a unique computer_id." + ) + raise ValueError(error_msg) + if not server.is_local and any( + (not s.is_local) and s.address == server.address for s in self._servers + ): + error_msg = ( + f"A remote Agent OS server with address {server.address!r} is " + "already registered. Each remote server must have a unique address." + ) + raise ValueError(error_msg) + self._servers.append(server) + if self._active_session_guid is None: + self._active_session_guid = server.session_guid + return server + + def add_remote( + self, + address: str, + description: str, + ) -> RemoteAgentOsServer: + """ + Convenience method to construct and register a remote Agent OS server. + + Args: + address (str): gRPC address of the remote Agent OS server. + description (str): Human-readable description. + + Returns: + RemoteAgentOsServer: The newly registered server. + """ + server = RemoteAgentOsServer(address=address, description=description) + self.add(server) + return server + + def reset(self) -> None: + """Remove all registered servers.""" + self._servers = [] + self._active_session_guid = None + + def remove(self, computer_id: str) -> None: + """ + Remove a registered server by its `computer_id`. + + Args: + computer_id (str): The computer id of the server to remove. + + Raises: + KeyError: If no server with the given computer id is registered. + """ + index = self._index_of(computer_id) + removed = self._servers[index] + del self._servers[index] + if self._active_session_guid == removed.session_guid: + self._active_session_guid = ( + self._servers[0].session_guid if self._servers else None + ) + + def list(self) -> list[AgentOsServer]: + """Return a list of all registered servers.""" + return list(self._servers) + + def get(self, computer_id: str) -> AgentOsServer: + """ + Return the registered server with the given `computer_id`. + + Raises: + KeyError: If no server with the given computer id is registered. + """ + return self._servers[self._index_of(computer_id)] + + def get_by_session_guid(self, session_guid: str) -> AgentOsServer | None: + """ + Return the registered server with the given `session_guid`, or `None` if + no such server is registered. + + Intended for internal lookups (e.g. mapping a gRPC session GUID back to + its server during teardown). For user-facing selection, prefer `get`. + """ + for server in self._servers: + if server.session_guid == session_guid: + return server + return None + + def switch(self, computer_id: str) -> AgentOsServer: + """ + Set the active server by its `computer_id`. + + Args: + computer_id (str): The computer id of the server to activate. + + Returns: + AgentOsServer: The newly active server. + + Raises: + KeyError: If no server with the given computer id is registered. + """ + server = self.get(computer_id) + self._active_session_guid = server.session_guid + return server + + @property + def active(self) -> AgentOsServer | None: + """The currently active server, or `None` if no servers are registered.""" + if self._active_session_guid is None: + return None + return self.get_by_session_guid(self._active_session_guid) + + def __len__(self) -> int: + return len(self._servers) + + def _index_of(self, computer_id: str) -> int: + for i, server in enumerate(self._servers): + if server.computer_id == computer_id: + return i + registered = ", ".join(repr(s.computer_id) for s in self._servers) or "none" + error_msg = ( + f"No Agent OS server with computer_id={computer_id!r} is registered. " + f"Registered computer ids: {registered}. Use " + "`list_agent_os_servers()` to inspect the registered servers." + ) + raise KeyError(error_msg) + + +__all__ = ["AgentOsServerManager"] diff --git a/src/askui/tools/askui/askui_controller.py b/src/askui/tools/askui/askui_controller.py index 5e8814bf..b20f3dc7 100644 --- a/src/askui/tools/askui/askui_controller.py +++ b/src/askui/tools/askui/askui_controller.py @@ -1,1296 +1,1543 @@ -import logging -import pathlib -import subprocess -import sys -import time -import types -import uuid -from typing import Literal, Type - -import grpc -from google.protobuf.json_format import MessageToDict -from PIL import Image -from typing_extensions import Self, override - -from askui.container import telemetry -from askui.reporting import NULL_REPORTER, Reporter -from askui.tools.agent_os import ( - AgentOs, - Coordinate, - Display, - DisplaysListResponse, - ModifierKey, - PcKey, -) -from askui.tools.askui.askui_controller_client_settings import ( - AskUiControllerClientSettings, -) -from askui.tools.askui.askui_controller_settings import AskUiControllerSettings -from askui.tools.askui.askui_ui_controller_grpc.desktop_agent_os_error import ( - DesktopAgentOsError, -) -from askui.tools.askui.askui_ui_controller_grpc.generated import ( - Controller_V1_pb2 as controller_v1_pbs, -) -from askui.tools.askui.askui_ui_controller_grpc.generated import ( - Controller_V1_pb2_grpc as controller_v1, -) -from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Request_2501 import ( # noqa: E501 - AddRenderObjectCommand, - AskUIAgentOSSendRequestSchema, - ClearRenderObjectsCommand, - Command, - DeleteRenderObjectCommand, - GetActiveProcessCommand, - GetActiveWindowCommand, - GetMousePositionCommand, - GetSystemInfoCommand, - Guid, - Header, - Length, - Location, - Message, - Parameter3, - RenderImage, - RenderObjectId, - RenderObjectStyle, - RenderText, - SetActiveProcessCommand, - SetActiveWindowCommand, - SetMousePositionCommand, - UpdateRenderObjectCommand, -) -from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Response_2501 import ( # noqa: E501 - AskUIAgentOSSendResponseSchema, - GetActiveProcessResponse, - GetActiveProcessResponseModel, - GetActiveWindowResponse, - GetActiveWindowResponseModel, - GetSystemInfoResponse, - GetSystemInfoResponseModel, -) -from askui.utils.annotated_image import AnnotatedImage - -from ..utils import process_exists, wait_for_port -from .exceptions import ( - AskUiControllerError, - AskUiControllerInvalidCommandError, - AskUiControllerOperationTimeoutError, -) - -logger = logging.getLogger(__name__) - - -class AskUiControllerServer: - """ - Concrete implementation of `ControllerServer` for managing the AskUI Remote Device - Controller process. - Handles process discovery, startup, and shutdown for the native controller binary. - - Args: - settings (AskUiControllerSettings | None, optional): Settings for the AskUI. - """ - - def __init__(self, settings: AskUiControllerSettings | None = None) -> None: - self._process: subprocess.Popen[bytes] | None = None - self._settings = settings or AskUiControllerSettings() - - def _start_process( - self, - path: pathlib.Path, - args: str | None = None, - ) -> None: - commands = [str(path)] - if args: - commands.extend(args.split()) - if not logger.isEnabledFor(logging.DEBUG): - self._process = subprocess.Popen( - commands, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL - ) - else: - self._process = subprocess.Popen(commands) - wait_for_port(23000) - - def start(self, clean_up: bool = False) -> None: - """ - Start the controller process. - - Args: - clean_up (bool, optional): Whether to clean up existing processes - (only on Windows) before starting. Defaults to `False`. - """ - if ( - sys.platform == "win32" - and clean_up - and process_exists("AskuiRemoteDeviceController.exe") - ): - self.clean_up() - logger.debug( - "Starting AskUI Remote Device Controller", - extra={"path": str(self._settings.controller_path)}, - ) - self._start_process( - self._settings.controller_path, self._settings.controller_args - ) - time.sleep(0.5) - - def clean_up(self) -> None: - subprocess.run("taskkill.exe /IM AskUI*") - time.sleep(0.1) - - def stop(self, force: bool = False) -> None: - """ - Stop the controller process. - - Args: - force (bool, optional): Whether to forcefully terminate the process. - Defaults to `False`. - """ - if self._process is None: - return # Nothing to stop - - try: - if force: - self._process.kill() - if sys.platform == "win32": - self.clean_up() - else: - self._process.terminate() - except Exception: # noqa: BLE001 - We want to catch all other exceptions here - logger.exception("Controller error") - finally: - self._process = None - - -class AskUiControllerClient(AgentOs): - """ - Implementation of `AgentOs` that communicates with the AskUI Remote Device - Controller via gRPC. - - Args: - reporter (Reporter): Reporter used for reporting with the `"AgentOs"`. - display (int, optional): Display number to use. Defaults to `1`. - controller_server (AskUiControllerServer | None, optional): Custom controller - server. Defaults to `ControllerServer`. - """ - - @telemetry.record_call(exclude={"reporter", "controller_server"}) - def __init__( - self, - reporter: Reporter = NULL_REPORTER, - display: int = 1, - controller_server: AskUiControllerServer | None = None, - settings: AskUiControllerClientSettings | None = None, - ) -> None: - self._stub: controller_v1.ControllerAPIStub | None = None - self._channel: grpc.Channel | None = None - self._session_info: controller_v1_pbs.SessionInfo | None = None - self._pre_action_wait = 0 - self._post_action_wait = 0.05 - self._max_retries = 10 - self._display = display - self._reporter = reporter - self._controller_server = controller_server or AskUiControllerServer() - self._session_guid = "{" + str(uuid.uuid4()) + "}" - self._settings = settings or AskUiControllerClientSettings() - - @telemetry.record_call() - @override - def connect(self) -> None: - """ - Establishes a connection to the AskUI Remote Device Controller. - - This method starts the controller server, establishes a gRPC channel, - creates a session, and sets up the initial display. - """ - if self._settings.server_autostart: - self._controller_server.start() - self._channel = grpc.insecure_channel( - self._settings.server_address, - options=[ - ("grpc.max_send_message_length", 2**30), - ("grpc.max_receive_message_length", 2**30), - ("grpc.default_deadline", 300000), - ], - ) - self._stub = controller_v1.ControllerAPIStub(self._channel) - self._start_session() - self._start_execution() - self.set_display(self._display) - - def _get_stub(self) -> controller_v1.ControllerAPIStub: - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized. Call `connect()` first." - ) - return self._stub - - def _run_recorder_action( - self, - acion_class_id: controller_v1_pbs.ActionClassID, - action_parameters: controller_v1_pbs.ActionParameters, - ) -> controller_v1_pbs.Response_RunRecordedAction: - time.sleep(self._pre_action_wait) - response: controller_v1_pbs.Response_RunRecordedAction = ( - self._get_stub().RunRecordedAction( - controller_v1_pbs.Request_RunRecordedAction( - sessionInfo=self._session_info, - actionClassID=acion_class_id, - actionParameters=action_parameters, - ) - ) - ) - - time.sleep((response.requiredMilliseconds / 1000)) - num_retries = 0 - for _ in range(self._max_retries): - poll_response: controller_v1_pbs.Response_Poll = self._get_stub().Poll( - controller_v1_pbs.Request_Poll( - sessionInfo=self._session_info, - pollEventID=controller_v1_pbs.PollEventID.PollEventID_ActionFinished, - ) - ) - if ( - poll_response.pollEventParameters.actionFinished.actionID - == response.actionID - ): - break - time.sleep(self._post_action_wait) - num_retries += 1 - if num_retries == self._max_retries - 1: - raise AskUiControllerOperationTimeoutError - return response - - @telemetry.record_call() - @override - def disconnect(self) -> None: - """ - Terminates the connection to the AskUI Remote Device Controller. - - This method stops the execution, ends the session, closes the gRPC channel, - and stops the controller server. - """ - try: - self._stop_execution() - self._stop_session() - if self._channel is not None: - self._channel.close() - self._controller_server.stop() - except Exception as e: # noqa: BLE001 - # We want to catch all other exceptions here and not re-raise them - msg = ( - "Error while disconnecting from the AskUI Remote Device Controller" - f" Error: {e}" - ) - logger.exception(msg) - - @telemetry.record_call() - def __enter__(self) -> Self: - """ - Context manager entry point that establishes the connection. - - Returns: - Self: The instance of AskUiControllerClient. - """ - self.connect() - return self - - @telemetry.record_call(exclude={"exc_value", "traceback"}) - def __exit__( - self, - exc_type: Type[BaseException] | None, - exc_value: BaseException | None, - traceback: types.TracebackType | None, - ) -> None: - """ - Context manager exit point that disconnects the client. - - Args: - exc_type: The exception type if an exception was raised. - exc_value: The exception value if an exception was raised. - traceback: The traceback if an exception was raised. - """ - self.disconnect() - - def _start_session(self) -> None: - response = self._get_stub().StartSession( - controller_v1_pbs.Request_StartSession( - sessionGUID=self._session_guid, immediateExecution=True - ) - ) - self._session_info = response.sessionInfo - - def _stop_session(self) -> None: - self._get_stub().EndSession( - controller_v1_pbs.Request_EndSession(sessionInfo=self._session_info) - ) - - def _start_execution(self) -> None: - self._get_stub().StartExecution( - controller_v1_pbs.Request_StartExecution(sessionInfo=self._session_info) - ) - - def _stop_execution(self) -> None: - self._get_stub().StopExecution( - controller_v1_pbs.Request_StopExecution(sessionInfo=self._session_info) - ) - - @telemetry.record_call() - @override - def screenshot(self, report: bool = True) -> Image.Image: - """ - Take a screenshot of the current screen. - - Args: - report (bool, optional): Whether to include the screenshot in reporting. - Defaults to `True`. - - Returns: - Image.Image: A PIL Image object containing the screenshot. - - """ - screenResponse = self._get_stub().CaptureScreen( - controller_v1_pbs.Request_CaptureScreen( - sessionInfo=self._session_info, - captureParameters=controller_v1_pbs.CaptureParameters( - displayID=self._display - ), - ) - ) - r, g, b, _ = Image.frombytes( - "RGBA", - (screenResponse.bitmap.width, screenResponse.bitmap.height), - screenResponse.bitmap.data, - ).split() - image = Image.merge("RGB", (b, g, r)) - self._reporter.add_message("AgentOS", "screenshot()", image) - return image - - @telemetry.record_call() - @override - def mouse_move(self, x: int, y: int, duration: int = 500) -> None: - """ - Moves the mouse cursor to specified screen coordinates. - - Args: - x (int): The horizontal coordinate (in pixels) to move to. - y (int): The vertical coordinate (in pixels) to move to. - duration (int): The duration (in ms) the movement should take. - """ - self._reporter.add_message( - "AgentOS", - f"mouse_move({x}, {y}, duration={duration})", - AnnotatedImage(lambda: self.screenshot(report=False), point_list=[(x, y)]), - ) - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseMove, - action_parameters=controller_v1_pbs.ActionParameters( - mouseMove=controller_v1_pbs.ActionParameters_MouseMove( - position=controller_v1_pbs.Coordinate2(x=x, y=y), - milliseconds=duration, - ) - ), - ) - - @telemetry.record_call(exclude={"text"}) - @override - def type(self, text: str, typing_speed: int = 50) -> None: - """ - Type text at current cursor position as if entered on a keyboard. - - Args: - text (str): The text to type. - typing_speed (int, optional): The speed of typing in characters per second. - Defaults to `50`. - """ - self._reporter.add_message("AgentOS", f'type("{text}", {typing_speed})') - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardType_UnicodeText, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardTypeUnicodeText=controller_v1_pbs.ActionParameters_KeyboardType_UnicodeText( - text=text.encode("utf-16-le"), - typingSpeed=typing_speed, - typingSpeedValue=controller_v1_pbs.TypingSpeedValue.TypingSpeedValue_CharactersPerSecond, - ) - ), - ) - - @telemetry.record_call() - @override - def click( - self, button: Literal["left", "middle", "right"] = "left", count: int = 1 - ) -> None: - """ - Click a mouse button. - - Args: - button (Literal["left", "middle", "right"], optional): The mouse button to - click. Defaults to `"left"`. - count (int, optional): Number of times to click. Defaults to `1`. - """ - self._reporter.add_message("AgentOS", f'click("{button}", {count})') - mouse_button = None - match button: - case "left": - mouse_button = controller_v1_pbs.MouseButton_Left - case "middle": - mouse_button = controller_v1_pbs.MouseButton_Middle - case "right": - mouse_button = controller_v1_pbs.MouseButton_Right - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_PressAndRelease, - action_parameters=controller_v1_pbs.ActionParameters( - mouseButtonPressAndRelease=controller_v1_pbs.ActionParameters_MouseButton_PressAndRelease( - mouseButton=mouse_button, count=count - ) - ), - ) - - @telemetry.record_call() - @override - def mouse_down(self, button: Literal["left", "middle", "right"] = "left") -> None: - """ - Press and hold a mouse button. - - Args: - button (Literal["left", "middle", "right"], optional): The mouse button to - press. Defaults to `"left"`. - """ - self._reporter.add_message("AgentOS", f'mouse_down("{button}")') - mouse_button = None - match button: - case "left": - mouse_button = controller_v1_pbs.MouseButton_Left - case "middle": - mouse_button = controller_v1_pbs.MouseButton_Middle - case "right": - mouse_button = controller_v1_pbs.MouseButton_Right - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Press, - action_parameters=controller_v1_pbs.ActionParameters( - mouseButtonPress=controller_v1_pbs.ActionParameters_MouseButton_Press( - mouseButton=mouse_button - ) - ), - ) - - @telemetry.record_call() - @override - def mouse_up(self, button: Literal["left", "middle", "right"] = "left") -> None: - """ - Release a mouse button. - - Args: - button (Literal["left", "middle", "right"], optional): The mouse button to - release. Defaults to `"left"`. - """ - self._reporter.add_message("AgentOS", f'mouse_up("{button}")') - mouse_button = None - match button: - case "left": - mouse_button = controller_v1_pbs.MouseButton_Left - case "middle": - mouse_button = controller_v1_pbs.MouseButton_Middle - case "right": - mouse_button = controller_v1_pbs.MouseButton_Right - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Release, - action_parameters=controller_v1_pbs.ActionParameters( - mouseButtonRelease=controller_v1_pbs.ActionParameters_MouseButton_Release( - mouseButton=mouse_button - ) - ), - ) - - @telemetry.record_call() - @override - def mouse_scroll(self, dx: int, dy: int) -> None: - """ - Scroll the mouse wheel. - - Args: - dx (int): The horizontal scroll amount. Positive values scroll right, - negative values scroll left. - dy (int): The vertical scroll amount. Positive values scroll down, - negative values scroll up. - """ - self._reporter.add_message("AgentOS", f"mouse_scroll({dx}, {dy})") - if dx != 0: - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, - action_parameters=controller_v1_pbs.ActionParameters( - mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( - direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Horizontal, - deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, - delta=dx, - milliseconds=50, - ) - ), - ) - if dy != 0: - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, - action_parameters=controller_v1_pbs.ActionParameters( - mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( - direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Vertical, - deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, - delta=dy, - milliseconds=50, - ) - ), - ) - - @telemetry.record_call() - @override - def keyboard_pressed( - self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None - ) -> None: - """ - Press and hold a keyboard key. - - Args: - key (PcKey | ModifierKey): The key to press. - modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to - press along with the main key. Defaults to `None`. - """ - self._reporter.add_message( - "AgentOS", f'keyboard_pressed("{key}", {modifier_keys})' - ) - if modifier_keys is None: - modifier_keys = [] - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Press, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardKeyPress=controller_v1_pbs.ActionParameters_KeyboardKey_Press( - keyName=key, modifierKeyNames=modifier_keys - ) - ), - ) - - @telemetry.record_call() - @override - def keyboard_release( - self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None - ) -> None: - """ - Release a keyboard key. - - Args: - key (PcKey | ModifierKey): The key to release. - modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to - release along with the main key. Defaults to `None`. - """ - self._reporter.add_message( - "AgentOS", f'keyboard_release("{key}", {modifier_keys})' - ) - if modifier_keys is None: - modifier_keys = [] - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Release, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardKeyRelease=controller_v1_pbs.ActionParameters_KeyboardKey_Release( - keyName=key, modifierKeyNames=modifier_keys - ) - ), - ) - - @telemetry.record_call() - @override - def keyboard_tap( - self, - key: PcKey | ModifierKey, - modifier_keys: list[ModifierKey] | None = None, - count: int = 1, - ) -> None: - """ - Press and immediately release a keyboard key. - - Args: - key (PcKey | ModifierKey): The key to tap. - modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to - press along with the main key. Defaults to `None`. - count (int, optional): The number of times to tap the key. Defaults to `1`. - """ - self._reporter.add_message( - "AgentOS", - f'keyboard_tap("{key}", {modifier_keys}, {count})', - ) - if modifier_keys is None: - modifier_keys = [] - for _ in range(count): - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_PressAndRelease, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardKeyPressAndRelease=controller_v1_pbs.ActionParameters_KeyboardKey_PressAndRelease( - keyName=key, modifierKeyNames=modifier_keys - ) - ), - ) - - @telemetry.record_call() - @override - def set_display(self, display: int = 1) -> None: - """ - Set the active display. - - Args: - display (int, optional): The display ID to set as active. - This can be either a real display ID or a virtual display ID. - Defaults to `1`. - """ - self._get_stub().SetActiveDisplay( - controller_v1_pbs.Request_SetActiveDisplay(displayID=display) - ) - self._display = display - self._reporter.add_message("AgentOS", f"set_display({display})") - - @telemetry.record_call(exclude={"command"}) - @override - def run_command(self, command: str, timeout_ms: int = 30000) -> None: - """ - Execute a shell command. - - Args: - command (str): The command to execute. - timeout_ms (int, optional): The timeout for command - execution in milliseconds. Defaults to `30000` (30 seconds). - """ - self._reporter.add_message("AgentOS", f'run_command("{command}", {timeout_ms})') - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_RunCommand, - action_parameters=controller_v1_pbs.ActionParameters( - runcommand=controller_v1_pbs.ActionParameters_RunCommand( - command=command, timeoutInMilliseconds=timeout_ms - ) - ), - ) - - @telemetry.record_call() - @override - def retrieve_active_display(self) -> Display: - """ - Retrieve the currently active display/screen. - - Returns: - Display: The currently active display/screen. - """ - self._reporter.add_message("AgentOS", "retrieve_active_display()") - displays_list_response = self.list_displays() - for display in displays_list_response.data: - if display.id == self._display: - self._reporter.add_message( - "AgentOS", f"retrieve_active_display() -> {display}" - ) - return display - error_msg = f"Display {self._display} not found" - raise ValueError(error_msg) - - @telemetry.record_call() - @override - def list_displays( - self, - ) -> DisplaysListResponse: - """ - List all available Displays from the controller. - It includes both real and virtual displays - without describing the type of display (virtual or real). - - Returns: - DisplaysListResponse - """ - - self._reporter.add_message("AgentOS", "list_displays()") - - response: controller_v1_pbs.Response_GetDisplayInformation = ( - self._get_stub().GetDisplayInformation(controller_v1_pbs.Request_Void()) - ) - - response_dict = MessageToDict( - response, - preserving_proto_field_name=True, - ) - - displays = DisplaysListResponse.model_validate(response_dict) - - self._reporter.add_message("AgentOS", f"list_displays() ->{str(displays)}") - - return displays - - @telemetry.record_call() - def get_process_list( - self, get_extended_info: bool = False - ) -> controller_v1_pbs.Response_GetProcessList: - """ - Get a list of running processes. - - Args: - get_extended_info (bool, optional): Whether to include - extended process information. - Defaults to `False`. - - Returns: - controller_v1_pbs.Response_GetProcessList: Process list response containing: - - processes: List of ProcessInfo objects - """ - - self._reporter.add_message("AgentOS", f"get_process_list({get_extended_info})") - - response: controller_v1_pbs.Response_GetProcessList = ( - self._get_stub().GetProcessList( - controller_v1_pbs.Request_GetProcessList( - getExtendedInfo=get_extended_info - ) - ) - ) - self._reporter.add_message( - "AgentOS", f"get_process_list({get_extended_info}) -> {response}" - ) - - return response - - @telemetry.record_call() - def get_window_list( - self, process_id: int - ) -> controller_v1_pbs.Response_GetWindowList: - """ - Get a list of windows for a specific process. - - Args: - process_id (int): The ID of the process to get windows for. - - Returns: - controller_v1_pbs.Response_GetWindowList: Window list response containing: - - windows: List of WindowInfo objects with ID and name - """ - - self._reporter.add_message("AgentOS", f"get_window_list({process_id})") - - response: controller_v1_pbs.Response_GetWindowList = ( - self._get_stub().GetWindowList( - controller_v1_pbs.Request_GetWindowList(processID=process_id) - ) - ) - - self._reporter.add_message( - "AgentOS", f"get_window_list({process_id}) -> {response}" - ) - - return response - - @telemetry.record_call() - def get_automation_target_list( - self, - ) -> controller_v1_pbs.Response_GetAutomationTargetList: - """ - Get a list of available automation targets. - - Returns: - controller_v1_pbs.Response_GetAutomationTargetList: - Automation target list response: - - targets: List of AutomationTarget objects - """ - - self._reporter.add_message("AgentOS", "get_automation_target_list()") - - response: controller_v1_pbs.Response_GetAutomationTargetList = ( - self._get_stub().GetAutomationTargetList(controller_v1_pbs.Request_Void()) - ) - self._reporter.add_message( - "AgentOS", f"get_automation_target_list() -> {response}" - ) - - return response - - @telemetry.record_call() - def set_mouse_delay(self, delay_ms: int) -> None: - """ - Configure mouse action delay. - - Args: - delay_ms (int): The delay in milliseconds to set for mouse actions. - """ - - self._reporter.add_message("AgentOS", f"set_mouse_delay({delay_ms})") - - self._get_stub().SetMouseDelay( - controller_v1_pbs.Request_SetMouseDelay( - sessionInfo=self._session_info, delayInMilliseconds=delay_ms - ) - ) - - @telemetry.record_call() - def set_keyboard_delay(self, delay_ms: int) -> None: - """ - Configure keyboard action delay. - - Args: - delay_ms (int): The delay in milliseconds to set for keyboard actions. - """ - - self._reporter.add_message("AgentOS", f"set_keyboard_delay({delay_ms})") - - self._get_stub().SetKeyboardDelay( - controller_v1_pbs.Request_SetKeyboardDelay( - sessionInfo=self._session_info, delayInMilliseconds=delay_ms - ) - ) - - @telemetry.record_call() - def set_active_window(self, process_id: int, window_id: int) -> int: - """ - Set the active window for automation. - Adds the window as a virtual display and returns the display ID. - It raises an error if display length is not increased after adding the window. - - Args: - process_id (int): The ID of the process that owns the window. - window_id (int): The ID of the window to set as active. - - returns: - int: The new Display ID. - Raises: - AskUiControllerError: - If display length is not increased after adding the window. - """ - - self._reporter.add_message( - "AgentOS", f"set_active_window({process_id}, {window_id})" - ) - - display_length_before_adding_window = len(self.list_displays().data) - - self._get_stub().SetActiveWindow( - controller_v1_pbs.Request_SetActiveWindow( - processID=process_id, windowID=window_id - ) - ) - new_display_length = len(self.list_displays().data) - if new_display_length <= display_length_before_adding_window: - msg = f"Failed to set active window {window_id} for process {process_id}" - raise AskUiControllerError(msg) - self._reporter.add_message( - "AgentOS", - f"set_active_window({process_id}, {window_id}) -> {new_display_length}", - ) - return new_display_length - - @telemetry.record_call() - def set_active_automation_target(self, target_id: int) -> None: - """ - Set the active automation target. - - Args: - target_id (int): The ID of the automation target to set as active. - """ - - self._reporter.add_message( - "AgentOS", f"set_active_automation_target({target_id})" - ) - - self._get_stub().SetActiveAutomationTarget( - controller_v1_pbs.Request_SetActiveAutomationTarget(ID=target_id) - ) - - @telemetry.record_call() - def schedule_batched_action( - self, - action_class_id: controller_v1_pbs.ActionClassID, - action_parameters: controller_v1_pbs.ActionParameters, - ) -> controller_v1_pbs.Response_ScheduleBatchedAction: - """ - Schedule an action for batch execution. - - Args: - action_class_id (controller_v1_pbs.ActionClassID): The class ID - of the action to schedule. - action_parameters (controller_v1_pbs.ActionParameters): - Parameters for the action. - - Returns: - controller_v1_pbs.Response_ScheduleBatchedAction: Response containing - the scheduled action ID. - """ - - self._reporter.add_message( - "AgentOS", - f"schedule_batched_action({action_class_id}, {action_parameters})", - ) - - response: controller_v1_pbs.Response_ScheduleBatchedAction = ( - self._get_stub().ScheduleBatchedAction( - controller_v1_pbs.Request_ScheduleBatchedAction( - sessionInfo=self._session_info, - actionClassID=action_class_id, - actionParameters=action_parameters, - ) - ) - ) - - return response - - @telemetry.record_call() - def start_batch_run(self) -> None: - """ - Start executing batched actions. - """ - - self._reporter.add_message("AgentOS", "start_batch_run()") - - self._get_stub().StartBatchRun( - controller_v1_pbs.Request_StartBatchRun(sessionInfo=self._session_info) - ) - - @telemetry.record_call() - def stop_batch_run(self) -> None: - """ - Stop executing batched actions. - """ - - self._reporter.add_message("AgentOS", "stop_batch_run()") - - self._get_stub().StopBatchRun( - controller_v1_pbs.Request_StopBatchRun(sessionInfo=self._session_info) - ) - - @telemetry.record_call() - def get_action_count(self) -> controller_v1_pbs.Response_GetActionCount: - """ - Get the count of recorded or batched actions. - - Returns: - controller_v1_pbs.Response_GetActionCount: Response - containing the action count. - """ - - response: controller_v1_pbs.Response_GetActionCount = ( - self._get_stub().GetActionCount( - controller_v1_pbs.Request_GetActionCount(sessionInfo=self._session_info) - ) - ) - self._reporter.add_message("AgentOS", f"get_action_count() -> {response}") - return response - - @telemetry.record_call() - def get_action(self, action_index: int) -> controller_v1_pbs.Response_GetAction: - """ - Get a specific action by its index. - - Args: - action_index (int): The index of the action to retrieve. - - Returns: - controller_v1_pbs.Response_GetAction: Action information containing: - - actionID: The action ID - - actionClassID: The action class ID - - actionParameters: The action parameters - """ - - self._reporter.add_message("AgentOS", f"get_action({action_index})") - - response: controller_v1_pbs.Response_GetAction = self._get_stub().GetAction( - controller_v1_pbs.Request_GetAction( - sessionInfo=self._session_info, actionIndex=action_index - ) - ) - - return response - - @telemetry.record_call() - def remove_action(self, action_id: int) -> None: - """ - Remove a specific action by its ID. - - Args: - action_id (int): The ID of the action to remove. - """ - - self._reporter.add_message("AgentOS", f"remove_action({action_id})") - - self._get_stub().RemoveAction( - controller_v1_pbs.Request_RemoveAction( - sessionInfo=self._session_info, actionID=action_id - ) - ) - - @telemetry.record_call() - def remove_all_actions(self) -> None: - """ - Clear all recorded or batched actions. - """ - - self._reporter.add_message("AgentOS", "remove_all_actions()") - - self._get_stub().RemoveAllActions( - controller_v1_pbs.Request_RemoveAllActions(sessionInfo=self._session_info) - ) - - def _send_command(self, command: Command) -> AskUIAgentOSSendResponseSchema: - """ - Send a general command to the controller. - - Args: - command (Command): The command to send to the controller. - - Returns: - AskUIAgentOSSendResponseSchema: Response containing - the message from the controller. - - Raises: - AskUiControllerInvalidCommandError: If the command fails schema validation - on the server side. - """ - - header = Header(authentication=Guid(root=self._session_guid)) - message = Message(header=header, command=command) - - request = AskUIAgentOSSendRequestSchema(message=message) - - request_str = request.model_dump_json(exclude_none=True, by_alias=True) - - try: - response: controller_v1_pbs.Response_Send = self._get_stub().Send( - controller_v1_pbs.Request_Send(message=request_str) - ) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.INVALID_ARGUMENT: - details = e.details() or None - raise AskUiControllerInvalidCommandError(details) from e - raise - - return AskUIAgentOSSendResponseSchema.model_validate_json(response.message) - - @telemetry.record_call() - def get_mouse_position(self) -> Coordinate: - """ - Get the mouse cursor position - - Returns: - Coordinate: Response containing the result of the mouse position change. - """ - self._reporter.add_message("AgentOS", "get_mouse_position()") - res = self._send_command(GetMousePositionCommand()) - coordinate = Coordinate( - x=res.message.command.response.position.x.root, # type: ignore[union-attr] - y=res.message.command.response.position.y.root, # type: ignore[union-attr] - ) - self._reporter.add_message("AgentOS", f"get_mouse_position() -> {coordinate}") - return coordinate - - @telemetry.record_call() - def set_mouse_position(self, x: int, y: int) -> None: - """ - Set the mouse cursor position to specific coordinates. - - Args: - x (int): The horizontal coordinate (in pixels) to set the cursor to. - y (int): The vertical coordinate (in pixels) to set the cursor to. - """ - location = Location(x=Length(root=x), y=Length(root=y)) - command = SetMousePositionCommand(parameters=[location]) - self._reporter.add_message("AgentOS", f"set_mouse_position({x},{y})") - self._send_command(command) - - @telemetry.record_call() - def render_quad(self, style: RenderObjectStyle) -> int: - """ - Render a quad object to the display. - - Args: - style (RenderObjectStyle): The style properties for the quad. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_quad({style})") - command = AddRenderObjectCommand(parameters=["Quad", style]) - res = self._send_command(command) - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call() - def render_line(self, style: RenderObjectStyle, points: list[Coordinate]) -> int: - """ - Render a line object to the display. - - Args: - style (RenderObjectStyle): The style properties for the line. - points (list[Coordinates]): The points defining the line. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_line({style}, {points})") - command = AddRenderObjectCommand(parameters=["Line", style, points]) - res = self._send_command(command) - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call(exclude={"image_data"}) - def render_image(self, style: RenderObjectStyle, image_data: str) -> int: - """ - Render an image object to the display. - - Args: - style (RenderObjectStyle): The style properties for the image. - image_data (str): The base64-encoded image data. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_image({style}, [image_data])") - image = RenderImage(root=image_data) - command = AddRenderObjectCommand(parameters=["Image", style, image]) - res = self._send_command(command) - - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call() - def render_text(self, style: RenderObjectStyle, content: str) -> int: - """ - Render a text object to the display. - - Args: - style (RenderObjectStyle): The style properties for the text. - content (str): The text content to display. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_text({style}, {content})") - text = RenderText(root=content) - command = AddRenderObjectCommand(parameters=["Text", style, text]) - res = self._send_command(command) - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call() - def update_render_object(self, object_id: int, style: RenderObjectStyle) -> None: - """ - Update styling properties of an existing render object. - - Args: - object_id (float): The ID of the render object to update. - style (RenderObjectStyle): The new style properties. - - Returns: - int: Object ID. - """ - self._reporter.add_message( - "AgentOS", f"update_render_object({object_id}, {style})" - ) - render_object_id = RenderObjectId(root=object_id) - command = UpdateRenderObjectCommand(parameters=[render_object_id, style]) - self._send_command(command) - - @telemetry.record_call() - def delete_render_object(self, object_id: int) -> None: - """ - Delete an existing render object from the display. - - Args: - object_id (RenderObjectId): The ID of the render object to delete. - """ - self._reporter.add_message("AgentOS", f"delete_render_object({object_id})") - render_object_id = RenderObjectId(root=object_id) - command = DeleteRenderObjectCommand(parameters=[render_object_id]) - self._send_command(command) - - @telemetry.record_call() - def clear_render_objects(self) -> None: - """ - Clear all render objects from the display. - """ - self._reporter.add_message("AgentOS", "clear_render_objects()") - command = ClearRenderObjectsCommand() - self._send_command(command) - - def get_system_info(self) -> GetSystemInfoResponseModel: - """ - Get the system information. - - Returns: - SystemInfo: The system information. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", "get_system_info()") - command = GetSystemInfoCommand() - res = self._send_command(command).message.command - if not isinstance(res, GetSystemInfoResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - self._reporter.add_message("AgentOS", f"get_system_info() -> {res.response}") - return res.response - - def get_active_process(self) -> GetActiveProcessResponseModel: - """ - Get the active process. - - Returns: - GetActiveProcessResponseModel: The active process. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", "get_active_process()") - command = GetActiveProcessCommand() - res = self._send_command(command).message.command - if not isinstance(res, GetActiveProcessResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - self._reporter.add_message("AgentOS", f"get_active_process() -> {res.response}") - return res.response - - def set_active_process(self, process_id: int) -> None: - """ - Set the active process. - - Args: - process_id (int): The ID of the process to set as active. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", f"set_active_process({process_id})") - _process_id = Parameter3(root=process_id) - command = SetActiveProcessCommand(parameters=[_process_id]) - self._send_command(command) - - def get_active_window(self) -> GetActiveWindowResponseModel: - """ - Gets the window id and name in addition to the process id - and name of the currently active window (in focus). - - - Returns: - GetActiveWindowResponseModel: The active window. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", "get_active_window()") - command = GetActiveWindowCommand() - res = self._send_command(command).message.command - if not isinstance(res, GetActiveWindowResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - self._reporter.add_message("AgentOS", f"get_active_window() -> {res.response}") - return res.response - - def set_window_in_focus(self, process_id: int, window_id: int) -> None: - """ - Sets the window with the specified windowId of the process - with the specified processId active, - which brings it to the front and gives it focus. - - Args: - process_id (int): The ID of the process that owns the window. - window_id (int): The ID of the window to set as active. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message( - "AgentOS", f"set_window_in_focus({process_id}, {window_id})" - ) - _process_id = Parameter3(root=process_id) - _window_id = Parameter3(root=window_id) - command = SetActiveWindowCommand(parameters=[_process_id, _window_id]) - self._send_command(command) +import logging +import time +import types +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Literal, Type + +import grpc +from google.protobuf.json_format import MessageToDict +from PIL import Image +from typing_extensions import Self, override + +from askui.container import telemetry +from askui.reporting import NULL_REPORTER, Reporter +from askui.tools.agent_os import ( + AgentOs, + Coordinate, + Display, + DisplaysListResponse, + ModifierKey, + PcKey, +) +from askui.tools.askui.agent_os_server import ( + AgentOsServer, + LocalAgentOsServer, + RemoteAgentOsServer, +) +from askui.tools.askui.agent_os_server_manager import ( + AgentOsServerManager, +) +from askui.tools.askui.askui_ui_controller_grpc.desktop_agent_os_error import ( + DesktopAgentOsError, +) +from askui.tools.askui.askui_ui_controller_grpc.generated import ( + Controller_V1_pb2 as controller_v1_pbs, +) +from askui.tools.askui.askui_ui_controller_grpc.generated import ( + Controller_V1_pb2_grpc as controller_v1, +) +from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Request_2501 import ( # noqa: E501 + AddRenderObjectCommand, + AskUIAgentOSSendRequestSchema, + ClearRenderObjectsCommand, + Command, + DeleteRenderObjectCommand, + GetActiveProcessCommand, + GetActiveWindowCommand, + GetMousePositionCommand, + GetSystemInfoCommand, + Guid, + Header, + Length, + Location, + Message, + Parameter3, + RenderImage, + RenderObjectId, + RenderObjectStyle, + RenderText, + SetActiveProcessCommand, + SetActiveWindowCommand, + SetMousePositionCommand, + UpdateRenderObjectCommand, +) +from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Response_2501 import ( # noqa: E501 + AskUIAgentOSSendResponseSchema, + GetActiveProcessResponse, + GetActiveProcessResponseModel, + GetActiveWindowResponse, + GetActiveWindowResponseModel, + GetSystemInfoResponse, + GetSystemInfoResponseModel, +) +from askui.utils.annotated_image import AnnotatedImage + +from .exceptions import ( + AskUiControllerError, + AskUiControllerInvalidCommandError, + AskUiControllerOperationTimeoutError, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class _Connection: + """gRPC connection state for a single Agent OS server.""" + + channel: grpc.Channel + stub: controller_v1.ControllerAPIStub + session_info: controller_v1_pbs.SessionInfo + started_process: bool + + +class AskUiControllerClient(AgentOs): + """ + Implementation of `AgentOs` that communicates with one or more Agent OS servers + (AskUI Remote Device Controller processes) via gRPC. + + A client is configured with a non-empty list of `agent_os_servers` (at most one + local, the rest remote with unique addresses). `connect()` opens a gRPC channel + and session for *every* registered server. Exactly one server is *active* at a + time; agent-os actions are routed to its connection. `disconnect()` closes every + open connection and stops only those local processes that were started by + this client (i.e. `is_local` and not `is_service` at connect time). + + Use `add_agent_os_server` / `add_remote_agent_os_server` to register additional + targets (which auto-connect if the client is currently connected), + `switch_agent_os_server` to change the active one, `list_agent_os_servers` to + inspect the list, and `reset_agent_os_servers` to clear or replace the list. + + Args: + reporter (Reporter): Reporter used for reporting with the `"AgentOs"`. + display (int, optional): Display number to use. Defaults to `1`. + agent_os_servers (list[AgentOsServer] | None, optional): + Agent OS servers to register. Must be non-empty if provided, contain at + most one local server, and have unique addresses across remote servers. + If `None` (default), a single `LocalAgentOsServer` with + default settings is registered. + """ + + _REPORTER_SOURCE = "AgentOS" + + @telemetry.record_call(exclude={"reporter", "agent_os_servers"}) + def __init__( + self, + reporter: Reporter = NULL_REPORTER, + display: int = 1, + agent_os_servers: list[AgentOsServer] | None = None, + ) -> None: + if not agent_os_servers: + agent_os_servers = [LocalAgentOsServer(display=display)] + + self._connections: dict[str, _Connection] = {} + self._pre_action_wait = 0 + self._post_action_wait = 0.05 + self._max_retries = 10 + self._reporter = reporter + self._manager = AgentOsServerManager(agent_os_servers=agent_os_servers) + + @property + def agent_os_server_manager(self) -> AgentOsServerManager: + """The underlying Agent-OS-server manager.""" + return self._manager + + @property + def is_connected(self) -> bool: + """`True` when at least one Agent-OS-server connection is open.""" + return bool(self._connections) + + def _require_active_server(self) -> AgentOsServer: + server = self._manager.active + if server is None: + error_msg = ( + "No active Agent OS server. Register one via " + "`AskUiControllerClient.add_agent_os_server()` / " + "`add_remote_agent_os_server()`, or pass `agent_os_servers` to the " + "`AskUiControllerClient` constructor." + ) + raise AskUiControllerError(error_msg) + return server + + def _active_connection(self) -> _Connection: + server = self._require_active_server() + conn = self._connections.get(server.session_guid) + if conn is None: + error_msg = ( + f"Active Agent OS server {server.description!r} " + f"(session_guid={server.session_guid}, address={server.address}) " + "is not connected. Call `AskUiControllerClient.connect()` first." + ) + raise AskUiControllerError(error_msg) + return conn + + @property + def _session_info(self) -> controller_v1_pbs.SessionInfo: + return self._active_connection().session_info + + @telemetry.record_call() + @override + def add_remote_agent_os_server( + self, + address: str, + description: str, + ) -> RemoteAgentOsServer: + """ + Register a remote Agent OS server. Auto-connects if the client is currently + connected. + + Args: + address (str): gRPC address of the remote controller (required). + description (str): Human-readable description. + + Returns: + RemoteAgentOsServer: The newly registered server. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, + f"add_remote_agent_os_server({address!r}, description={description!r})", + ) + server = self._manager.add_remote(address=address, description=description) + if self.is_connected: + self._connect_server(server) + self._reporter.add_message( + self._REPORTER_SOURCE, f"add_remote_agent_os_server(...) -> {server!r}" + ) + return server + + @telemetry.record_call(exclude={"server"}) + @override + def add_agent_os_server(self, server: AgentOsServer) -> AgentOsServer: + """ + Register an already-constructed Agent OS server. Auto-connects if the + client is currently connected. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"add_agent_os_server({server!r})" + ) + self._manager.add(server) + if self.is_connected: + self._connect_server(server) + return server + + @telemetry.record_call(exclude={"agent_os_servers"}) + @override + def reset_agent_os_servers( + self, + agent_os_servers: list[AgentOsServer] | None = None, + ) -> None: + """ + Disconnect (if connected) and replace the Agent-OS-server list. + + Args: + agent_os_servers (list[AgentOsServer] | None, optional): + New list of Agent OS servers to register after the reset. If `None`, + the list is left empty and a subsequent `connect()` will fail until + at least one server has been registered again. Same validation rules + as the constructor (at most one local, unique remote addresses). + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"reset_agent_os_servers({agent_os_servers!r})" + ) + was_connected = self.is_connected + if was_connected: + self.disconnect() + self._manager.reset() + if agent_os_servers is not None: + for server in agent_os_servers: + self._manager.add(server) + if was_connected: + self.connect() + + @telemetry.record_call() + @override + def list_agent_os_servers(self) -> list[AgentOsServer]: + """Return all registered Agent OS servers.""" + self._reporter.add_message(self._REPORTER_SOURCE, "list_agent_os_servers()") + servers = self._manager.list() + self._reporter.add_message( + self._REPORTER_SOURCE, f"list_agent_os_servers() -> {servers!r}" + ) + return servers + + @telemetry.record_call() + @override + def get_active_agent_os_server(self, report: bool = True) -> AgentOsServer: + """Return the currently active Agent OS server.""" + if report: + self._reporter.add_message( + self._REPORTER_SOURCE, "get_active_agent_os_server()" + ) + server = self._require_active_server() + if report: + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_active_agent_os_server() -> {server!r}" + ) + return server + + @telemetry.record_call() + @override + def switch_agent_os_server(self, computer_id: str) -> AgentOsServer: + """ + Switch the active Agent OS server by its `computer_id` (the user-supplied + identifier; defaults to the server's `session_guid` when none was supplied + at construction time). + + Connections to all registered servers stay open across switches; this just + changes which connection routes future agent-os actions. If the server was + added after `connect()` and isn't connected yet, it is connected on switch. + + Args: + computer_id (str): The computer id of the server to switch to. + + Returns: + AgentOsServer: The newly active server. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"switch_agent_os_server({computer_id!r})" + ) + server = self._manager.switch(computer_id) + if self.is_connected and server.session_guid not in self._connections: + self._connect_server(server) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"switch_agent_os_server({computer_id!r}) -> {server!r}", + ) + return server + + @contextmanager + @override + def temporary_select(self, computer_id: str) -> Iterator[Self]: + previous = self._manager.active + self._reporter.add_message( + self._REPORTER_SOURCE, + f"temporary_select({computer_id!r}) [previous={previous!r}]", + ) + self.switch_agent_os_server(computer_id) + try: + yield self + finally: + if previous is not None and previous.computer_id != computer_id: + self.switch_agent_os_server(previous.computer_id) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"temporary_select({computer_id!r}) -> restored", + ) + + @telemetry.record_call() + @override + def connect(self) -> None: + """ + Open a gRPC channel and session to every registered Agent OS server. + + For each server: starts the local process when `is_local` and `is_service` + is `False`, opens an insecure gRPC channel, starts a session, starts + execution, and sets the configured display. Servers already connected are + skipped, so calling `connect()` twice is safe. + + On failure mid-loop, all servers connected so far are rolled back via + `disconnect()` before re-raising. + """ + if not self._manager.list(): + error_msg = ( + "Cannot connect: no Agent OS servers registered. Provide at least " + "one via the `AskUiControllerClient` constructor's `agent_os_servers` " + "argument, or call `add_agent_os_server()` / " + "`add_remote_agent_os_server()` before `connect()`." + ) + raise AskUiControllerError(error_msg) + try: + for server in self._manager.list(): + self._connect_server(server) + except Exception: + self.disconnect() + raise + + def _connect_server(self, server: AgentOsServer) -> None: + if server.session_guid in self._connections: + return + started_process = False + if isinstance(server, LocalAgentOsServer) and not server.is_service: + server.start() + started_process = True + channel = grpc.insecure_channel( + server.address, + options=[ + ("grpc.max_send_message_length", 2**30), + ("grpc.max_receive_message_length", 2**30), + ("grpc.default_deadline", 300000), + ], + ) + stub = controller_v1.ControllerAPIStub(channel) + try: + session_response: controller_v1_pbs.Response_StartSession = ( + stub.StartSession( + controller_v1_pbs.Request_StartSession( + sessionGUID=server.session_guid, immediateExecution=True + ) + ) + ) + session_info = session_response.sessionInfo + stub.StartExecution( + controller_v1_pbs.Request_StartExecution(sessionInfo=session_info) + ) + stub.SetActiveDisplay( + controller_v1_pbs.Request_SetActiveDisplay(displayID=server.display) + ) + except Exception as e: + try: + channel.close() + finally: + if started_process: + server.stop() + if hasattr(e, "add_note"): + e.add_note( + f"While connecting to Agent OS server {server.description!r} " + f"(computer_id={server.computer_id!r}, " + f"session_guid={server.session_guid}, " + f"display={server.display}, " + f"address={server.address})" + ) + raise + self._connections[server.session_guid] = _Connection( + channel=channel, + stub=stub, + session_info=session_info, + started_process=started_process, + ) + + def _get_stub(self) -> controller_v1.ControllerAPIStub: + return self._active_connection().stub + + def _run_recorder_action( + self, + acion_class_id: controller_v1_pbs.ActionClassID, + action_parameters: controller_v1_pbs.ActionParameters, + ) -> controller_v1_pbs.Response_RunRecordedAction: + time.sleep(self._pre_action_wait) + response: controller_v1_pbs.Response_RunRecordedAction = ( + self._get_stub().RunRecordedAction( + controller_v1_pbs.Request_RunRecordedAction( + sessionInfo=self._session_info, + actionClassID=acion_class_id, + actionParameters=action_parameters, + ) + ) + ) + + time.sleep((response.requiredMilliseconds / 1000)) + num_retries = 0 + for _ in range(self._max_retries): + poll_response: controller_v1_pbs.Response_Poll = self._get_stub().Poll( + controller_v1_pbs.Request_Poll( + sessionInfo=self._session_info, + pollEventID=controller_v1_pbs.PollEventID.PollEventID_ActionFinished, + ) + ) + if ( + poll_response.pollEventParameters.actionFinished.actionID + == response.actionID + ): + break + time.sleep(self._post_action_wait) + num_retries += 1 + if num_retries == self._max_retries - 1: + server = self._require_active_server() + timeout_seconds = self._max_retries * self._post_action_wait + timeout_msg = ( + f"Action did not finish on Agent OS server {server.description!r} " + f"(session_guid={server.session_guid}) within " + f"{timeout_seconds:.2f}s ({self._max_retries} polls of " + f"{self._post_action_wait:.2f}s). " + f"Action class id: {acion_class_id}." + ) + raise AskUiControllerOperationTimeoutError( + message=timeout_msg, timeout_seconds=timeout_seconds + ) + return response + + @telemetry.record_call() + @override + def disconnect(self) -> None: + """ + Close every open controller-server connection. + + For each connection: stops execution, ends the session, closes the gRPC + channel, and (only when `connect()` started the local process) + stops the controller process. Errors are logged but do not abort the loop - + a partial failure on one server still releases the others. + """ + for session_guid in list(self._connections.keys()): + self._disconnect_server(session_guid) + + def _disconnect_server(self, session_guid: str) -> None: + conn = self._connections.pop(session_guid, None) + if conn is None: + return + try: + conn.stub.StopExecution( + controller_v1_pbs.Request_StopExecution(sessionInfo=conn.session_info) + ) + conn.stub.EndSession( + controller_v1_pbs.Request_EndSession(sessionInfo=conn.session_info) + ) + except Exception: # noqa: BLE001 + logger.exception( + "Error stopping execution/session for controller %s", session_guid + ) + try: + conn.channel.close() + except Exception: # noqa: BLE001 + logger.exception("Error closing channel for controller %s", session_guid) + if conn.started_process: + server = self._manager.get_by_session_guid(session_guid) + if server is None: + return + try: + server.stop() + except Exception: # noqa: BLE001 + logger.exception( + "Error stopping client-started controller process for %s", + session_guid, + ) + + @telemetry.record_call() + def __enter__(self) -> Self: + """ + Context manager entry point that establishes the connection. + + Returns: + Self: The instance of AskUiControllerClient. + """ + self.connect() + return self + + @telemetry.record_call(exclude={"exc_value", "traceback"}) + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + """ + Context manager exit point that disconnects the client. + + Args: + exc_type: The exception type if an exception was raised. + exc_value: The exception value if an exception was raised. + traceback: The traceback if an exception was raised. + """ + self.disconnect() + + @telemetry.record_call() + @override + def screenshot(self, report: bool = True) -> Image.Image: + """ + Take a screenshot of the current screen. + + Args: + report (bool, optional): Whether to include the screenshot in reporting. + Defaults to `True`. + + Returns: + Image.Image: A PIL Image object containing the screenshot. + + """ + screenResponse = self._get_stub().CaptureScreen( + controller_v1_pbs.Request_CaptureScreen( + sessionInfo=self._session_info, + captureParameters=controller_v1_pbs.CaptureParameters( + displayID=self._require_active_server().display + ), + ) + ) + r, g, b, _ = Image.frombytes( + "RGBA", + (screenResponse.bitmap.width, screenResponse.bitmap.height), + screenResponse.bitmap.data, + ).split() + image = Image.merge("RGB", (b, g, r)) + self._reporter.add_message(self._REPORTER_SOURCE, "screenshot()", image) + return image + + @telemetry.record_call() + @override + def mouse_move(self, x: int, y: int, duration: int = 500) -> None: + """ + Moves the mouse cursor to specified screen coordinates. + + Args: + x (int): The horizontal coordinate (in pixels) to move to. + y (int): The vertical coordinate (in pixels) to move to. + duration (int): The duration (in ms) the movement should take. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, + f"mouse_move({x}, {y}, duration={duration})", + AnnotatedImage(lambda: self.screenshot(report=False), point_list=[(x, y)]), + ) + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseMove, + action_parameters=controller_v1_pbs.ActionParameters( + mouseMove=controller_v1_pbs.ActionParameters_MouseMove( + position=controller_v1_pbs.Coordinate2(x=x, y=y), + milliseconds=duration, + ) + ), + ) + + @telemetry.record_call(exclude={"text"}) + @override + def type(self, text: str, typing_speed: int = 50) -> None: + """ + Type text at current cursor position as if entered on a keyboard. + + Args: + text (str): The text to type. + typing_speed (int, optional): The speed of typing in characters per second. + Defaults to `50`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'type("{text}", {typing_speed})' + ) + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardType_UnicodeText, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardTypeUnicodeText=controller_v1_pbs.ActionParameters_KeyboardType_UnicodeText( + text=text.encode("utf-16-le"), + typingSpeed=typing_speed, + typingSpeedValue=controller_v1_pbs.TypingSpeedValue.TypingSpeedValue_CharactersPerSecond, + ) + ), + ) + + @telemetry.record_call() + @override + def click( + self, button: Literal["left", "middle", "right"] = "left", count: int = 1 + ) -> None: + """ + Click a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + click. Defaults to `"left"`. + count (int, optional): Number of times to click. Defaults to `1`. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f'click("{button}", {count})') + mouse_button = None + match button: + case "left": + mouse_button = controller_v1_pbs.MouseButton_Left + case "middle": + mouse_button = controller_v1_pbs.MouseButton_Middle + case "right": + mouse_button = controller_v1_pbs.MouseButton_Right + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_PressAndRelease, + action_parameters=controller_v1_pbs.ActionParameters( + mouseButtonPressAndRelease=controller_v1_pbs.ActionParameters_MouseButton_PressAndRelease( + mouseButton=mouse_button, count=count + ) + ), + ) + + @telemetry.record_call() + @override + def mouse_down(self, button: Literal["left", "middle", "right"] = "left") -> None: + """ + Press and hold a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + press. Defaults to `"left"`. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f'mouse_down("{button}")') + mouse_button = None + match button: + case "left": + mouse_button = controller_v1_pbs.MouseButton_Left + case "middle": + mouse_button = controller_v1_pbs.MouseButton_Middle + case "right": + mouse_button = controller_v1_pbs.MouseButton_Right + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Press, + action_parameters=controller_v1_pbs.ActionParameters( + mouseButtonPress=controller_v1_pbs.ActionParameters_MouseButton_Press( + mouseButton=mouse_button + ) + ), + ) + + @telemetry.record_call() + @override + def mouse_up(self, button: Literal["left", "middle", "right"] = "left") -> None: + """ + Release a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + release. Defaults to `"left"`. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f'mouse_up("{button}")') + mouse_button = None + match button: + case "left": + mouse_button = controller_v1_pbs.MouseButton_Left + case "middle": + mouse_button = controller_v1_pbs.MouseButton_Middle + case "right": + mouse_button = controller_v1_pbs.MouseButton_Right + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Release, + action_parameters=controller_v1_pbs.ActionParameters( + mouseButtonRelease=controller_v1_pbs.ActionParameters_MouseButton_Release( + mouseButton=mouse_button + ) + ), + ) + + @telemetry.record_call() + @override + def mouse_scroll(self, dx: int, dy: int) -> None: + """ + Scroll the mouse wheel. + + Args: + dx (int): The horizontal scroll amount. Positive values scroll right, + negative values scroll left. + dy (int): The vertical scroll amount. Positive values scroll down, + negative values scroll up. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f"mouse_scroll({dx}, {dy})") + if dx != 0: + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, + action_parameters=controller_v1_pbs.ActionParameters( + mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( + direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Horizontal, + deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, + delta=dx, + milliseconds=50, + ) + ), + ) + if dy != 0: + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, + action_parameters=controller_v1_pbs.ActionParameters( + mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( + direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Vertical, + deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, + delta=dy, + milliseconds=50, + ) + ), + ) + + @telemetry.record_call() + @override + def keyboard_pressed( + self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + ) -> None: + """ + Press and hold a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to press. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + press along with the main key. Defaults to `None`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'keyboard_pressed("{key}", {modifier_keys})' + ) + if modifier_keys is None: + modifier_keys = [] + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Press, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardKeyPress=controller_v1_pbs.ActionParameters_KeyboardKey_Press( + keyName=key, modifierKeyNames=modifier_keys + ) + ), + ) + + @telemetry.record_call() + @override + def keyboard_release( + self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + ) -> None: + """ + Release a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to release. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + release along with the main key. Defaults to `None`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'keyboard_release("{key}", {modifier_keys})' + ) + if modifier_keys is None: + modifier_keys = [] + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Release, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardKeyRelease=controller_v1_pbs.ActionParameters_KeyboardKey_Release( + keyName=key, modifierKeyNames=modifier_keys + ) + ), + ) + + @telemetry.record_call() + @override + def keyboard_tap( + self, + key: PcKey | ModifierKey, + modifier_keys: list[ModifierKey] | None = None, + count: int = 1, + ) -> None: + """ + Press and immediately release a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to tap. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + press along with the main key. Defaults to `None`. + count (int, optional): The number of times to tap the key. Defaults to `1`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, + f'keyboard_tap("{key}", {modifier_keys}, {count})', + ) + if modifier_keys is None: + modifier_keys = [] + for _ in range(count): + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_PressAndRelease, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardKeyPressAndRelease=controller_v1_pbs.ActionParameters_KeyboardKey_PressAndRelease( + keyName=key, modifierKeyNames=modifier_keys + ) + ), + ) + + @telemetry.record_call() + @override + def set_display(self, display: int = 1) -> None: + """ + Set the active display. + + Args: + display (int, optional): The display ID to set as active. + This can be either a real display ID or a virtual display ID. + Defaults to `1`. + """ + self._get_stub().SetActiveDisplay( + controller_v1_pbs.Request_SetActiveDisplay(displayID=display) + ) + self._require_active_server().display = display + self._reporter.add_message(self._REPORTER_SOURCE, f"set_display({display})") + + @telemetry.record_call(exclude={"command"}) + @override + def run_command(self, command: str, timeout_ms: int = 30000) -> None: + """ + Execute a shell command. + + Args: + command (str): The command to execute. + timeout_ms (int, optional): The timeout for command + execution in milliseconds. Defaults to `30000` (30 seconds). + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'run_command("{command}", {timeout_ms})' + ) + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_RunCommand, + action_parameters=controller_v1_pbs.ActionParameters( + runcommand=controller_v1_pbs.ActionParameters_RunCommand( + command=command, timeoutInMilliseconds=timeout_ms + ) + ), + ) + + @telemetry.record_call() + @override + def retrieve_active_display(self) -> Display: + """ + Retrieve the currently active display/screen. + + Returns: + Display: The currently active display/screen. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "retrieve_active_display()") + server = self._require_active_server() + active_display_id = server.display + displays_list_response = self.list_displays() + for display in displays_list_response.data: + if display.id == active_display_id: + self._reporter.add_message( + self._REPORTER_SOURCE, f"retrieve_active_display() -> {display}" + ) + return display + available_ids = ( + ", ".join(str(d.id) for d in displays_list_response.data) or "none" + ) + error_msg = ( + f"Display {active_display_id} not found on Agent OS server " + f"{server.description!r} (session_guid={server.session_guid}). " + f"Available display ids: {available_ids}. " + "Call `set_display()` with a valid id, or `list_displays()` to inspect." + ) + raise ValueError(error_msg) + + @telemetry.record_call() + @override + def list_displays( + self, + ) -> DisplaysListResponse: + """ + List all available Displays from the controller. + It includes both real and virtual displays + without describing the type of display (virtual or real). + + Returns: + DisplaysListResponse + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "list_displays()") + + response: controller_v1_pbs.Response_GetDisplayInformation = ( + self._get_stub().GetDisplayInformation(controller_v1_pbs.Request_Void()) + ) + + response_dict = MessageToDict( + response, + preserving_proto_field_name=True, + ) + + displays = DisplaysListResponse.model_validate(response_dict) + + self._reporter.add_message( + self._REPORTER_SOURCE, f"list_displays() ->{str(displays)}" + ) + + return displays + + @telemetry.record_call() + def get_process_list( + self, get_extended_info: bool = False + ) -> controller_v1_pbs.Response_GetProcessList: + """ + Get a list of running processes. + + Args: + get_extended_info (bool, optional): Whether to include + extended process information. + Defaults to `False`. + + Returns: + controller_v1_pbs.Response_GetProcessList: Process list response containing: + - processes: List of ProcessInfo objects + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_process_list({get_extended_info})" + ) + + response: controller_v1_pbs.Response_GetProcessList = ( + self._get_stub().GetProcessList( + controller_v1_pbs.Request_GetProcessList( + getExtendedInfo=get_extended_info + ) + ) + ) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"get_process_list({get_extended_info}) -> {response}", + ) + + return response + + @telemetry.record_call() + def get_window_list( + self, process_id: int + ) -> controller_v1_pbs.Response_GetWindowList: + """ + Get a list of windows for a specific process. + + Args: + process_id (int): The ID of the process to get windows for. + + Returns: + controller_v1_pbs.Response_GetWindowList: Window list response containing: + - windows: List of WindowInfo objects with ID and name + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_window_list({process_id})" + ) + + response: controller_v1_pbs.Response_GetWindowList = ( + self._get_stub().GetWindowList( + controller_v1_pbs.Request_GetWindowList(processID=process_id) + ) + ) + + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_window_list({process_id}) -> {response}" + ) + + return response + + @telemetry.record_call() + def get_automation_target_list( + self, + ) -> controller_v1_pbs.Response_GetAutomationTargetList: + """ + Get a list of available automation targets. + + Returns: + controller_v1_pbs.Response_GetAutomationTargetList: + Automation target list response: + - targets: List of AutomationTarget objects + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, "get_automation_target_list()" + ) + + response: controller_v1_pbs.Response_GetAutomationTargetList = ( + self._get_stub().GetAutomationTargetList(controller_v1_pbs.Request_Void()) + ) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_automation_target_list() -> {response}" + ) + + return response + + @telemetry.record_call() + def set_mouse_delay(self, delay_ms: int) -> None: + """ + Configure mouse action delay. + + Args: + delay_ms (int): The delay in milliseconds to set for mouse actions. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_mouse_delay({delay_ms})" + ) + + self._get_stub().SetMouseDelay( + controller_v1_pbs.Request_SetMouseDelay( + sessionInfo=self._session_info, delayInMilliseconds=delay_ms + ) + ) + + @telemetry.record_call() + def set_keyboard_delay(self, delay_ms: int) -> None: + """ + Configure keyboard action delay. + + Args: + delay_ms (int): The delay in milliseconds to set for keyboard actions. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_keyboard_delay({delay_ms})" + ) + + self._get_stub().SetKeyboardDelay( + controller_v1_pbs.Request_SetKeyboardDelay( + sessionInfo=self._session_info, delayInMilliseconds=delay_ms + ) + ) + + @telemetry.record_call() + def set_active_window(self, process_id: int, window_id: int) -> int: + """ + Set the active window for automation. + Adds the window as a virtual display and returns the display ID. + It raises an error if display length is not increased after adding the window. + + Args: + process_id (int): The ID of the process that owns the window. + window_id (int): The ID of the window to set as active. + + returns: + int: The new Display ID. + Raises: + AskUiControllerError: + If display length is not increased after adding the window. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_active_window({process_id}, {window_id})" + ) + + display_length_before_adding_window = len(self.list_displays().data) + + self._get_stub().SetActiveWindow( + controller_v1_pbs.Request_SetActiveWindow( + processID=process_id, windowID=window_id + ) + ) + new_display_length = len(self.list_displays().data) + if new_display_length <= display_length_before_adding_window: + msg = ( + f"Failed to add window {window_id} of process {process_id} as a " + f"virtual display: display count did not increase " + f"({display_length_before_adding_window} -> {new_display_length}). " + "Verify the process and window ids exist and are valid for the " + "active Agent OS server." + ) + raise AskUiControllerError(msg) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"set_active_window({process_id}, {window_id}) -> {new_display_length}", + ) + return new_display_length + + @telemetry.record_call() + def set_active_automation_target(self, target_id: int) -> None: + """ + Set the active automation target. + + Args: + target_id (int): The ID of the automation target to set as active. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_active_automation_target({target_id})" + ) + + self._get_stub().SetActiveAutomationTarget( + controller_v1_pbs.Request_SetActiveAutomationTarget(ID=target_id) + ) + + @telemetry.record_call() + def schedule_batched_action( + self, + action_class_id: controller_v1_pbs.ActionClassID, + action_parameters: controller_v1_pbs.ActionParameters, + ) -> controller_v1_pbs.Response_ScheduleBatchedAction: + """ + Schedule an action for batch execution. + + Args: + action_class_id (controller_v1_pbs.ActionClassID): The class ID + of the action to schedule. + action_parameters (controller_v1_pbs.ActionParameters): + Parameters for the action. + + Returns: + controller_v1_pbs.Response_ScheduleBatchedAction: Response containing + the scheduled action ID. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, + f"schedule_batched_action({action_class_id}, {action_parameters})", + ) + + response: controller_v1_pbs.Response_ScheduleBatchedAction = ( + self._get_stub().ScheduleBatchedAction( + controller_v1_pbs.Request_ScheduleBatchedAction( + sessionInfo=self._session_info, + actionClassID=action_class_id, + actionParameters=action_parameters, + ) + ) + ) + + return response + + @telemetry.record_call() + def start_batch_run(self) -> None: + """ + Start executing batched actions. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "start_batch_run()") + + self._get_stub().StartBatchRun( + controller_v1_pbs.Request_StartBatchRun(sessionInfo=self._session_info) + ) + + @telemetry.record_call() + def stop_batch_run(self) -> None: + """ + Stop executing batched actions. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "stop_batch_run()") + + self._get_stub().StopBatchRun( + controller_v1_pbs.Request_StopBatchRun(sessionInfo=self._session_info) + ) + + @telemetry.record_call() + def get_action_count(self) -> controller_v1_pbs.Response_GetActionCount: + """ + Get the count of recorded or batched actions. + + Returns: + controller_v1_pbs.Response_GetActionCount: Response + containing the action count. + """ + + response: controller_v1_pbs.Response_GetActionCount = ( + self._get_stub().GetActionCount( + controller_v1_pbs.Request_GetActionCount(sessionInfo=self._session_info) + ) + ) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_action_count() -> {response}" + ) + return response + + @telemetry.record_call() + def get_action(self, action_index: int) -> controller_v1_pbs.Response_GetAction: + """ + Get a specific action by its index. + + Args: + action_index (int): The index of the action to retrieve. + + Returns: + controller_v1_pbs.Response_GetAction: Action information containing: + - actionID: The action ID + - actionClassID: The action class ID + - actionParameters: The action parameters + """ + + self._reporter.add_message(self._REPORTER_SOURCE, f"get_action({action_index})") + + response: controller_v1_pbs.Response_GetAction = self._get_stub().GetAction( + controller_v1_pbs.Request_GetAction( + sessionInfo=self._session_info, actionIndex=action_index + ) + ) + + return response + + @telemetry.record_call() + def remove_action(self, action_id: int) -> None: + """ + Remove a specific action by its ID. + + Args: + action_id (int): The ID of the action to remove. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, f"remove_action({action_id})") + + self._get_stub().RemoveAction( + controller_v1_pbs.Request_RemoveAction( + sessionInfo=self._session_info, actionID=action_id + ) + ) + + @telemetry.record_call() + def remove_all_actions(self) -> None: + """ + Clear all recorded or batched actions. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "remove_all_actions()") + + self._get_stub().RemoveAllActions( + controller_v1_pbs.Request_RemoveAllActions(sessionInfo=self._session_info) + ) + + def _send_command(self, command: Command) -> AskUIAgentOSSendResponseSchema: + """ + Send a general command to the controller. + + Args: + command (Command): The command to send to the controller. + + Returns: + AskUIAgentOSSendResponseSchema: Response containing + the message from the controller. + + Raises: + AskUiControllerInvalidCommandError: If the command fails schema validation + on the server side. + """ + + server = self._require_active_server() + header = Header(authentication=Guid(root=server.session_guid)) + message = Message(header=header, command=command) + + request = AskUIAgentOSSendRequestSchema(message=message) + + request_str = request.model_dump_json(exclude_none=True, by_alias=True) + + try: + response: controller_v1_pbs.Response_Send = self._get_stub().Send( + controller_v1_pbs.Request_Send(message=request_str) + ) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.INVALID_ARGUMENT: + details = e.details() or None + raise AskUiControllerInvalidCommandError(details) from e + raise + + return AskUIAgentOSSendResponseSchema.model_validate_json(response.message) + + @telemetry.record_call() + def get_mouse_position(self) -> Coordinate: + """ + Get the mouse cursor position + + Returns: + Coordinate: Response containing the result of the mouse position change. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_mouse_position()") + res = self._send_command(GetMousePositionCommand()) + coordinate = Coordinate( + x=res.message.command.response.position.x.root, # type: ignore[union-attr] + y=res.message.command.response.position.y.root, # type: ignore[union-attr] + ) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_mouse_position() -> {coordinate}" + ) + return coordinate + + @telemetry.record_call() + def set_mouse_position(self, x: int, y: int) -> None: + """ + Set the mouse cursor position to specific coordinates. + + Args: + x (int): The horizontal coordinate (in pixels) to set the cursor to. + y (int): The vertical coordinate (in pixels) to set the cursor to. + """ + location = Location(x=Length(root=x), y=Length(root=y)) + command = SetMousePositionCommand(parameters=[location]) + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_mouse_position({x},{y})" + ) + self._send_command(command) + + @telemetry.record_call() + def render_quad(self, style: RenderObjectStyle) -> int: + """ + Render a quad object to the display. + + Args: + style (RenderObjectStyle): The style properties for the quad. + + Returns: + int: Object ID. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f"render_quad({style})") + command = AddRenderObjectCommand(parameters=["Quad", style]) + res = self._send_command(command) + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call() + def render_line(self, style: RenderObjectStyle, points: list[Coordinate]) -> int: + """ + Render a line object to the display. + + Args: + style (RenderObjectStyle): The style properties for the line. + points (list[Coordinates]): The points defining the line. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"render_line({style}, {points})" + ) + command = AddRenderObjectCommand(parameters=["Line", style, points]) + res = self._send_command(command) + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call(exclude={"image_data"}) + def render_image(self, style: RenderObjectStyle, image_data: str) -> int: + """ + Render an image object to the display. + + Args: + style (RenderObjectStyle): The style properties for the image. + image_data (str): The base64-encoded image data. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"render_image({style}, [image_data])" + ) + image = RenderImage(root=image_data) + command = AddRenderObjectCommand(parameters=["Image", style, image]) + res = self._send_command(command) + + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call() + def render_text(self, style: RenderObjectStyle, content: str) -> int: + """ + Render a text object to the display. + + Args: + style (RenderObjectStyle): The style properties for the text. + content (str): The text content to display. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"render_text({style}, {content})" + ) + text = RenderText(root=content) + command = AddRenderObjectCommand(parameters=["Text", style, text]) + res = self._send_command(command) + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call() + def update_render_object(self, object_id: int, style: RenderObjectStyle) -> None: + """ + Update styling properties of an existing render object. + + Args: + object_id (float): The ID of the render object to update. + style (RenderObjectStyle): The new style properties. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"update_render_object({object_id}, {style})" + ) + render_object_id = RenderObjectId(root=object_id) + command = UpdateRenderObjectCommand(parameters=[render_object_id, style]) + self._send_command(command) + + @telemetry.record_call() + def delete_render_object(self, object_id: int) -> None: + """ + Delete an existing render object from the display. + + Args: + object_id (RenderObjectId): The ID of the render object to delete. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"delete_render_object({object_id})" + ) + render_object_id = RenderObjectId(root=object_id) + command = DeleteRenderObjectCommand(parameters=[render_object_id]) + self._send_command(command) + + @telemetry.record_call() + def clear_render_objects(self) -> None: + """ + Clear all render objects from the display. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "clear_render_objects()") + command = ClearRenderObjectsCommand() + self._send_command(command) + + def get_system_info(self) -> GetSystemInfoResponseModel: + """ + Get the system information. + + Returns: + SystemInfo: The system information. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_system_info()") + command = GetSystemInfoCommand() + res = self._send_command(command).message.command + if not isinstance(res, GetSystemInfoResponse): + message = ( + f"get_system_info: expected GetSystemInfoResponse from the " + f"controller but got {type(res).__name__}: {res!r}" + ) + raise DesktopAgentOsError(message) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_system_info() -> {res.response}" + ) + return res.response + + def get_active_process(self) -> GetActiveProcessResponseModel: + """ + Get the active process. + + Returns: + GetActiveProcessResponseModel: The active process. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_active_process()") + command = GetActiveProcessCommand() + res = self._send_command(command).message.command + if not isinstance(res, GetActiveProcessResponse): + message = ( + f"get_active_process: expected GetActiveProcessResponse from the " + f"controller but got {type(res).__name__}: {res!r}" + ) + raise DesktopAgentOsError(message) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_active_process() -> {res.response}" + ) + return res.response + + def set_active_process(self, process_id: int) -> None: + """ + Set the active process. + + Args: + process_id (int): The ID of the process to set as active. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_active_process({process_id})" + ) + _process_id = Parameter3(root=process_id) + command = SetActiveProcessCommand(parameters=[_process_id]) + self._send_command(command) + + def get_active_window(self) -> GetActiveWindowResponseModel: + """ + Gets the window id and name in addition to the process id + and name of the currently active window (in focus). + + + Returns: + GetActiveWindowResponseModel: The active window. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_active_window()") + command = GetActiveWindowCommand() + res = self._send_command(command).message.command + if not isinstance(res, GetActiveWindowResponse): + message = ( + f"get_active_window: expected GetActiveWindowResponse from the " + f"controller but got {type(res).__name__}: {res!r}" + ) + raise DesktopAgentOsError(message) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_active_window() -> {res.response}" + ) + return res.response + + def set_window_in_focus(self, process_id: int, window_id: int) -> None: + """ + Sets the window with the specified windowId of the process + with the specified processId active, + which brings it to the front and gives it focus. + + Args: + process_id (int): The ID of the process that owns the window. + window_id (int): The ID of the window to set as active. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_window_in_focus({process_id}, {window_id})" + ) + _process_id = Parameter3(root=process_id) + _window_id = Parameter3(root=window_id) + command = SetActiveWindowCommand(parameters=[_process_id, _window_id]) + self._send_command(command) diff --git a/src/askui/tools/askui/askui_controller_client_settings.py b/src/askui/tools/askui/askui_controller_client_settings.py deleted file mode 100644 index 28db94d7..00000000 --- a/src/askui/tools/askui/askui_controller_client_settings.py +++ /dev/null @@ -1,26 +0,0 @@ -from pydantic import Field -from pydantic_settings import BaseSettings, SettingsConfigDict - - -class AskUiControllerClientSettings(BaseSettings): - """ - Settings for the AskUI Remote Device Controller client. - """ - - model_config = SettingsConfigDict( - env_prefix="ASKUI_CONTROLLER_CLIENT_", - ) - - server_address: str = Field( - default="localhost:23000", - description="Address of the AskUI Remote Device Controller server.", - ) - - server_autostart: bool = Field( - default=True, - description="Whether to automatically start the AskUI Remote Device" - "Controller server. Defaults to True.", - ) - - -__all__ = ["AskUiControllerClientSettings"] diff --git a/src/askui/tools/askui/exceptions.py b/src/askui/tools/askui/exceptions.py index 1398ff2b..622d7d05 100644 --- a/src/askui/tools/askui/exceptions.py +++ b/src/askui/tools/askui/exceptions.py @@ -42,7 +42,11 @@ class AskUiControllerOperationTimeoutError(AskUiControllerError): """ def __init__( - self, message: str = "Action not yet done", timeout_seconds: float | None = None + self, + message: str = ( + "Controller action did not finish within the expected time window." + ), + timeout_seconds: float | None = None, ): super().__init__(message) self.timeout_seconds = timeout_seconds @@ -52,7 +56,7 @@ class AskUiControllerInvalidCommandError(AskUiControllerError): """Exception raised when a command sent to the controller is invalid. This exception is raised when a command fails schema validation on the - controller server side, typically due to malformed command structure or + Agent OS server side, typically due to malformed command structure or invalid parameters. Args: @@ -61,12 +65,13 @@ class AskUiControllerInvalidCommandError(AskUiControllerError): def __init__(self, details: str | None = None): error_msg = ( - "AgentOS: Command validation failed" - " This error may be resolved by updating the AskUI" - " controller to the latest version." + "AgentOS: command validation failed on the Agent OS server. " + "This is typically caused by a malformed command or a version " + "mismatch; updating the AskUI controller to the latest version " + "may resolve it." ) if details: - error_msg += f"\n{details}" + error_msg += f"\nServer details: {details}" super().__init__(error_msg) self.details = details diff --git a/src/askui/tools/computer/__init__.py b/src/askui/tools/computer/__init__.py index 0410151e..0cb7fd0a 100644 --- a/src/askui/tools/computer/__init__.py +++ b/src/askui/tools/computer/__init__.py @@ -1,10 +1,12 @@ from .connect_tool import ComputerConnectTool from .disconnect_tool import ComputerDisconnectTool +from .get_active_agent_os_server_tool import ComputerGetActiveAgentOsServerTool from .get_mouse_position_tool import ComputerGetMousePositionTool from .get_system_info_tool import ComputerGetSystemInfoTool from .keyboard_pressed_tool import ComputerKeyboardPressedTool from .keyboard_release_tool import ComputerKeyboardReleaseTool from .keyboard_tap_tool import ComputerKeyboardTapTool +from .list_agent_os_servers_tool import ComputerListAgentOsServersTool from .list_displays_tool import ComputerListDisplaysTool from .mouse_click_tool import ComputerMouseClickTool from .mouse_hold_down_tool import ComputerMouseHoldDownTool @@ -14,12 +16,14 @@ from .retrieve_active_display_tool import ComputerRetrieveActiveDisplayTool from .screenshot_tool import ComputerScreenshotTool from .set_active_display_tool import ComputerSetActiveDisplayTool +from .switch_agent_os_server_tool import ComputerSwitchAgentOsServerTool from .type_tool import ComputerTypeTool __all__ = [ "ComputerGetSystemInfoTool", "ComputerConnectTool", "ComputerDisconnectTool", + "ComputerGetActiveAgentOsServerTool", "ComputerGetMousePositionTool", "ComputerKeyboardPressedTool", "ComputerKeyboardReleaseTool", @@ -32,6 +36,8 @@ "ComputerScreenshotTool", "ComputerTypeTool", "ComputerListDisplaysTool", + "ComputerListAgentOsServersTool", "ComputerRetrieveActiveDisplayTool", "ComputerSetActiveDisplayTool", + "ComputerSwitchAgentOsServerTool", ] diff --git a/src/askui/tools/computer/get_active_agent_os_server_tool.py b/src/askui/tools/computer/get_active_agent_os_server_tool.py new file mode 100644 index 00000000..629c40c0 --- /dev/null +++ b/src/askui/tools/computer/get_active_agent_os_server_tool.py @@ -0,0 +1,18 @@ +from askui.models.shared import ComputerBaseTool +from askui.tools.agent_os import AgentOs + + +class ComputerGetActiveAgentOsServerTool(ComputerBaseTool): + def __init__(self, agent_os: AgentOs | None = None) -> None: + super().__init__( + name="get_active_agent_os_server", + description=""" + Return the currently active Agent OS server that agent-os actions + are routed to. + """, + agent_os=agent_os, + ) + self.is_cacheable = False + + def __call__(self) -> str: + return repr(self.agent_os.get_active_agent_os_server()) diff --git a/src/askui/tools/computer/get_mouse_position_tool.py b/src/askui/tools/computer/get_mouse_position_tool.py index 059822a5..068d14b7 100644 --- a/src/askui/tools/computer/get_mouse_position_tool.py +++ b/src/askui/tools/computer/get_mouse_position_tool.py @@ -8,12 +8,20 @@ class ComputerGetMousePositionTool(ComputerBaseTool): def __init__(self, agent_os: ComputerAgentOsFacade | None = None) -> None: super().__init__( name="get_mouse_position", - description="Get the current mouse position.", + description=( + "Get the current mouse position on the currently active Agent OS " + "server. The result is prefixed with the active Agent OS server " + "session GUID." + ), agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) self.is_cacheable = True def __call__(self) -> str: + server = self.agent_os.get_active_agent_os_server(report=False) cursor_position = self.agent_os.get_mouse_position() - return f"Mouse is at position ({cursor_position.x}, {cursor_position.y})." + return ( + f"[Server with id '{server.computer_id}']: Mouse is at position " + f"({cursor_position.x}, {cursor_position.y})." + ) diff --git a/src/askui/tools/computer/get_system_info_tool.py b/src/askui/tools/computer/get_system_info_tool.py index 7f68c07d..cc0872ca 100644 --- a/src/askui/tools/computer/get_system_info_tool.py +++ b/src/askui/tools/computer/get_system_info_tool.py @@ -4,8 +4,9 @@ class ComputerGetSystemInfoTool(ComputerBaseTool): """ - Get the system information. - This tool returns the system information as a JSON object. + Get the system information of the currently active Agent OS server. + This tool returns the system information as a JSON object prefixed with + the active Agent OS server session GUID. The JSON object contains the following fields: - platform: The operating system platform. - label: The operating system label. @@ -17,8 +18,10 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: super().__init__( name="get_system_info_tool", description=""" - Get the system information. - This tool returns the system information as a JSON object. + Get the system information of the currently active Agent OS server. + This tool returns the system information as a JSON object prefixed + with the active Agent OS server session GUID so it is clear which + server the info belongs to. The JSON object contains the following fields: - platform: The operating system platform. - label: The operating system label. @@ -29,4 +32,6 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: ) def __call__(self) -> str: - return str(self.agent_os.get_system_info().model_dump_json()) + server = self.agent_os.get_active_agent_os_server(report=False) + system_info_json = self.agent_os.get_system_info().model_dump_json() + return f"[Server with id '{server.computer_id}']: {system_info_json}" diff --git a/src/askui/tools/computer/list_agent_os_servers_tool.py b/src/askui/tools/computer/list_agent_os_servers_tool.py new file mode 100644 index 00000000..95dac4ff --- /dev/null +++ b/src/askui/tools/computer/list_agent_os_servers_tool.py @@ -0,0 +1,19 @@ +from askui.models.shared import ComputerBaseTool +from askui.tools.agent_os import AgentOs + + +class ComputerListAgentOsServersTool(ComputerBaseTool): + def __init__(self, agent_os: AgentOs | None = None) -> None: + super().__init__( + name="list_agent_os_servers", + description=""" + List all the registered Agent OS servers that the agent can route + actions to. Each server has a unique session GUID that can be used + to switch between them. + """, + agent_os=agent_os, + ) + + def __call__(self) -> str: + servers = self.agent_os.list_agent_os_servers() + return ",".join(repr(s) for s in servers) diff --git a/src/askui/tools/computer/list_displays_tool.py b/src/askui/tools/computer/list_displays_tool.py index 68f3c207..3cb30459 100644 --- a/src/askui/tools/computer/list_displays_tool.py +++ b/src/askui/tools/computer/list_displays_tool.py @@ -7,13 +7,17 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: super().__init__( name="list_displays", description=""" - List all the available displays on the computer. + List all the available displays on the currently active Agent OS + server. The result is prefixed with the active Agent OS server + session GUID so it is clear which server the displays belong to. """, agent_os=agent_os, ) self.is_cacheable = True def __call__(self) -> str: - return self.agent_os.list_displays().model_dump_json( + server = self.agent_os.get_active_agent_os_server(report=False) + displays_json = self.agent_os.list_displays().model_dump_json( exclude={"data": {"__all__": {"size"}}}, ) + return f"[Server with id '{server.computer_id}']: {displays_json}" diff --git a/src/askui/tools/computer/retrieve_active_display_tool.py b/src/askui/tools/computer/retrieve_active_display_tool.py index 7eef6cfd..941c9188 100644 --- a/src/askui/tools/computer/retrieve_active_display_tool.py +++ b/src/askui/tools/computer/retrieve_active_display_tool.py @@ -7,14 +7,18 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: super().__init__( name="retrieve_active_display", description=""" - Retrieve the currently active display on the computer. - The display is used to take screenshots and perform actions. + Retrieve the currently active display on the currently active Agent OS + server. The display is used to take screenshots and perform actions. + The result is prefixed with the active Agent OS server session GUID + so it is clear which server the display belongs to. """, agent_os=agent_os, ) self.is_cacheable = True def __call__(self) -> str: - return str( - self.agent_os.retrieve_active_display().model_dump_json(exclude={"size"}) + server = self.agent_os.get_active_agent_os_server(report=False) + display_json = self.agent_os.retrieve_active_display().model_dump_json( + exclude={"size"} ) + return f"[Server with id '{server.computer_id}']: {display_json}" diff --git a/src/askui/tools/computer/screenshot_tool.py b/src/askui/tools/computer/screenshot_tool.py index fcf46553..3bdcb40d 100644 --- a/src/askui/tools/computer/screenshot_tool.py +++ b/src/askui/tools/computer/screenshot_tool.py @@ -10,12 +10,21 @@ class ComputerScreenshotTool(ComputerBaseTool): def __init__(self, agent_os: ComputerAgentOsFacade | None = None) -> None: super().__init__( name="screenshot", - description="Take a screenshot of the current screen.", + description=( + "Take a screenshot of the current screen on the currently active " + "Agent OS server. The accompanying message is prefixed with the " + "active Agent OS server session GUID so it is clear which server " + "the screenshot was taken on." + ), agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) self.is_cacheable = True def __call__(self) -> tuple[str, Image.Image]: + server = self.agent_os.get_active_agent_os_server(report=False) screenshot = self.agent_os.screenshot() - return "Screenshot was taken.", screenshot + return ( + f"[Server with id '{server.computer_id}']: Screenshot was taken.", + screenshot, + ) diff --git a/src/askui/tools/computer/switch_agent_os_server_tool.py b/src/askui/tools/computer/switch_agent_os_server_tool.py new file mode 100644 index 00000000..f1bb6bb8 --- /dev/null +++ b/src/askui/tools/computer/switch_agent_os_server_tool.py @@ -0,0 +1,27 @@ +from askui.models.shared import ComputerBaseTool +from askui.tools.agent_os import AgentOs + + +class ComputerSwitchAgentOsServerTool(ComputerBaseTool): + def __init__(self, agent_os: AgentOs | None = None) -> None: + super().__init__( + name="switch_agent_os_server", + description=""" + Switch the active Agent OS server by its `computer_id`. Future + agent-os actions are routed to the newly selected server. Use + `list_agent_os_servers` to discover the available computer ids. + """, + input_schema={ + "type": "object", + "properties": { + "computer_id": { + "type": "string", + }, + }, + "required": ["computer_id"], + }, + agent_os=agent_os, + ) + + def __call__(self, computer_id: str) -> str: + return repr(self.agent_os.switch_agent_os_server(computer_id)) diff --git a/src/askui/tools/computer_agent_os_facade.py b/src/askui/tools/computer_agent_os_facade.py index 3be1481e..e7cd5596 100644 --- a/src/askui/tools/computer_agent_os_facade.py +++ b/src/askui/tools/computer_agent_os_facade.py @@ -1,6 +1,9 @@ +from collections.abc import Iterator +from contextlib import contextmanager from typing import TYPE_CHECKING from PIL import Image +from typing_extensions import Self from askui.models.shared.tool_tags import ToolTags from askui.tools.agent_os import ( @@ -18,6 +21,10 @@ from askui.utils.image_utils import scale_coordinates, scale_image_to_fit if TYPE_CHECKING: + from askui.tools.askui.agent_os_server import ( + AgentOsServer, + RemoteAgentOsServer, + ) from askui.tools.askui.askui_ui_controller_grpc.generated import ( Controller_V1_pb2 as controller_v1_pbs, ) @@ -266,6 +273,44 @@ def set_window_in_focus(self, process_id: int, window_id: int) -> None: """ self._agent_os.set_window_in_focus(process_id, window_id) + def add_agent_os_server(self, server: "AgentOsServer") -> "AgentOsServer": + return self._agent_os.add_agent_os_server(server) + + def add_remote_agent_os_server( + self, + address: str, + description: str, + ) -> "RemoteAgentOsServer": + return self._agent_os.add_remote_agent_os_server( + address=address, description=description + ) + + def reset_agent_os_servers( + self, + agent_os_servers: "list[AgentOsServer] | None" = None, + ) -> None: + self._agent_os.reset_agent_os_servers(agent_os_servers) + + def list_agent_os_servers(self) -> "list[AgentOsServer]": + return self._agent_os.list_agent_os_servers() + + def get_active_agent_os_server(self, report: bool = True) -> "AgentOsServer": + return self._agent_os.get_active_agent_os_server(report=report) + + def switch_agent_os_server(self, computer_id: str) -> "AgentOsServer": + agent_os_server = self._agent_os.switch_agent_os_server(computer_id) + self._real_screen_resolution = None + return agent_os_server + + @contextmanager + def temporary_select(self, computer_id: str) -> Iterator[Self]: + with self._agent_os.temporary_select(computer_id): + self._real_screen_resolution = None + try: + yield self + finally: + self._real_screen_resolution = None + def _scale_coordinates_back( self, x: int, diff --git a/tests/conftest.py b/tests/conftest.py index 5eb112db..5d863453 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,6 @@ from pytest_mock import MockerFixture from askui.tools.agent_os import AgentOs, Display, DisplaySize -from askui.tools.toolbox import AgentToolbox @pytest.fixture @@ -97,9 +96,13 @@ def agent_os_mock(mocker: MockerFixture) -> AgentOs: @pytest.fixture -def agent_toolbox_mock(agent_os_mock: AgentOs) -> AgentToolbox: - """Fixture providing a mock agent toolbox.""" - return AgentToolbox(agent_os=agent_os_mock) +def agent_os_mock_patch(mocker: MockerFixture, agent_os_mock: AgentOs) -> AgentOs: + """Patches `AskUiControllerClient` so `ComputerAgent` uses `agent_os_mock`.""" + mocker.patch( + "askui.computer_agent.AskUiControllerClient", + return_value=agent_os_mock, + ) + return agent_os_mock @pytest.fixture(autouse=True) diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index 19bdbaa6..e1f96ac8 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -27,7 +27,7 @@ from askui.models.shared.settings import LocateSettings from askui.models.types.geometry import PointList from askui.reporting import Reporter, SimpleHtmlReporter -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import AgentOs from askui.utils.image_utils import ImageSource @@ -98,7 +98,7 @@ def combo_locate_model(path_fixtures: pathlib.Path) -> LocateModel: @pytest.fixture def agent_with_pta_model( pta_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -106,7 +106,6 @@ def agent_with_pta_model( detection_provider=_LocateModelDetectionProvider(pta_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @@ -114,7 +113,7 @@ def agent_with_pta_model( @pytest.fixture def agent_with_ocr_model( ocr_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -122,7 +121,6 @@ def agent_with_ocr_model( detection_provider=_LocateModelDetectionProvider(ocr_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @@ -130,7 +128,7 @@ def agent_with_ocr_model( @pytest.fixture def agent_with_ai_element_model( ai_element_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -138,7 +136,6 @@ def agent_with_ai_element_model( detection_provider=_LocateModelDetectionProvider(ai_element_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @@ -146,7 +143,7 @@ def agent_with_ai_element_model( @pytest.fixture def agent_with_combo_model( combo_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -154,19 +151,17 @@ def agent_with_combo_model( detection_provider=_LocateModelDetectionProvider(combo_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @pytest.fixture def vision_agent( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: """Fixture providing a ComputerAgent instance.""" with ComputerAgent( reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index bae0d4e8..8156f4ad 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -19,7 +19,7 @@ from askui.models.shared.settings import GetSettings from askui.models.types.response_schemas import ResponseSchema from askui.reporting import Reporter -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import AgentOs from askui.utils.source_utils import Source @@ -97,7 +97,7 @@ class BrowserContextResponse(ResponseSchemaBase): ) def test_get( vision_agent: ComputerAgent, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel | None, @@ -112,7 +112,6 @@ def test_get( settings=AgentSettings( image_qa_provider=_GetModelImageQAProvider(get_model) ), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: url = agent.get( @@ -142,14 +141,13 @@ def test_get( ], ) def test_get_with_pdf_with_gemini_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_pdf: pathlib.Path, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -180,7 +178,7 @@ def test_get_with_pdf_with_gemini_model( ], ) def test_get_with_pdf_too_large( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_pdf: pathlib.Path, @@ -189,7 +187,6 @@ def test_get_with_pdf_too_large( mocker.patch("askui.models.askui.get_model.MAX_FILE_SIZE_BYTES", 1) with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: with pytest.raises(ValueError, match="PDF file size exceeds the limit"): @@ -232,14 +229,13 @@ def test_get_with_pdf_too_large_with_default_model( ], ) def test_get_with_xlsx_with_gemini_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_excel: pathlib.Path, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -279,14 +275,13 @@ class SalaryResponse(ResponseSchemaBase): ], ) def test_get_with_xlsx_with_gemini_model_with_response_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_excel: pathlib.Path, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -325,7 +320,7 @@ def test_get_with_docs_with_default_model( def test_get_with_fallback_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, ) -> None: @@ -338,7 +333,6 @@ def test_get_with_fallback_model( image_qa_provider=_GetModelImageQAProvider(askui_get_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: url = agent.get( "What is the current url shown in the url bar?", @@ -393,7 +387,7 @@ def test_get_with_response_schema_with_default_value( ) def test_get_with_response_schema( vision_agent: ComputerAgent, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel | None, @@ -409,7 +403,6 @@ def test_get_with_response_schema( settings=AgentSettings( image_qa_provider=_GetModelImageQAProvider(get_model) ), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -434,14 +427,13 @@ def test_get_with_response_schema( ], ) def test_get_with_nested_and_inherited_response_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -473,14 +465,13 @@ class LinkedListNode(ResponseSchemaBase): ], ) def test_get_with_recursive_response_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: with pytest.raises( @@ -507,14 +498,13 @@ def test_get_with_recursive_response_schema( ], ) def test_get_with_string_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -545,14 +535,13 @@ def test_get_with_string_schema( ], ) def test_get_with_boolean_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -577,14 +566,13 @@ def test_get_with_boolean_schema( ], ) def test_get_with_integer_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -609,14 +597,13 @@ def test_get_with_integer_schema( ], ) def test_get_with_float_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -641,14 +628,13 @@ def test_get_with_float_schema( ], ) def test_get_returns_str_when_no_schema_specified( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -675,14 +661,13 @@ class Basis(ResponseSchemaBase): ], ) def test_get_with_basis_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -715,14 +700,13 @@ class BasisWithNestedRootModel(ResponseSchemaBase): ], ) def test_get_with_nested_root_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -774,7 +758,7 @@ class PageDom(ResponseSchemaBase): ], ) def test_get_with_deeply_nested_response_schema_with_model_that_does_not_support_recursion( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, @@ -786,7 +770,6 @@ def test_get_with_deeply_nested_response_schema_with_model_that_does_not_support """ with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( diff --git a/tests/e2e/test_telemetry.py b/tests/e2e/test_telemetry.py index 25b9202a..a2d3bae8 100644 --- a/tests/e2e/test_telemetry.py +++ b/tests/e2e/test_telemetry.py @@ -5,13 +5,13 @@ from askui import locators as loc from askui.container import telemetry from askui.telemetry.processors import Segment, SegmentSettings -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import AgentOs @pytest.mark.timeout(60) def test_telemetry_with_nonexistent_domain_should_not_block( github_login_screenshot: Image.Image, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG001 ) -> None: telemetry.set_processors( [ @@ -23,6 +23,6 @@ def test_telemetry_with_nonexistent_domain_should_not_block( ) ] ) - with ComputerAgent(tools=agent_toolbox_mock) as agent: + with ComputerAgent() as agent: agent.locate(loc.Text(), screenshot=github_login_screenshot) assert True diff --git a/tests/e2e/tools/askui/test_askui_controller.py b/tests/e2e/tools/askui/test_askui_controller.py index bca9e591..1f100cf0 100644 --- a/tests/e2e/tools/askui/test_askui_controller.py +++ b/tests/e2e/tools/askui/test_askui_controller.py @@ -7,29 +7,29 @@ from askui.reporting import CompositeReporter from askui.tools.agent_os import Coordinate +from askui.tools.askui import LocalAgentOsServer from askui.tools.askui.askui_controller import ( AskUiControllerClient, - AskUiControllerServer, RenderObjectStyle, ) from askui.tools.askui.askui_controller_settings import AskUiControllerSettings @pytest.fixture -def controller_server() -> AskUiControllerServer: - return AskUiControllerServer( +def agent_os_server() -> LocalAgentOsServer: + return LocalAgentOsServer( settings=AskUiControllerSettings(controller_args="--showOverlay true") ) @pytest.fixture def controller_client( - controller_server: AskUiControllerServer, + agent_os_server: LocalAgentOsServer, ) -> AskUiControllerClient: return AskUiControllerClient( reporter=CompositeReporter(), display=1, - controller_server=controller_server, + agent_os_servers=[agent_os_server], ) diff --git a/tests/integration/agent/test_retry.py b/tests/integration/agent/test_retry.py index 8f08d51a..6edb85c4 100644 --- a/tests/integration/agent/test_retry.py +++ b/tests/integration/agent/test_retry.py @@ -10,7 +10,7 @@ from askui.models.exceptions import ElementNotFoundError, ModelNotFoundError from askui.models.shared.settings import LocateSettings from askui.models.types.geometry import PointList -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import AgentOs from askui.utils.image_utils import ImageSource @@ -58,21 +58,21 @@ def always_failing_provider() -> FailingDetectionProvider: @pytest.fixture def agent_with_retry( - failing_provider: FailingDetectionProvider, agent_toolbox_mock: AgentToolbox + failing_provider: FailingDetectionProvider, + agent_os_mock_patch: AgentOs, # noqa: ARG001 ) -> ComputerAgent: return ComputerAgent( settings=AgentSettings(detection_provider=failing_provider), - tools=agent_toolbox_mock, ) @pytest.fixture def agent_with_retry_on_multiple_exceptions( - failing_provider: FailingDetectionProvider, agent_toolbox_mock: AgentToolbox + failing_provider: FailingDetectionProvider, + agent_os_mock_patch: AgentOs, # noqa: ARG001 ) -> ComputerAgent: return ComputerAgent( settings=AgentSettings(detection_provider=failing_provider), - tools=agent_toolbox_mock, retry=ConfigurableRetry( on_exception_types=( ElementNotFoundError, @@ -88,11 +88,11 @@ def agent_with_retry_on_multiple_exceptions( @pytest.fixture def agent_always_fail( - always_failing_provider: FailingDetectionProvider, agent_toolbox_mock: AgentToolbox + always_failing_provider: FailingDetectionProvider, + agent_os_mock_patch: AgentOs, # noqa: ARG001 ) -> ComputerAgent: return ComputerAgent( settings=AgentSettings(detection_provider=always_failing_provider), - tools=agent_toolbox_mock, retry=ConfigurableRetry( on_exception_types=(ElementNotFoundError,), strategy="Fixed", diff --git a/tests/integration/test_custom_models.py b/tests/integration/test_custom_models.py index 996f610a..55829cb5 100644 --- a/tests/integration/test_custom_models.py +++ b/tests/integration/test_custom_models.py @@ -26,7 +26,7 @@ from askui.models.shared.prompts import SystemPrompt from askui.models.shared.settings import GetSettings, LocateSettings from askui.models.shared.tools import ToolCollection -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import AgentOs from askui.utils.image_utils import ImageSource from askui.utils.source_utils import Source @@ -148,12 +148,11 @@ def detection_provider(self) -> SimpleDetectionProvider: def test_inject_and_use_custom_vlm_provider( self, vlm_provider: SimpleVlmProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG002 ) -> None: """Test injecting and using a custom VLM provider.""" with ComputerAgent( settings=AgentSettings(vlm_provider=vlm_provider), - tools=agent_toolbox_mock, ) as agent: agent.act("test goal") @@ -175,12 +174,11 @@ def test_inject_and_use_custom_vlm_provider( def test_inject_and_use_custom_image_qa_provider( self, image_qa_provider: SimpleImageQAProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG002 ) -> None: """Test injecting and using a custom image Q&A provider.""" with ComputerAgent( settings=AgentSettings(image_qa_provider=image_qa_provider), - tools=agent_toolbox_mock, ) as agent: result = agent.get("test query") @@ -190,13 +188,12 @@ def test_inject_and_use_custom_image_qa_provider( def test_inject_and_use_custom_image_qa_provider_with_pdf( self, image_qa_provider: SimpleImageQAProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG002 path_fixtures_dummy_pdf: pathlib.Path, ) -> None: """Test injecting and using a custom image Q&A provider with a PDF.""" with ComputerAgent( settings=AgentSettings(image_qa_provider=image_qa_provider), - tools=agent_toolbox_mock, ) as agent: result = agent.get("test query", source=path_fixtures_dummy_pdf) @@ -206,12 +203,11 @@ def test_inject_and_use_custom_image_qa_provider_with_pdf( def test_inject_and_use_custom_detection_provider( self, detection_provider: SimpleDetectionProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG002 ) -> None: """Test injecting and using a custom detection provider.""" with ComputerAgent( settings=AgentSettings(detection_provider=detection_provider), - tools=agent_toolbox_mock, ) as agent: agent.click("test element") @@ -222,7 +218,7 @@ def test_inject_all_custom_providers( vlm_provider: SimpleVlmProvider, image_qa_provider: SimpleImageQAProvider, detection_provider: SimpleDetectionProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG002 ) -> None: """Test injecting all custom providers at once.""" with ComputerAgent( @@ -231,7 +227,6 @@ def test_inject_all_custom_providers( image_qa_provider=image_qa_provider, detection_provider=detection_provider, ), - tools=agent_toolbox_mock, ) as agent: agent.act("test goal") result = agent.get("test query") @@ -258,7 +253,7 @@ def test_inject_all_custom_providers( def test_use_response_schema_with_custom_image_qa_provider( self, image_qa_provider: SimpleImageQAProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG002 ) -> None: """Test using a response schema with a custom image Q&A provider.""" response = SimpleResponseSchema(value="test value") @@ -266,7 +261,6 @@ def test_use_response_schema_with_custom_image_qa_provider( with ComputerAgent( settings=AgentSettings(image_qa_provider=image_qa_provider), - tools=agent_toolbox_mock, ) as agent: result = agent.get("test query", response_schema=SimpleResponseSchema) @@ -276,8 +270,8 @@ def test_use_response_schema_with_custom_image_qa_provider( def test_defaults_to_built_in_providers_when_not_provided( self, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: AgentOs, # noqa: ARG002 ) -> None: """Test agent uses built-in defaults when custom ones not provided.""" - with ComputerAgent(tools=agent_toolbox_mock) as agent: + with ComputerAgent() as agent: assert agent is not None diff --git a/tests/unit/tools/askui/test_agent_os_server.py b/tests/unit/tools/askui/test_agent_os_server.py new file mode 100644 index 00000000..569bc9fb --- /dev/null +++ b/tests/unit/tools/askui/test_agent_os_server.py @@ -0,0 +1,140 @@ +import re +from typing import Callable + +import pytest + +from askui.tools.askui.agent_os_server import ( + AgentOsServer, + LocalAgentOsServer, + RemoteAgentOsServer, + _generate_session_guid, + _replace_port, +) + + +class TestSessionGuid: + def test_generated_guid_is_brace_wrapped_uuid(self) -> None: + guid = _generate_session_guid() + assert re.fullmatch( + r"\{[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\}", + guid, + ) + + def test_each_generated_guid_is_unique(self) -> None: + assert _generate_session_guid() != _generate_session_guid() + + +class TestReplacePort: + def test_replaces_port_on_bare_authority(self) -> None: + assert _replace_port("example.com:1234", 23000) == "example.com:23000" + + def test_replaces_port_on_url_with_scheme(self) -> None: + assert _replace_port("http://example.com:1234", 23000) == "example.com:23000" + + def test_falls_back_to_localhost_when_host_missing(self) -> None: + # A bare ":1234" has no hostname, so the helper falls back to "localhost". + assert _replace_port(":1234", 23000) == "localhost:23000" + + +class TestAgentOsServer: + def test_session_guid_unique_per_instance(self) -> None: + a = RemoteAgentOsServer(address="1.2.3.4:23000", description="a") + b = RemoteAgentOsServer(address="5.6.7.8:23000", description="b") + assert a.session_guid != b.session_guid + + def test_computer_id_defaults_to_session_guid(self) -> None: + s = RemoteAgentOsServer(address="1.2.3.4:23000", description="a") + assert s.computer_id == s.session_guid + + def test_explicit_computer_id_is_preserved(self) -> None: + s = RemoteAgentOsServer( + address="1.2.3.4:23000", description="a", computer_id="laptop" + ) + assert s.computer_id == "laptop" + assert s.session_guid != "laptop" + + def test_display_defaults_to_one_and_is_settable(self) -> None: + s = RemoteAgentOsServer(address="1.2.3.4:23000", description="a") + assert s.display == 1 + s.display = 3 + assert s.display == 3 + + def test_explicit_display_is_preserved(self) -> None: + s = RemoteAgentOsServer(address="1.2.3.4:23000", description="a", display=2) + assert s.display == 2 + + def test_repr_contains_identity_fields(self) -> None: + s = RemoteAgentOsServer( + address="1.2.3.4:23000", + description="my rig", + display=2, + computer_id="rig", + ) + r = repr(s) + assert "RemoteAgentOsServer" in r + assert "computer_id='rig'" in r + assert "description='my rig'" in r + assert "display=2" in r + + def test_base_class_is_not_local(self) -> None: + s = RemoteAgentOsServer(address="1.2.3.4:23000", description="a") + assert s.is_local is False + + def test_start_and_stop_are_no_ops_on_remote(self) -> None: + s = RemoteAgentOsServer(address="1.2.3.4:23000", description="a") + s.start() + s.stop() + + +class TestLocalAgentOsServer: + def test_is_local(self) -> None: + s = LocalAgentOsServer(discover_service=False) + assert s.is_local is True + + def test_default_description(self) -> None: + s = LocalAgentOsServer(discover_service=False) + assert s.description == "Local Agent OS server" + + def test_default_address(self) -> None: + s = LocalAgentOsServer(discover_service=False) + assert s.address == "localhost:23000" + + def test_is_service_default_false(self) -> None: + s = LocalAgentOsServer(discover_service=False) + assert s.is_service is False + + def test_explicit_computer_id(self) -> None: + s = LocalAgentOsServer(discover_service=False, computer_id="my-laptop") + assert s.computer_id == "my-laptop" + + def test_parse_port_rejects_bad_address(self) -> None: + s = LocalAgentOsServer(discover_service=False, address="no-port-here") + with pytest.raises(ValueError, match="Could not parse port"): + s._parse_port() # noqa: SLF001 - intentional unit test against helper + + def test_parse_port_extracts_port(self) -> None: + s = LocalAgentOsServer(discover_service=False, address="localhost:24567") + assert s._parse_port() == 24567 # noqa: SLF001 + + +class TestSubclassesPassThroughDisplayAndId: + @pytest.mark.parametrize( + "factory", + [ + lambda: LocalAgentOsServer( + discover_service=False, display=4, computer_id="local" + ), + lambda: RemoteAgentOsServer( + address="1.2.3.4:23000", + description="r", + display=4, + computer_id="remote", + ), + ], + ) + def test_display_and_computer_id_round_trip( + self, factory: Callable[[], AgentOsServer] + ) -> None: + s: AgentOsServer = factory() + assert s.display == 4 + assert s.computer_id in {"local", "remote"} diff --git a/tests/unit/tools/askui/test_agent_os_server_manager.py b/tests/unit/tools/askui/test_agent_os_server_manager.py new file mode 100644 index 00000000..6a928010 --- /dev/null +++ b/tests/unit/tools/askui/test_agent_os_server_manager.py @@ -0,0 +1,187 @@ +import pytest + +from askui.tools.askui.agent_os_server import ( + LocalAgentOsServer, + RemoteAgentOsServer, +) +from askui.tools.askui.agent_os_server_manager import AgentOsServerManager + + +def _make_remote( + address: str = "1.2.3.4:23000", + description: str = "remote", + computer_id: str | None = None, +) -> RemoteAgentOsServer: + return RemoteAgentOsServer( + address=address, description=description, computer_id=computer_id + ) + + +def _make_local(computer_id: str | None = None) -> LocalAgentOsServer: + return LocalAgentOsServer(discover_service=False, computer_id=computer_id) + + +class TestConstruction: + def test_empty_constructor_yields_empty_manager(self) -> None: + m = AgentOsServerManager() + assert m.list() == [] + assert m.active is None + assert len(m) == 0 + + def test_constructor_registers_initial_servers_in_order(self) -> None: + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m = AgentOsServerManager(agent_os_servers=[a, b]) + assert m.list() == [a, b] + # First registered becomes active. + assert m.active is a + + def test_first_added_becomes_active(self) -> None: + m = AgentOsServerManager() + a = _make_remote(computer_id="a") + m.add(a) + assert m.active is a + + +class TestAddConstraints: + def test_rejects_second_local_server(self) -> None: + m = AgentOsServerManager() + m.add(_make_local(computer_id="first")) + with pytest.raises(ValueError, match="second local Agent OS server"): + m.add(_make_local(computer_id="second")) + + def test_rejects_duplicate_computer_id(self) -> None: + m = AgentOsServerManager() + m.add(_make_remote(address="1.1.1.1:23000", computer_id="rig")) + with pytest.raises(ValueError, match="computer_id='rig'"): + m.add(_make_remote(address="2.2.2.2:23000", computer_id="rig")) + + def test_rejects_duplicate_remote_address(self) -> None: + m = AgentOsServerManager() + m.add(_make_remote(address="1.1.1.1:23000", computer_id="a")) + with pytest.raises( + ValueError, match="remote Agent OS server with address '1.1.1.1:23000'" + ): + m.add(_make_remote(address="1.1.1.1:23000", computer_id="b")) + + def test_allows_local_plus_remote_with_same_address(self) -> None: + m = AgentOsServerManager() + m.add(_make_local(computer_id="local")) + # Local target's default address is 'localhost:23000' but the local/remote + # address-uniqueness rule only applies between remote servers. + m.add( + _make_remote( + address="localhost:23000", description="remote", computer_id="remote" + ) + ) + assert len(m) == 2 + + +class TestAddRemote: + def test_constructs_and_registers(self) -> None: + m = AgentOsServerManager() + server = m.add_remote(address="1.2.3.4:23000", description="r") + assert isinstance(server, RemoteAgentOsServer) + assert server.address == "1.2.3.4:23000" + assert server.description == "r" + assert m.list() == [server] + + +class TestGetAndSwitch: + def test_get_returns_server_by_computer_id(self) -> None: + m = AgentOsServerManager() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + m.add(a) + assert m.get("a") is a + + def test_get_raises_keyerror_with_registered_ids(self) -> None: + m = AgentOsServerManager() + m.add(_make_remote(address="1.1.1.1:23000", computer_id="a")) + with pytest.raises(KeyError) as exc_info: + m.get("missing") + message = str(exc_info.value) + assert "missing" in message + assert "'a'" in message # registered id surfaced + + def test_switch_changes_active(self) -> None: + m = AgentOsServerManager() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + assert m.active is a + m.switch("b") + assert m.active is b + + def test_switch_unknown_id_raises_keyerror(self) -> None: + m = AgentOsServerManager() + m.add(_make_remote(computer_id="a")) + with pytest.raises(KeyError, match="missing"): + m.switch("missing") + + +class TestGetBySessionGuid: + def test_returns_server_by_session_guid(self) -> None: + m = AgentOsServerManager() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + m.add(a) + assert m.get_by_session_guid(a.session_guid) is a + + def test_returns_none_for_unknown_session_guid(self) -> None: + m = AgentOsServerManager() + m.add(_make_remote(computer_id="a")) + assert m.get_by_session_guid("no-such-guid") is None + + +class TestRemove: + def test_remove_drops_server(self) -> None: + m = AgentOsServerManager() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + m.remove("a") + assert m.list() == [b] + + def test_remove_active_falls_back_to_first_remaining(self) -> None: + m = AgentOsServerManager() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + assert m.active is a + m.remove("a") + assert m.active is b + + def test_remove_last_clears_active(self) -> None: + m = AgentOsServerManager() + m.add(_make_remote(computer_id="a")) + m.remove("a") + assert m.active is None + assert len(m) == 0 + + def test_remove_inactive_keeps_active_unchanged(self) -> None: + m = AgentOsServerManager() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + m.remove("b") + assert m.active is a + + def test_remove_unknown_raises_keyerror(self) -> None: + m = AgentOsServerManager() + m.add(_make_remote(computer_id="a")) + with pytest.raises(KeyError): + m.remove("missing") + + +class TestReset: + def test_reset_clears_all(self) -> None: + m = AgentOsServerManager() + m.add(_make_remote(computer_id="a")) + m.add(_make_remote(address="2.2.2.2:23000", computer_id="b")) + m.reset() + assert m.list() == [] + assert m.active is None + assert len(m) == 0 diff --git a/tests/unit/tools/askui/test_askui_controller_client.py b/tests/unit/tools/askui/test_askui_controller_client.py new file mode 100644 index 00000000..455ec451 --- /dev/null +++ b/tests/unit/tools/askui/test_askui_controller_client.py @@ -0,0 +1,202 @@ +""" +Unit tests for `AskUiControllerClient`'s multi-server registration / routing +logic. These tests intentionally avoid exercising the gRPC code path (which +needs a real controller binary). They cover the in-memory bookkeeping done by +the client and its `AgentOsServerManager`. +""" + +import pytest + +from askui.tools.askui.agent_os_server import ( + LocalAgentOsServer, + RemoteAgentOsServer, +) +from askui.tools.askui.agent_os_server_manager import AgentOsServerManager +from askui.tools.askui.askui_controller import AskUiControllerClient +from askui.tools.askui.exceptions import AskUiControllerError + + +def _make_local( + description: str = "local", computer_id: str | None = None, display: int = 1 +) -> LocalAgentOsServer: + return LocalAgentOsServer( + description=description, + discover_service=False, + computer_id=computer_id, + display=display, + ) + + +def _make_remote( + address: str = "1.2.3.4:23000", + description: str = "remote", + computer_id: str | None = None, + display: int = 1, +) -> RemoteAgentOsServer: + return RemoteAgentOsServer( + address=address, + description=description, + computer_id=computer_id, + display=display, + ) + + +class TestConstruction: + def test_default_registers_single_local_server(self) -> None: + client = AskUiControllerClient() + servers = client.agent_os_server_manager.list() + assert len(servers) == 1 + assert isinstance(servers[0], LocalAgentOsServer) + + def test_default_propagates_display_to_default_local_server(self) -> None: + client = AskUiControllerClient(display=3) + active = client.agent_os_server_manager.active + assert active is not None + assert active.display == 3 + + def test_accepts_explicit_servers(self) -> None: + a = _make_local(computer_id="local") + b = _make_remote(computer_id="remote") + client = AskUiControllerClient(agent_os_servers=[a, b]) + assert client.agent_os_server_manager.list() == [a, b] + assert client.agent_os_server_manager.active is a + + def test_explicit_servers_keep_their_own_display(self) -> None: + """Constructor's display arg only seeds the auto-created default server.""" + a = _make_local(computer_id="local", display=2) + b = _make_remote(computer_id="remote", display=3) + client = AskUiControllerClient(display=5, agent_os_servers=[a, b]) + assert client.agent_os_server_manager.get("local").display == 2 + assert client.agent_os_server_manager.get("remote").display == 3 + + def test_is_connected_false_before_connect(self) -> None: + client = AskUiControllerClient(agent_os_servers=[_make_remote()]) + assert client.is_connected is False + + +class TestActiveServer: + def test_get_active_returns_first_registered(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = AskUiControllerClient(agent_os_servers=[a, b]) + assert client.get_active_agent_os_server() is a + + def test_get_active_with_empty_manager_raises(self) -> None: + client = AskUiControllerClient(agent_os_servers=[_make_remote()]) + client.agent_os_server_manager.reset() + with pytest.raises(AskUiControllerError, match="No active Agent OS server"): + client.get_active_agent_os_server(report=False) + + +class TestSwitchAgentOsServer: + def test_switch_changes_active_when_disconnected(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = AskUiControllerClient(agent_os_servers=[a, b]) + client.switch_agent_os_server("b") + assert client.agent_os_server_manager.active is b + + def test_switch_unknown_computer_id_raises_keyerror(self) -> None: + client = AskUiControllerClient(agent_os_servers=[_make_local(computer_id="a")]) + with pytest.raises(KeyError, match="missing"): + client.switch_agent_os_server("missing") + + def test_switch_returns_the_new_active_server(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = AskUiControllerClient(agent_os_servers=[a, b]) + result = client.switch_agent_os_server("b") + assert result is b + + def test_per_server_display_preserved_across_switch(self) -> None: + a = _make_local(computer_id="a", display=1) + b = _make_remote(computer_id="b", display=4) + client = AskUiControllerClient(agent_os_servers=[a, b]) + client.switch_agent_os_server("b") + active_b = client.agent_os_server_manager.active + assert active_b is not None + assert active_b.display == 4 + client.switch_agent_os_server("a") + active_a = client.agent_os_server_manager.active + assert active_a is not None + assert active_a.display == 1 + + +class TestListAndReset: + def test_list_returns_registered_servers(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = AskUiControllerClient(agent_os_servers=[a, b]) + assert client.list_agent_os_servers() == [a, b] + + def test_reset_with_no_args_leaves_manager_empty(self) -> None: + client = AskUiControllerClient(agent_os_servers=[_make_remote(computer_id="r")]) + client.reset_agent_os_servers() + assert client.list_agent_os_servers() == [] + + def test_reset_with_new_list_replaces_registrations(self) -> None: + client = AskUiControllerClient( + agent_os_servers=[_make_remote(computer_id="old")] + ) + new_server = _make_remote(address="9.9.9.9:23000", computer_id="new") + client.reset_agent_os_servers([new_server]) + assert client.list_agent_os_servers() == [new_server] + assert client.agent_os_server_manager.active is new_server + + +class TestAddAgentOsServerWhileDisconnected: + def test_add_remote_appends_without_connecting(self) -> None: + client = AskUiControllerClient(agent_os_servers=[_make_local(computer_id="l")]) + added = client.add_remote_agent_os_server( + address="2.2.2.2:23000", description="r" + ) + assert added in client.list_agent_os_servers() + assert client.is_connected is False + + def test_add_already_constructed_server(self) -> None: + client = AskUiControllerClient(agent_os_servers=[_make_local(computer_id="l")]) + extra = _make_remote(address="2.2.2.2:23000", computer_id="r") + result = client.add_agent_os_server(extra) + assert result is extra + assert extra in client.list_agent_os_servers() + + +class TestTemporarySelect: + def test_temporary_select_restores_previous_active(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = AskUiControllerClient(agent_os_servers=[a, b]) + manager = client.agent_os_server_manager + before = manager.active + assert before is a + with client.temporary_select("b"): + inside = manager.active + assert inside is b + after = manager.active + assert after is a + + def test_temporary_select_restores_previous_even_on_exception(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = AskUiControllerClient(agent_os_servers=[a, b]) + error_message = "boom" + with ( + pytest.raises(RuntimeError, match=error_message), + client.temporary_select("b"), + ): + assert client.agent_os_server_manager.active is b + raise RuntimeError(error_message) + assert client.agent_os_server_manager.active is a + + def test_temporary_select_same_id_is_a_noop_around_yield(self) -> None: + a = _make_local(computer_id="a") + client = AskUiControllerClient(agent_os_servers=[a]) + with client.temporary_select("a"): + assert client.agent_os_server_manager.active is a + assert client.agent_os_server_manager.active is a + + +class TestUsesAgentOsServerManager: + def test_underlying_manager_is_an_agent_os_server_manager(self) -> None: + client = AskUiControllerClient(agent_os_servers=[_make_local(computer_id="l")]) + assert isinstance(client.agent_os_server_manager, AgentOsServerManager) diff --git a/tests/unit/tools/askui/test_askui_controller_client_settings.py b/tests/unit/tools/askui/test_askui_controller_client_settings.py deleted file mode 100644 index 3a086453..00000000 --- a/tests/unit/tools/askui/test_askui_controller_client_settings.py +++ /dev/null @@ -1,73 +0,0 @@ -from unittest.mock import patch - -import pytest -from pydantic import ValidationError - -from askui.tools.askui.askui_controller_client_settings import ( - AskUiControllerClientSettings, -) - - -class TestAskUiControllerClientSettings: - """Test suite for AskUiControllerClientSettings.""" - - def test_defaults(self) -> None: - """Defaults are applied when no environment variables are set.""" - with patch.dict("os.environ", {}, clear=True): - settings = AskUiControllerClientSettings() - assert settings.server_address == "localhost:23000" - assert settings.server_autostart is True - - def test_server_address_from_env(self) -> None: - """ - `ASKUI_CONTROLLER_CLIENT_SERVER_ADDRESS` overrides default for `server_address`. - """ - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_ADDRESS": "127.0.0.1:24000"}, - clear=True, - ): - settings = AskUiControllerClientSettings() - assert settings.server_address == "127.0.0.1:24000" - - def test_server_autostart_from_env_false(self) -> None: - """`ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART` parses boolean from env.""" - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART": "False"}, - clear=True, - ): - settings = AskUiControllerClientSettings() - assert settings.server_autostart is False - - def test_server_autostart_from_env_true(self) -> None: - """Boolean true value is parsed correctly from environment variable.""" - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART": "true"}, - clear=True, - ): - settings = AskUiControllerClientSettings() - assert settings.server_autostart is True - - def test_server_address_from_constructor(self) -> None: - """`server_address` is set correctly from constructor.""" - settings = AskUiControllerClientSettings(server_address="127.0.0.1:24000") - assert settings.server_address == "127.0.0.1:24000" - - def test_server_autostart_from_constructor(self) -> None: - """`server_autostart` is set correctly from constructor.""" - settings = AskUiControllerClientSettings(server_autostart=False) - assert settings.server_autostart is False - - def test_autostart_from_env_with_invalid_value(self) -> None: - """ - Test that ValidationError is raised when environment variable is invalid. - """ - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART": "invalid"}, - clear=True, - ): - with pytest.raises(ValidationError): - AskUiControllerClientSettings() diff --git a/tests/unit/tools/computer/__init__.py b/tests/unit/tools/computer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/tools/computer/test_agent_os_server_tools.py b/tests/unit/tools/computer/test_agent_os_server_tools.py new file mode 100644 index 00000000..c840fd3e --- /dev/null +++ b/tests/unit/tools/computer/test_agent_os_server_tools.py @@ -0,0 +1,86 @@ +from unittest.mock import MagicMock + +import pytest + +from askui.tools.agent_os import AgentOs +from askui.tools.askui.agent_os_server import RemoteAgentOsServer +from askui.tools.computer import ( + ComputerGetActiveAgentOsServerTool, + ComputerListAgentOsServersTool, + ComputerSwitchAgentOsServerTool, +) + + +@pytest.fixture +def fake_agent_os() -> MagicMock: + """A MagicMock that passes `isinstance(x, AgentOs)` checks.""" + return MagicMock(spec=AgentOs) + + +class TestComputerListAgentOsServersTool: + def test_tool_name(self, fake_agent_os: MagicMock) -> None: + tool = ComputerListAgentOsServersTool(agent_os=fake_agent_os) + assert tool.base_name == "list_agent_os_servers" + + def test_returns_comma_separated_repr_of_servers( + self, fake_agent_os: MagicMock + ) -> None: + a = RemoteAgentOsServer( + address="1.1.1.1:23000", description="a", computer_id="a" + ) + b = RemoteAgentOsServer( + address="2.2.2.2:23000", description="b", computer_id="b" + ) + fake_agent_os.list_agent_os_servers.return_value = [a, b] + tool = ComputerListAgentOsServersTool(agent_os=fake_agent_os) + out = tool() + assert out == f"{a!r},{b!r}" + + def test_empty_list_yields_empty_string(self, fake_agent_os: MagicMock) -> None: + fake_agent_os.list_agent_os_servers.return_value = [] + tool = ComputerListAgentOsServersTool(agent_os=fake_agent_os) + assert tool() == "" + + +class TestComputerSwitchAgentOsServerTool: + def test_tool_name(self, fake_agent_os: MagicMock) -> None: + tool = ComputerSwitchAgentOsServerTool(agent_os=fake_agent_os) + assert tool.base_name == "switch_agent_os_server" + + def test_input_schema_requires_computer_id(self, fake_agent_os: MagicMock) -> None: + tool = ComputerSwitchAgentOsServerTool(agent_os=fake_agent_os) + schema = tool.input_schema + assert "computer_id" in schema["properties"] + assert schema["required"] == ["computer_id"] + + def test_call_delegates_to_switch_agent_os_server( + self, fake_agent_os: MagicMock + ) -> None: + switched = RemoteAgentOsServer( + address="1.1.1.1:23000", description="new", computer_id="new" + ) + fake_agent_os.switch_agent_os_server.return_value = switched + tool = ComputerSwitchAgentOsServerTool(agent_os=fake_agent_os) + out = tool(computer_id="new") + fake_agent_os.switch_agent_os_server.assert_called_once_with("new") + assert out == repr(switched) + + +class TestComputerGetActiveAgentOsServerTool: + def test_tool_name(self, fake_agent_os: MagicMock) -> None: + tool = ComputerGetActiveAgentOsServerTool(agent_os=fake_agent_os) + assert tool.base_name == "get_active_agent_os_server" + + def test_is_not_cacheable(self, fake_agent_os: MagicMock) -> None: + tool = ComputerGetActiveAgentOsServerTool(agent_os=fake_agent_os) + assert tool.is_cacheable is False + + def test_call_returns_active_server_repr(self, fake_agent_os: MagicMock) -> None: + active = RemoteAgentOsServer( + address="1.1.1.1:23000", description="a", computer_id="a" + ) + fake_agent_os.get_active_agent_os_server.return_value = active + tool = ComputerGetActiveAgentOsServerTool(agent_os=fake_agent_os) + out = tool() + fake_agent_os.get_active_agent_os_server.assert_called_once_with() + assert out == repr(active)