diff --git a/finder_commander/remote_client_agent.example.json b/finder_commander/remote_client_agent.example.json new file mode 100644 index 0000000..e1d5879 --- /dev/null +++ b/finder_commander/remote_client_agent.example.json @@ -0,0 +1,15 @@ +{ + "agent_access_token": "change-me-agent-token", + "client_id": "", + "display_name": "MacBook Pro van Jan", + "endpoint": "http://192.168.1.25:8765", + "heartbeat_interval_seconds": 20, + "platform": "macos", + "registration_token": "change-me-registration-token", + "shares": { + "downloads": "/Users/jan/Downloads", + "movies": "/Users/jan/Movies", + "pictures": "/Users/jan/Pictures" + }, + "webmanager_base_url": "http://127.0.0.1:8080" +} diff --git a/finder_commander/remote_client_agent.py b/finder_commander/remote_client_agent.py new file mode 100644 index 0000000..a21a621 --- /dev/null +++ b/finder_commander/remote_client_agent.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import argparse +import json +import sys +import time +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from urllib import error, request + + +AGENT_VERSION = "1.1.0-phase1" + + +@dataclass +class AgentConfig: + config_path: Path + webmanager_base_url: str + registration_token: str + agent_access_token: str + display_name: str + endpoint: str + shares: dict[str, str] + heartbeat_interval_seconds: int + client_id: str + platform: str = "macos" + + @property + def normalized_base_url(self) -> str: + return self.webmanager_base_url.rstrip("/") + + +def load_config(config_path: Path) -> AgentConfig: + raw = json.loads(config_path.read_text(encoding="utf-8")) + client_id = str(raw.get("client_id", "")).strip() + if not client_id: + client_id = str(uuid.uuid4()) + raw["client_id"] = client_id + config_path.write_text(json.dumps(raw, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + shares_raw = raw.get("shares") or {} + shares: dict[str, str] = {} + if isinstance(shares_raw, dict): + for key, value in shares_raw.items(): + normalized_key = str(key).strip() + normalized_value = str(value).strip() + if normalized_key and normalized_value: + shares[normalized_key] = normalized_value + + if not shares: + raise ValueError("config requires at least one share") + + return AgentConfig( + config_path=config_path, + webmanager_base_url=str(raw.get("webmanager_base_url", "")).strip(), + registration_token=str(raw.get("registration_token", "")).strip(), + agent_access_token=str(raw.get("agent_access_token", "")).strip(), + display_name=str(raw.get("display_name", "")).strip(), + endpoint=str(raw.get("public_endpoint", raw.get("endpoint", ""))).strip(), + shares=shares, + heartbeat_interval_seconds=max(5, int(raw.get("heartbeat_interval_seconds", 20))), + client_id=client_id, + platform=str(raw.get("platform", "macos")).strip() or "macos", + ) + + +def require_non_empty(value: str, field: str) -> str: + normalized = value.strip() + if not normalized: + raise ValueError(f"config field '{field}' is required") + return normalized + + +def build_register_payload(config: AgentConfig) -> dict[str, Any]: + return { + "client_id": config.client_id, + "display_name": config.display_name, + "platform": config.platform, + "agent_version": AGENT_VERSION, + "endpoint": config.endpoint, + "shares": [{"key": key, "label": key.capitalize()} for key in sorted(config.shares.keys())], + } + + +def build_heartbeat_payload(config: AgentConfig) -> dict[str, Any]: + return { + "client_id": config.client_id, + "agent_version": AGENT_VERSION, + } + + +def post_json(url: str, token: str, payload: dict[str, Any]) -> dict[str, Any]: + data = json.dumps(payload).encode("utf-8") + req = request.Request( + url, + method="POST", + data=data, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {token}", + }, + ) + with request.urlopen(req, timeout=10) as resp: + return json.loads(resp.read().decode("utf-8")) + + +def run(config: AgentConfig) -> None: + require_non_empty(config.webmanager_base_url, "webmanager_base_url") + require_non_empty(config.registration_token, "registration_token") + require_non_empty(config.agent_access_token, "agent_access_token") + require_non_empty(config.display_name, "display_name") + require_non_empty(config.endpoint, "public_endpoint") + + register_url = f"{config.normalized_base_url}/api/clients/register" + heartbeat_url = f"{config.normalized_base_url}/api/clients/heartbeat" + + print(f"Starting remote client agent for {config.display_name} ({config.client_id})", flush=True) + print("agent_access_token is configured for future authenticated agent endpoints", flush=True) + + while True: + try: + post_json(register_url, config.registration_token, build_register_payload(config)) + print("register ok", flush=True) + break + except error.HTTPError as exc: + print(f"register failed: HTTP {exc.code}", file=sys.stderr, flush=True) + except error.URLError as exc: + print(f"register failed: {exc.reason}", file=sys.stderr, flush=True) + time.sleep(config.heartbeat_interval_seconds) + + while True: + try: + post_json(heartbeat_url, config.registration_token, build_heartbeat_payload(config)) + print("heartbeat ok", flush=True) + except error.HTTPError as exc: + print(f"heartbeat failed: HTTP {exc.code}", file=sys.stderr, flush=True) + except error.URLError as exc: + print(f"heartbeat failed: {exc.reason}", file=sys.stderr, flush=True) + time.sleep(config.heartbeat_interval_seconds) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Remote client agent Phase 1 for WebManager MVP") + parser.add_argument( + "--config", + default=str(Path(__file__).resolve().with_name("remote_client_agent.example.json")), + help="Path to remote client agent config JSON", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + try: + config = load_config(Path(args.config).resolve()) + run(config) + except KeyboardInterrupt: + return 130 + except Exception as exc: + print(str(exc), file=sys.stderr) + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/webui/backend/app/__pycache__/config.cpython-313.pyc b/webui/backend/app/__pycache__/config.cpython-313.pyc index 6439911..61a6cae 100644 Binary files a/webui/backend/app/__pycache__/config.cpython-313.pyc and b/webui/backend/app/__pycache__/config.cpython-313.pyc differ diff --git a/webui/backend/app/__pycache__/dependencies.cpython-313.pyc b/webui/backend/app/__pycache__/dependencies.cpython-313.pyc index 590bbb6..e016732 100644 Binary files a/webui/backend/app/__pycache__/dependencies.cpython-313.pyc and b/webui/backend/app/__pycache__/dependencies.cpython-313.pyc differ diff --git a/webui/backend/app/__pycache__/main.cpython-313.pyc b/webui/backend/app/__pycache__/main.cpython-313.pyc index cf50032..c179d0f 100644 Binary files a/webui/backend/app/__pycache__/main.cpython-313.pyc and b/webui/backend/app/__pycache__/main.cpython-313.pyc differ diff --git a/webui/backend/app/api/__pycache__/schemas.cpython-313.pyc b/webui/backend/app/api/__pycache__/schemas.cpython-313.pyc index ff6d8d5..89dba5f 100644 Binary files a/webui/backend/app/api/__pycache__/schemas.cpython-313.pyc and b/webui/backend/app/api/__pycache__/schemas.cpython-313.pyc differ diff --git a/webui/backend/app/api/routes_clients.py b/webui/backend/app/api/routes_clients.py new file mode 100644 index 0000000..539aa60 --- /dev/null +++ b/webui/backend/app/api/routes_clients.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Header + +from backend.app.api.schemas import ( + RemoteClientHeartbeatRequest, + RemoteClientItem, + RemoteClientListResponse, + RemoteClientRegisterRequest, +) +from backend.app.dependencies import get_remote_client_service +from backend.app.services.remote_client_service import RemoteClientService + +router = APIRouter(prefix="/clients") + + +@router.get("", response_model=RemoteClientListResponse) +async def list_clients( + service: RemoteClientService = Depends(get_remote_client_service), +) -> RemoteClientListResponse: + return service.list_clients() + + +@router.post("/register", response_model=RemoteClientItem) +async def register_client( + request: RemoteClientRegisterRequest, + authorization: str | None = Header(default=None), + service: RemoteClientService = Depends(get_remote_client_service), +) -> RemoteClientItem: + return service.register_client(authorization=authorization, request=request) + + +@router.post("/heartbeat", response_model=RemoteClientItem) +async def heartbeat( + request: RemoteClientHeartbeatRequest, + authorization: str | None = Header(default=None), + service: RemoteClientService = Depends(get_remote_client_service), +) -> RemoteClientItem: + return service.record_heartbeat(authorization=authorization, request=request) diff --git a/webui/backend/app/api/schemas.py b/webui/backend/app/api/schemas.py index e350fc8..24582ef 100644 --- a/webui/backend/app/api/schemas.py +++ b/webui/backend/app/api/schemas.py @@ -238,3 +238,41 @@ class SearchResultItem(BaseModel): class SearchResponse(BaseModel): items: list[SearchResultItem] truncated: bool + + +class RemoteClientShare(BaseModel): + key: str + label: str + + +class RemoteClientRegisterRequest(BaseModel): + client_id: str + display_name: str + platform: str + agent_version: str + endpoint: str + shares: list[RemoteClientShare] + + +class RemoteClientHeartbeatRequest(BaseModel): + client_id: str + agent_version: str + + +class RemoteClientItem(BaseModel): + client_id: str + display_name: str + platform: str + agent_version: str + endpoint: str + shares: list[RemoteClientShare] + last_seen: str | None = None + status: str + last_error: str | None = None + reachable_at: str | None = None + created_at: str + updated_at: str + + +class RemoteClientListResponse(BaseModel): + items: list[RemoteClientItem] diff --git a/webui/backend/app/config.py b/webui/backend/app/config.py index 5d50fc0..a3c8f8e 100644 --- a/webui/backend/app/config.py +++ b/webui/backend/app/config.py @@ -9,6 +9,10 @@ from pathlib import Path class Settings: root_aliases: dict[str, str] task_db_path: str + remote_client_registration_token: str + remote_client_offline_timeout_seconds: int + remote_client_agent_auth_header: str + remote_client_agent_auth_scheme: str DEFAULT_ROOT_ALIASES = { @@ -40,4 +44,17 @@ def get_settings() -> Settings: task_db_path = os.getenv("WEBMANAGER_TASK_DB_PATH", default_task_db_path).strip() if not task_db_path: task_db_path = default_task_db_path - return Settings(root_aliases=_load_root_aliases(), task_db_path=task_db_path) + raw_offline_timeout = os.getenv("WEBMANAGER_REMOTE_CLIENT_OFFLINE_TIMEOUT_SECONDS", "60").strip() + try: + remote_client_offline_timeout_seconds = max(1, int(raw_offline_timeout)) + except ValueError: + remote_client_offline_timeout_seconds = 60 + return Settings( + root_aliases=_load_root_aliases(), + task_db_path=task_db_path, + remote_client_registration_token=os.getenv("WEBMANAGER_REMOTE_CLIENT_REGISTRATION_TOKEN", "").strip(), + remote_client_offline_timeout_seconds=remote_client_offline_timeout_seconds, + remote_client_agent_auth_header=os.getenv("WEBMANAGER_REMOTE_CLIENT_AGENT_AUTH_HEADER", "Authorization").strip() + or "Authorization", + remote_client_agent_auth_scheme=os.getenv("WEBMANAGER_REMOTE_CLIENT_AGENT_AUTH_SCHEME", "Bearer").strip() or "Bearer", + ) diff --git a/webui/backend/app/db/remote_client_repository.py b/webui/backend/app/db/remote_client_repository.py new file mode 100644 index 0000000..b2bad1d --- /dev/null +++ b/webui/backend/app/db/remote_client_repository.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import json +import sqlite3 +from contextlib import contextmanager +from datetime import datetime, timezone +from pathlib import Path + + +class RemoteClientRepository: + def __init__(self, db_path: str): + self._db_path = db_path + self._ensure_schema() + + def upsert_client( + self, + *, + client_id: str, + display_name: str, + platform: str, + agent_version: str, + endpoint: str, + shares: list[dict[str, str]], + now_iso: str, + ) -> dict: + shares_json = self._encode_shares(shares) + with self._connection() as conn: + conn.execute( + """ + INSERT INTO remote_clients ( + client_id, display_name, platform, agent_version, endpoint, shares_json, + last_seen, status, last_error, reachable_at, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(client_id) DO UPDATE SET + display_name = excluded.display_name, + platform = excluded.platform, + agent_version = excluded.agent_version, + endpoint = excluded.endpoint, + shares_json = excluded.shares_json, + last_seen = excluded.last_seen, + status = excluded.status, + last_error = NULL, + updated_at = excluded.updated_at + """, + ( + client_id, + display_name, + platform, + agent_version, + endpoint, + shares_json, + now_iso, + "online", + None, + None, + now_iso, + now_iso, + ), + ) + row = conn.execute("SELECT * FROM remote_clients WHERE client_id = ?", (client_id,)).fetchone() + return self._to_dict(row) + + def record_heartbeat(self, *, client_id: str, agent_version: str, now_iso: str) -> dict | None: + with self._connection() as conn: + cursor = conn.execute( + """ + UPDATE remote_clients + SET agent_version = ?, last_seen = ?, status = ?, updated_at = ? + WHERE client_id = ? + """, + (agent_version, now_iso, "online", now_iso, client_id), + ) + if cursor.rowcount <= 0: + return None + row = conn.execute("SELECT * FROM remote_clients WHERE client_id = ?", (client_id,)).fetchone() + return self._to_dict(row) + + def mark_stale_clients_offline(self, *, cutoff_iso: str, now_iso: str) -> None: + with self._connection() as conn: + conn.execute( + """ + UPDATE remote_clients + SET status = ?, updated_at = ? + WHERE status != ? AND last_seen IS NOT NULL AND last_seen < ? + """, + ("offline", now_iso, "offline", cutoff_iso), + ) + + def list_clients(self) -> list[dict]: + with self._connection() as conn: + rows = conn.execute( + """ + SELECT * + FROM remote_clients + ORDER BY LOWER(display_name) ASC, client_id ASC + """ + ).fetchall() + return [self._to_dict(row) for row in rows] + + def _ensure_schema(self) -> None: + db_path = Path(self._db_path) + if db_path.parent and str(db_path.parent) not in {"", "."}: + db_path.parent.mkdir(parents=True, exist_ok=True) + with self._connection() as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS remote_clients ( + client_id TEXT PRIMARY KEY, + display_name TEXT NOT NULL, + platform TEXT NOT NULL, + agent_version TEXT NOT NULL, + endpoint TEXT NOT NULL, + shares_json TEXT NOT NULL, + last_seen TEXT NULL, + status TEXT NOT NULL, + last_error TEXT NULL, + reachable_at TEXT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_remote_clients_display_name + ON remote_clients(display_name) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_remote_clients_last_seen + ON remote_clients(last_seen) + """ + ) + + @contextmanager + def _connection(self): + conn = sqlite3.connect(self._db_path) + conn.row_factory = sqlite3.Row + try: + yield conn + conn.commit() + except Exception: + conn.rollback() + raise + finally: + conn.close() + + @classmethod + def _to_dict(cls, row: sqlite3.Row) -> dict: + return { + "client_id": row["client_id"], + "display_name": row["display_name"], + "platform": row["platform"], + "agent_version": row["agent_version"], + "endpoint": row["endpoint"], + "shares": cls._decode_shares(row["shares_json"]), + "last_seen": row["last_seen"], + "status": row["status"], + "last_error": row["last_error"], + "reachable_at": row["reachable_at"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + } + + @staticmethod + def _encode_shares(shares: list[dict[str, str]]) -> str: + return json.dumps(shares, separators=(",", ":"), sort_keys=True) + + @staticmethod + def _decode_shares(raw: str) -> list[dict[str, str]]: + parsed = json.loads(raw or "[]") + if not isinstance(parsed, list): + return [] + normalized: list[dict[str, str]] = [] + for item in parsed: + if not isinstance(item, dict): + continue + key = str(item.get("key", "")).strip() + label = str(item.get("label", "")).strip() + if key and label: + normalized.append({"key": key, "label": label}) + return normalized + + @staticmethod + def now_iso() -> str: + return datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z") diff --git a/webui/backend/app/dependencies.py b/webui/backend/app/dependencies.py index 3db48b9..41aafd7 100644 --- a/webui/backend/app/dependencies.py +++ b/webui/backend/app/dependencies.py @@ -6,6 +6,7 @@ from pathlib import Path from backend.app.config import Settings, get_settings from backend.app.db.bookmark_repository import BookmarkRepository from backend.app.db.history_repository import HistoryRepository +from backend.app.db.remote_client_repository import RemoteClientRepository from backend.app.db.settings_repository import SettingsRepository from backend.app.db.task_repository import TaskRepository from backend.app.fs.filesystem_adapter import FilesystemAdapter @@ -19,6 +20,7 @@ from backend.app.services.duplicate_task_service import DuplicateTaskService from backend.app.services.file_ops_service import FileOpsService from backend.app.services.history_service import HistoryService from backend.app.services.move_task_service import MoveTaskService +from backend.app.services.remote_client_service import RemoteClientService from backend.app.services.search_service import SearchService from backend.app.services.settings_service import SettingsService from backend.app.services.task_service import TaskService @@ -59,6 +61,12 @@ def get_settings_repository() -> SettingsRepository: return SettingsRepository(db_path=settings.task_db_path) +@lru_cache(maxsize=1) +def get_remote_client_repository() -> RemoteClientRepository: + settings: Settings = get_settings() + return RemoteClientRepository(db_path=settings.task_db_path) + + @lru_cache(maxsize=1) def get_task_runner() -> TaskRunner: return TaskRunner( @@ -155,3 +163,12 @@ async def get_search_service() -> SearchService: async def get_settings_service() -> SettingsService: return SettingsService(repository=get_settings_repository(), path_guard=get_path_guard()) + + +async def get_remote_client_service() -> RemoteClientService: + settings: Settings = get_settings() + return RemoteClientService( + repository=get_remote_client_repository(), + registration_token=settings.remote_client_registration_token, + offline_timeout_seconds=settings.remote_client_offline_timeout_seconds, + ) diff --git a/webui/backend/app/main.py b/webui/backend/app/main.py index b28e714..ffff7b2 100644 --- a/webui/backend/app/main.py +++ b/webui/backend/app/main.py @@ -10,6 +10,7 @@ from backend.app.api.errors import AppError from backend.app.api.routes_bookmarks import router as bookmarks_router from backend.app.api.routes_browse import router as browse_router from backend.app.api.routes_copy import router as copy_router +from backend.app.api.routes_clients import router as clients_router from backend.app.api.routes_duplicate import router as duplicate_router from backend.app.api.routes_files import router as files_router from backend.app.api.routes_history import router as history_router @@ -33,6 +34,7 @@ app.mount("/ui", StaticFiles(directory=str(UI_DIR), html=True), name="ui") app.include_router(browse_router, prefix="/api") app.include_router(files_router, prefix="/api") app.include_router(copy_router, prefix="/api") +app.include_router(clients_router, prefix="/api") app.include_router(duplicate_router, prefix="/api") app.include_router(move_router, prefix="/api") app.include_router(search_router, prefix="/api") diff --git a/webui/backend/app/services/remote_client_service.py b/webui/backend/app/services/remote_client_service.py new file mode 100644 index 0000000..02237fc --- /dev/null +++ b/webui/backend/app/services/remote_client_service.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import Callable + +from backend.app.api.errors import AppError +from backend.app.api.schemas import ( + RemoteClientHeartbeatRequest, + RemoteClientItem, + RemoteClientListResponse, + RemoteClientRegisterRequest, +) +from backend.app.db.remote_client_repository import RemoteClientRepository + + +class RemoteClientService: + def __init__( + self, + repository: RemoteClientRepository, + registration_token: str, + offline_timeout_seconds: int, + now: Callable[[], datetime] | None = None, + ): + self._repository = repository + self._registration_token = registration_token.strip() + self._offline_timeout_seconds = max(1, int(offline_timeout_seconds)) + self._now = now or (lambda: datetime.now(tz=timezone.utc)) + + def list_clients(self) -> RemoteClientListResponse: + now = self._now() + self._repository.mark_stale_clients_offline( + cutoff_iso=self._to_iso(now - timedelta(seconds=self._offline_timeout_seconds)), + now_iso=self._to_iso(now), + ) + items = [RemoteClientItem(**row) for row in self._repository.list_clients()] + return RemoteClientListResponse(items=items) + + def register_client(self, authorization: str | None, request: RemoteClientRegisterRequest) -> RemoteClientItem: + self._require_registration_auth(authorization) + payload = self._normalize_register_request(request) + now_iso = self._to_iso(self._now()) + item = self._repository.upsert_client(now_iso=now_iso, **payload) + return RemoteClientItem(**item) + + def record_heartbeat(self, authorization: str | None, request: RemoteClientHeartbeatRequest) -> RemoteClientItem: + self._require_registration_auth(authorization) + client_id = (request.client_id or "").strip() + agent_version = (request.agent_version or "").strip() + if not client_id: + raise AppError( + code="invalid_request", + message="client_id is required", + status_code=400, + details={"client_id": request.client_id}, + ) + if not agent_version: + raise AppError( + code="invalid_request", + message="agent_version is required", + status_code=400, + details={"agent_version": request.agent_version}, + ) + item = self._repository.record_heartbeat( + client_id=client_id, + agent_version=agent_version, + now_iso=self._to_iso(self._now()), + ) + if item is None: + raise AppError( + code="path_not_found", + message="Remote client was not found", + status_code=404, + details={"client_id": client_id}, + ) + return RemoteClientItem(**item) + + def _require_registration_auth(self, authorization: str | None) -> None: + if not self._registration_token: + raise AppError( + code="remote_client_registration_disabled", + message="Remote client registration is not configured", + status_code=503, + ) + expected = f"Bearer {self._registration_token}" + if (authorization or "").strip() != expected: + raise AppError( + code="forbidden", + message="Invalid remote client registration token", + status_code=403, + ) + + def _normalize_register_request(self, request: RemoteClientRegisterRequest) -> dict: + client_id = (request.client_id or "").strip() + display_name = (request.display_name or "").strip() + platform = (request.platform or "").strip() + agent_version = (request.agent_version or "").strip() + endpoint = (request.endpoint or "").strip() + shares = [ + {"key": (item.key or "").strip(), "label": (item.label or "").strip()} + for item in request.shares + ] + shares = [item for item in shares if item["key"] and item["label"]] + + if not client_id: + raise AppError("invalid_request", "client_id is required", 400, {"client_id": request.client_id}) + if not display_name: + raise AppError("invalid_request", "display_name is required", 400, {"display_name": request.display_name}) + if not platform: + raise AppError("invalid_request", "platform is required", 400, {"platform": request.platform}) + if not agent_version: + raise AppError("invalid_request", "agent_version is required", 400, {"agent_version": request.agent_version}) + if not endpoint: + raise AppError("invalid_request", "endpoint is required", 400, {"endpoint": request.endpoint}) + if not shares: + raise AppError("invalid_request", "at least one share is required", 400, {"shares": "[]"}) + + return { + "client_id": client_id, + "display_name": display_name, + "platform": platform, + "agent_version": agent_version, + "endpoint": endpoint, + "shares": shares, + } + + @staticmethod + def _to_iso(value: datetime) -> str: + return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z") diff --git a/webui/backend/data/tasks.db b/webui/backend/data/tasks.db index 3869ff4..515638c 100644 Binary files a/webui/backend/data/tasks.db and b/webui/backend/data/tasks.db differ diff --git a/webui/backend/tests/golden/test_api_clients_golden.py b/webui/backend/tests/golden/test_api_clients_golden.py new file mode 100644 index 0000000..3c17e3f --- /dev/null +++ b/webui/backend/tests/golden/test_api_clients_golden.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import asyncio +import sys +import tempfile +import unittest +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import httpx + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) + +from backend.app.dependencies import get_remote_client_service +from backend.app.db.remote_client_repository import RemoteClientRepository +from backend.app.main import app +from backend.app.services.remote_client_service import RemoteClientService + + +class _Clock: + def __init__(self, current: datetime): + self.current = current + + def now(self) -> datetime: + return self.current + + def advance(self, *, seconds: int) -> None: + self.current += timedelta(seconds=seconds) + + +class RemoteClientsApiGoldenTest(unittest.TestCase): + def setUp(self) -> None: + self.temp_dir = tempfile.TemporaryDirectory() + self.clock = _Clock(datetime(2026, 3, 26, 12, 0, 0, tzinfo=timezone.utc)) + repository = RemoteClientRepository(str(Path(self.temp_dir.name) / "remote-clients.db")) + service = RemoteClientService( + repository=repository, + registration_token="secret-token", + offline_timeout_seconds=60, + now=self.clock.now, + ) + + async def _override_remote_client_service() -> RemoteClientService: + return service + + app.dependency_overrides[get_remote_client_service] = _override_remote_client_service + + def tearDown(self) -> None: + app.dependency_overrides.clear() + self.temp_dir.cleanup() + + def _request(self, method: str, url: str, payload: dict | None = None, token: str | None = None) -> httpx.Response: + async def _run() -> httpx.Response: + transport = httpx.ASGITransport(app=app) + headers = {} + if token is not None: + headers["Authorization"] = f"Bearer {token}" + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + if method == "GET": + return await client.get(url, headers=headers) + return await client.post(url, json=payload, headers=headers) + + return asyncio.run(_run()) + + @staticmethod + def _register_payload() -> dict: + return { + "client_id": "client-123", + "display_name": "Jan MacBook", + "platform": "macos", + "agent_version": "1.1.0", + "endpoint": "http://192.168.1.25:8765", + "shares": [{"key": "downloads", "label": "Downloads"}], + } + + def test_list_is_empty_by_default(self) -> None: + response = self._request("GET", "/api/clients") + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {"items": []}) + + def test_register_then_list_then_heartbeat_and_status_timeout(self) -> None: + register_response = self._request( + "POST", + "/api/clients/register", + self._register_payload(), + token="secret-token", + ) + + self.assertEqual(register_response.status_code, 200) + register_body = register_response.json() + self.assertEqual(register_body["client_id"], "client-123") + self.assertEqual(register_body["display_name"], "Jan MacBook") + self.assertEqual(register_body["status"], "online") + self.assertEqual(register_body["last_seen"], "2026-03-26T12:00:00Z") + self.assertIsNone(register_body["last_error"]) + self.assertIsNone(register_body["reachable_at"]) + + list_response = self._request("GET", "/api/clients") + self.assertEqual(list_response.status_code, 200) + self.assertEqual(len(list_response.json()["items"]), 1) + self.assertEqual(list_response.json()["items"][0]["status"], "online") + + self.clock.advance(seconds=30) + heartbeat_response = self._request( + "POST", + "/api/clients/heartbeat", + {"client_id": "client-123", "agent_version": "1.1.1"}, + token="secret-token", + ) + self.assertEqual(heartbeat_response.status_code, 200) + heartbeat_body = heartbeat_response.json() + self.assertEqual(heartbeat_body["agent_version"], "1.1.1") + self.assertEqual(heartbeat_body["last_seen"], "2026-03-26T12:00:30Z") + self.assertEqual(heartbeat_body["status"], "online") + + self.clock.advance(seconds=61) + timed_out_list = self._request("GET", "/api/clients") + self.assertEqual(timed_out_list.status_code, 200) + timed_out_item = timed_out_list.json()["items"][0] + self.assertEqual(timed_out_item["status"], "offline") + self.assertEqual(timed_out_item["last_seen"], "2026-03-26T12:00:30Z") + self.assertIsNone(timed_out_item["last_error"]) + self.assertIsNone(timed_out_item["reachable_at"]) + + def test_register_rejects_invalid_token(self) -> None: + response = self._request( + "POST", + "/api/clients/register", + self._register_payload(), + token="wrong-token", + ) + + self.assertEqual(response.status_code, 403) + self.assertEqual(response.json()["error"]["code"], "forbidden") + + +if __name__ == "__main__": + unittest.main()