from __future__ import annotations import json import sqlite3 from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path class RemoteClientRepository: def __init__(self, db_path: str): self._db_path = db_path self._ensure_schema() def upsert_client( self, *, client_id: str, display_name: str, platform: str, agent_version: str, endpoint: str, shares: list[dict[str, str]], now_iso: str, ) -> dict: shares_json = self._encode_shares(shares) with self._connection() as conn: conn.execute( """ INSERT INTO remote_clients ( client_id, display_name, platform, agent_version, endpoint, shares_json, last_seen, status, last_error, reachable_at, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(client_id) DO UPDATE SET display_name = excluded.display_name, platform = excluded.platform, agent_version = excluded.agent_version, endpoint = excluded.endpoint, shares_json = excluded.shares_json, last_seen = excluded.last_seen, status = excluded.status, last_error = NULL, updated_at = excluded.updated_at """, ( client_id, display_name, platform, agent_version, endpoint, shares_json, now_iso, "online", None, None, now_iso, now_iso, ), ) row = conn.execute("SELECT * FROM remote_clients WHERE client_id = ?", (client_id,)).fetchone() return self._to_dict(row) def record_heartbeat(self, *, client_id: str, agent_version: str, now_iso: str) -> dict | None: with self._connection() as conn: cursor = conn.execute( """ UPDATE remote_clients SET agent_version = ?, last_seen = ?, status = ?, updated_at = ? WHERE client_id = ? """, (agent_version, now_iso, "online", now_iso, client_id), ) if cursor.rowcount <= 0: return None row = conn.execute("SELECT * FROM remote_clients WHERE client_id = ?", (client_id,)).fetchone() return self._to_dict(row) def mark_stale_clients_offline(self, *, cutoff_iso: str, now_iso: str) -> None: with self._connection() as conn: conn.execute( """ UPDATE remote_clients SET status = ?, updated_at = ? WHERE status != ? AND last_seen IS NOT NULL AND last_seen < ? """, ("offline", now_iso, "offline", cutoff_iso), ) def list_clients(self) -> list[dict]: with self._connection() as conn: rows = conn.execute( """ SELECT * FROM remote_clients ORDER BY LOWER(display_name) ASC, client_id ASC """ ).fetchall() return [self._to_dict(row) for row in rows] def get_client(self, client_id: str) -> dict | None: with self._connection() as conn: row = conn.execute( """ SELECT * FROM remote_clients WHERE client_id = ? """, (client_id,), ).fetchone() if row is None: return None return self._to_dict(row) def _ensure_schema(self) -> None: db_path = Path(self._db_path) if db_path.parent and str(db_path.parent) not in {"", "."}: db_path.parent.mkdir(parents=True, exist_ok=True) with self._connection() as conn: conn.execute( """ CREATE TABLE IF NOT EXISTS remote_clients ( client_id TEXT PRIMARY KEY, display_name TEXT NOT NULL, platform TEXT NOT NULL, agent_version TEXT NOT NULL, endpoint TEXT NOT NULL, shares_json TEXT NOT NULL, last_seen TEXT NULL, status TEXT NOT NULL, last_error TEXT NULL, reachable_at TEXT NULL, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ) """ ) conn.execute( """ CREATE INDEX IF NOT EXISTS idx_remote_clients_display_name ON remote_clients(display_name) """ ) conn.execute( """ CREATE INDEX IF NOT EXISTS idx_remote_clients_last_seen ON remote_clients(last_seen) """ ) @contextmanager def _connection(self): conn = sqlite3.connect(self._db_path) conn.row_factory = sqlite3.Row try: yield conn conn.commit() except Exception: conn.rollback() raise finally: conn.close() @classmethod def _to_dict(cls, row: sqlite3.Row) -> dict: return { "client_id": row["client_id"], "display_name": row["display_name"], "platform": row["platform"], "agent_version": row["agent_version"], "endpoint": row["endpoint"], "shares": cls._decode_shares(row["shares_json"]), "last_seen": row["last_seen"], "status": row["status"], "last_error": row["last_error"], "reachable_at": row["reachable_at"], "created_at": row["created_at"], "updated_at": row["updated_at"], } @staticmethod def _encode_shares(shares: list[dict[str, str]]) -> str: return json.dumps(shares, separators=(",", ":"), sort_keys=True) @staticmethod def _decode_shares(raw: str) -> list[dict[str, str]]: parsed = json.loads(raw or "[]") if not isinstance(parsed, list): return [] normalized: list[dict[str, str]] = [] for item in parsed: if not isinstance(item, dict): continue key = str(item.get("key", "")).strip() label = str(item.get("label", "")).strip() if key and label: normalized.append({"key": key, "label": label}) return normalized @staticmethod def now_iso() -> str: return datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z")