feat: remote client deel 1
This commit is contained in:
@@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"agent_access_token": "change-me-agent-token",
|
||||||
|
"client_id": "",
|
||||||
|
"display_name": "MacBook Pro van Jan",
|
||||||
|
"endpoint": "http://192.168.1.25:8765",
|
||||||
|
"heartbeat_interval_seconds": 20,
|
||||||
|
"platform": "macos",
|
||||||
|
"registration_token": "change-me-registration-token",
|
||||||
|
"shares": {
|
||||||
|
"downloads": "/Users/jan/Downloads",
|
||||||
|
"movies": "/Users/jan/Movies",
|
||||||
|
"pictures": "/Users/jan/Pictures"
|
||||||
|
},
|
||||||
|
"webmanager_base_url": "http://127.0.0.1:8080"
|
||||||
|
}
|
||||||
@@ -0,0 +1,168 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from urllib import error, request
|
||||||
|
|
||||||
|
|
||||||
|
AGENT_VERSION = "1.1.0-phase1"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentConfig:
|
||||||
|
config_path: Path
|
||||||
|
webmanager_base_url: str
|
||||||
|
registration_token: str
|
||||||
|
agent_access_token: str
|
||||||
|
display_name: str
|
||||||
|
endpoint: str
|
||||||
|
shares: dict[str, str]
|
||||||
|
heartbeat_interval_seconds: int
|
||||||
|
client_id: str
|
||||||
|
platform: str = "macos"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def normalized_base_url(self) -> str:
|
||||||
|
return self.webmanager_base_url.rstrip("/")
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path: Path) -> AgentConfig:
|
||||||
|
raw = json.loads(config_path.read_text(encoding="utf-8"))
|
||||||
|
client_id = str(raw.get("client_id", "")).strip()
|
||||||
|
if not client_id:
|
||||||
|
client_id = str(uuid.uuid4())
|
||||||
|
raw["client_id"] = client_id
|
||||||
|
config_path.write_text(json.dumps(raw, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
||||||
|
|
||||||
|
shares_raw = raw.get("shares") or {}
|
||||||
|
shares: dict[str, str] = {}
|
||||||
|
if isinstance(shares_raw, dict):
|
||||||
|
for key, value in shares_raw.items():
|
||||||
|
normalized_key = str(key).strip()
|
||||||
|
normalized_value = str(value).strip()
|
||||||
|
if normalized_key and normalized_value:
|
||||||
|
shares[normalized_key] = normalized_value
|
||||||
|
|
||||||
|
if not shares:
|
||||||
|
raise ValueError("config requires at least one share")
|
||||||
|
|
||||||
|
return AgentConfig(
|
||||||
|
config_path=config_path,
|
||||||
|
webmanager_base_url=str(raw.get("webmanager_base_url", "")).strip(),
|
||||||
|
registration_token=str(raw.get("registration_token", "")).strip(),
|
||||||
|
agent_access_token=str(raw.get("agent_access_token", "")).strip(),
|
||||||
|
display_name=str(raw.get("display_name", "")).strip(),
|
||||||
|
endpoint=str(raw.get("public_endpoint", raw.get("endpoint", ""))).strip(),
|
||||||
|
shares=shares,
|
||||||
|
heartbeat_interval_seconds=max(5, int(raw.get("heartbeat_interval_seconds", 20))),
|
||||||
|
client_id=client_id,
|
||||||
|
platform=str(raw.get("platform", "macos")).strip() or "macos",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def require_non_empty(value: str, field: str) -> str:
|
||||||
|
normalized = value.strip()
|
||||||
|
if not normalized:
|
||||||
|
raise ValueError(f"config field '{field}' is required")
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def build_register_payload(config: AgentConfig) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"client_id": config.client_id,
|
||||||
|
"display_name": config.display_name,
|
||||||
|
"platform": config.platform,
|
||||||
|
"agent_version": AGENT_VERSION,
|
||||||
|
"endpoint": config.endpoint,
|
||||||
|
"shares": [{"key": key, "label": key.capitalize()} for key in sorted(config.shares.keys())],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_heartbeat_payload(config: AgentConfig) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"client_id": config.client_id,
|
||||||
|
"agent_version": AGENT_VERSION,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def post_json(url: str, token: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
data = json.dumps(payload).encode("utf-8")
|
||||||
|
req = request.Request(
|
||||||
|
url,
|
||||||
|
method="POST",
|
||||||
|
data=data,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with request.urlopen(req, timeout=10) as resp:
|
||||||
|
return json.loads(resp.read().decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def run(config: AgentConfig) -> None:
|
||||||
|
require_non_empty(config.webmanager_base_url, "webmanager_base_url")
|
||||||
|
require_non_empty(config.registration_token, "registration_token")
|
||||||
|
require_non_empty(config.agent_access_token, "agent_access_token")
|
||||||
|
require_non_empty(config.display_name, "display_name")
|
||||||
|
require_non_empty(config.endpoint, "public_endpoint")
|
||||||
|
|
||||||
|
register_url = f"{config.normalized_base_url}/api/clients/register"
|
||||||
|
heartbeat_url = f"{config.normalized_base_url}/api/clients/heartbeat"
|
||||||
|
|
||||||
|
print(f"Starting remote client agent for {config.display_name} ({config.client_id})", flush=True)
|
||||||
|
print("agent_access_token is configured for future authenticated agent endpoints", flush=True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
post_json(register_url, config.registration_token, build_register_payload(config))
|
||||||
|
print("register ok", flush=True)
|
||||||
|
break
|
||||||
|
except error.HTTPError as exc:
|
||||||
|
print(f"register failed: HTTP {exc.code}", file=sys.stderr, flush=True)
|
||||||
|
except error.URLError as exc:
|
||||||
|
print(f"register failed: {exc.reason}", file=sys.stderr, flush=True)
|
||||||
|
time.sleep(config.heartbeat_interval_seconds)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
post_json(heartbeat_url, config.registration_token, build_heartbeat_payload(config))
|
||||||
|
print("heartbeat ok", flush=True)
|
||||||
|
except error.HTTPError as exc:
|
||||||
|
print(f"heartbeat failed: HTTP {exc.code}", file=sys.stderr, flush=True)
|
||||||
|
except error.URLError as exc:
|
||||||
|
print(f"heartbeat failed: {exc.reason}", file=sys.stderr, flush=True)
|
||||||
|
time.sleep(config.heartbeat_interval_seconds)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Remote client agent Phase 1 for WebManager MVP")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
default=str(Path(__file__).resolve().with_name("remote_client_agent.example.json")),
|
||||||
|
help="Path to remote client agent config JSON",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
args = parse_args()
|
||||||
|
try:
|
||||||
|
config = load_config(Path(args.config).resolve())
|
||||||
|
run(config)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
return 130
|
||||||
|
except Exception as exc:
|
||||||
|
print(str(exc), file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,39 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header
|
||||||
|
|
||||||
|
from backend.app.api.schemas import (
|
||||||
|
RemoteClientHeartbeatRequest,
|
||||||
|
RemoteClientItem,
|
||||||
|
RemoteClientListResponse,
|
||||||
|
RemoteClientRegisterRequest,
|
||||||
|
)
|
||||||
|
from backend.app.dependencies import get_remote_client_service
|
||||||
|
from backend.app.services.remote_client_service import RemoteClientService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/clients")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=RemoteClientListResponse)
|
||||||
|
async def list_clients(
|
||||||
|
service: RemoteClientService = Depends(get_remote_client_service),
|
||||||
|
) -> RemoteClientListResponse:
|
||||||
|
return service.list_clients()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=RemoteClientItem)
|
||||||
|
async def register_client(
|
||||||
|
request: RemoteClientRegisterRequest,
|
||||||
|
authorization: str | None = Header(default=None),
|
||||||
|
service: RemoteClientService = Depends(get_remote_client_service),
|
||||||
|
) -> RemoteClientItem:
|
||||||
|
return service.register_client(authorization=authorization, request=request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/heartbeat", response_model=RemoteClientItem)
|
||||||
|
async def heartbeat(
|
||||||
|
request: RemoteClientHeartbeatRequest,
|
||||||
|
authorization: str | None = Header(default=None),
|
||||||
|
service: RemoteClientService = Depends(get_remote_client_service),
|
||||||
|
) -> RemoteClientItem:
|
||||||
|
return service.record_heartbeat(authorization=authorization, request=request)
|
||||||
@@ -238,3 +238,41 @@ class SearchResultItem(BaseModel):
|
|||||||
class SearchResponse(BaseModel):
|
class SearchResponse(BaseModel):
|
||||||
items: list[SearchResultItem]
|
items: list[SearchResultItem]
|
||||||
truncated: bool
|
truncated: bool
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteClientShare(BaseModel):
|
||||||
|
key: str
|
||||||
|
label: str
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteClientRegisterRequest(BaseModel):
|
||||||
|
client_id: str
|
||||||
|
display_name: str
|
||||||
|
platform: str
|
||||||
|
agent_version: str
|
||||||
|
endpoint: str
|
||||||
|
shares: list[RemoteClientShare]
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteClientHeartbeatRequest(BaseModel):
|
||||||
|
client_id: str
|
||||||
|
agent_version: str
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteClientItem(BaseModel):
|
||||||
|
client_id: str
|
||||||
|
display_name: str
|
||||||
|
platform: str
|
||||||
|
agent_version: str
|
||||||
|
endpoint: str
|
||||||
|
shares: list[RemoteClientShare]
|
||||||
|
last_seen: str | None = None
|
||||||
|
status: str
|
||||||
|
last_error: str | None = None
|
||||||
|
reachable_at: str | None = None
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteClientListResponse(BaseModel):
|
||||||
|
items: list[RemoteClientItem]
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ from pathlib import Path
|
|||||||
class Settings:
|
class Settings:
|
||||||
root_aliases: dict[str, str]
|
root_aliases: dict[str, str]
|
||||||
task_db_path: str
|
task_db_path: str
|
||||||
|
remote_client_registration_token: str
|
||||||
|
remote_client_offline_timeout_seconds: int
|
||||||
|
remote_client_agent_auth_header: str
|
||||||
|
remote_client_agent_auth_scheme: str
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_ROOT_ALIASES = {
|
DEFAULT_ROOT_ALIASES = {
|
||||||
@@ -40,4 +44,17 @@ def get_settings() -> Settings:
|
|||||||
task_db_path = os.getenv("WEBMANAGER_TASK_DB_PATH", default_task_db_path).strip()
|
task_db_path = os.getenv("WEBMANAGER_TASK_DB_PATH", default_task_db_path).strip()
|
||||||
if not task_db_path:
|
if not task_db_path:
|
||||||
task_db_path = default_task_db_path
|
task_db_path = default_task_db_path
|
||||||
return Settings(root_aliases=_load_root_aliases(), task_db_path=task_db_path)
|
raw_offline_timeout = os.getenv("WEBMANAGER_REMOTE_CLIENT_OFFLINE_TIMEOUT_SECONDS", "60").strip()
|
||||||
|
try:
|
||||||
|
remote_client_offline_timeout_seconds = max(1, int(raw_offline_timeout))
|
||||||
|
except ValueError:
|
||||||
|
remote_client_offline_timeout_seconds = 60
|
||||||
|
return Settings(
|
||||||
|
root_aliases=_load_root_aliases(),
|
||||||
|
task_db_path=task_db_path,
|
||||||
|
remote_client_registration_token=os.getenv("WEBMANAGER_REMOTE_CLIENT_REGISTRATION_TOKEN", "").strip(),
|
||||||
|
remote_client_offline_timeout_seconds=remote_client_offline_timeout_seconds,
|
||||||
|
remote_client_agent_auth_header=os.getenv("WEBMANAGER_REMOTE_CLIENT_AGENT_AUTH_HEADER", "Authorization").strip()
|
||||||
|
or "Authorization",
|
||||||
|
remote_client_agent_auth_scheme=os.getenv("WEBMANAGER_REMOTE_CLIENT_AGENT_AUTH_SCHEME", "Bearer").strip() or "Bearer",
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,187 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteClientRepository:
|
||||||
|
def __init__(self, db_path: str):
|
||||||
|
self._db_path = db_path
|
||||||
|
self._ensure_schema()
|
||||||
|
|
||||||
|
def upsert_client(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
client_id: str,
|
||||||
|
display_name: str,
|
||||||
|
platform: str,
|
||||||
|
agent_version: str,
|
||||||
|
endpoint: str,
|
||||||
|
shares: list[dict[str, str]],
|
||||||
|
now_iso: str,
|
||||||
|
) -> dict:
|
||||||
|
shares_json = self._encode_shares(shares)
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO remote_clients (
|
||||||
|
client_id, display_name, platform, agent_version, endpoint, shares_json,
|
||||||
|
last_seen, status, last_error, reachable_at, created_at, updated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(client_id) DO UPDATE SET
|
||||||
|
display_name = excluded.display_name,
|
||||||
|
platform = excluded.platform,
|
||||||
|
agent_version = excluded.agent_version,
|
||||||
|
endpoint = excluded.endpoint,
|
||||||
|
shares_json = excluded.shares_json,
|
||||||
|
last_seen = excluded.last_seen,
|
||||||
|
status = excluded.status,
|
||||||
|
last_error = NULL,
|
||||||
|
updated_at = excluded.updated_at
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
client_id,
|
||||||
|
display_name,
|
||||||
|
platform,
|
||||||
|
agent_version,
|
||||||
|
endpoint,
|
||||||
|
shares_json,
|
||||||
|
now_iso,
|
||||||
|
"online",
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
now_iso,
|
||||||
|
now_iso,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
row = conn.execute("SELECT * FROM remote_clients WHERE client_id = ?", (client_id,)).fetchone()
|
||||||
|
return self._to_dict(row)
|
||||||
|
|
||||||
|
def record_heartbeat(self, *, client_id: str, agent_version: str, now_iso: str) -> dict | None:
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE remote_clients
|
||||||
|
SET agent_version = ?, last_seen = ?, status = ?, updated_at = ?
|
||||||
|
WHERE client_id = ?
|
||||||
|
""",
|
||||||
|
(agent_version, now_iso, "online", now_iso, client_id),
|
||||||
|
)
|
||||||
|
if cursor.rowcount <= 0:
|
||||||
|
return None
|
||||||
|
row = conn.execute("SELECT * FROM remote_clients WHERE client_id = ?", (client_id,)).fetchone()
|
||||||
|
return self._to_dict(row)
|
||||||
|
|
||||||
|
def mark_stale_clients_offline(self, *, cutoff_iso: str, now_iso: str) -> None:
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE remote_clients
|
||||||
|
SET status = ?, updated_at = ?
|
||||||
|
WHERE status != ? AND last_seen IS NOT NULL AND last_seen < ?
|
||||||
|
""",
|
||||||
|
("offline", now_iso, "offline", cutoff_iso),
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_clients(self) -> list[dict]:
|
||||||
|
with self._connection() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM remote_clients
|
||||||
|
ORDER BY LOWER(display_name) ASC, client_id ASC
|
||||||
|
"""
|
||||||
|
).fetchall()
|
||||||
|
return [self._to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
def _ensure_schema(self) -> None:
|
||||||
|
db_path = Path(self._db_path)
|
||||||
|
if db_path.parent and str(db_path.parent) not in {"", "."}:
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS remote_clients (
|
||||||
|
client_id TEXT PRIMARY KEY,
|
||||||
|
display_name TEXT NOT NULL,
|
||||||
|
platform TEXT NOT NULL,
|
||||||
|
agent_version TEXT NOT NULL,
|
||||||
|
endpoint TEXT NOT NULL,
|
||||||
|
shares_json TEXT NOT NULL,
|
||||||
|
last_seen TEXT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
last_error TEXT NULL,
|
||||||
|
reachable_at TEXT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_remote_clients_display_name
|
||||||
|
ON remote_clients(display_name)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_remote_clients_last_seen
|
||||||
|
ON remote_clients(last_seen)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _connection(self):
|
||||||
|
conn = sqlite3.connect(self._db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
conn.commit()
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _to_dict(cls, row: sqlite3.Row) -> dict:
|
||||||
|
return {
|
||||||
|
"client_id": row["client_id"],
|
||||||
|
"display_name": row["display_name"],
|
||||||
|
"platform": row["platform"],
|
||||||
|
"agent_version": row["agent_version"],
|
||||||
|
"endpoint": row["endpoint"],
|
||||||
|
"shares": cls._decode_shares(row["shares_json"]),
|
||||||
|
"last_seen": row["last_seen"],
|
||||||
|
"status": row["status"],
|
||||||
|
"last_error": row["last_error"],
|
||||||
|
"reachable_at": row["reachable_at"],
|
||||||
|
"created_at": row["created_at"],
|
||||||
|
"updated_at": row["updated_at"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _encode_shares(shares: list[dict[str, str]]) -> str:
|
||||||
|
return json.dumps(shares, separators=(",", ":"), sort_keys=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decode_shares(raw: str) -> list[dict[str, str]]:
|
||||||
|
parsed = json.loads(raw or "[]")
|
||||||
|
if not isinstance(parsed, list):
|
||||||
|
return []
|
||||||
|
normalized: list[dict[str, str]] = []
|
||||||
|
for item in parsed:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
key = str(item.get("key", "")).strip()
|
||||||
|
label = str(item.get("label", "")).strip()
|
||||||
|
if key and label:
|
||||||
|
normalized.append({"key": key, "label": label})
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def now_iso() -> str:
|
||||||
|
return datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z")
|
||||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
from backend.app.config import Settings, get_settings
|
from backend.app.config import Settings, get_settings
|
||||||
from backend.app.db.bookmark_repository import BookmarkRepository
|
from backend.app.db.bookmark_repository import BookmarkRepository
|
||||||
from backend.app.db.history_repository import HistoryRepository
|
from backend.app.db.history_repository import HistoryRepository
|
||||||
|
from backend.app.db.remote_client_repository import RemoteClientRepository
|
||||||
from backend.app.db.settings_repository import SettingsRepository
|
from backend.app.db.settings_repository import SettingsRepository
|
||||||
from backend.app.db.task_repository import TaskRepository
|
from backend.app.db.task_repository import TaskRepository
|
||||||
from backend.app.fs.filesystem_adapter import FilesystemAdapter
|
from backend.app.fs.filesystem_adapter import FilesystemAdapter
|
||||||
@@ -19,6 +20,7 @@ from backend.app.services.duplicate_task_service import DuplicateTaskService
|
|||||||
from backend.app.services.file_ops_service import FileOpsService
|
from backend.app.services.file_ops_service import FileOpsService
|
||||||
from backend.app.services.history_service import HistoryService
|
from backend.app.services.history_service import HistoryService
|
||||||
from backend.app.services.move_task_service import MoveTaskService
|
from backend.app.services.move_task_service import MoveTaskService
|
||||||
|
from backend.app.services.remote_client_service import RemoteClientService
|
||||||
from backend.app.services.search_service import SearchService
|
from backend.app.services.search_service import SearchService
|
||||||
from backend.app.services.settings_service import SettingsService
|
from backend.app.services.settings_service import SettingsService
|
||||||
from backend.app.services.task_service import TaskService
|
from backend.app.services.task_service import TaskService
|
||||||
@@ -59,6 +61,12 @@ def get_settings_repository() -> SettingsRepository:
|
|||||||
return SettingsRepository(db_path=settings.task_db_path)
|
return SettingsRepository(db_path=settings.task_db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_remote_client_repository() -> RemoteClientRepository:
|
||||||
|
settings: Settings = get_settings()
|
||||||
|
return RemoteClientRepository(db_path=settings.task_db_path)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def get_task_runner() -> TaskRunner:
|
def get_task_runner() -> TaskRunner:
|
||||||
return TaskRunner(
|
return TaskRunner(
|
||||||
@@ -155,3 +163,12 @@ async def get_search_service() -> SearchService:
|
|||||||
|
|
||||||
async def get_settings_service() -> SettingsService:
|
async def get_settings_service() -> SettingsService:
|
||||||
return SettingsService(repository=get_settings_repository(), path_guard=get_path_guard())
|
return SettingsService(repository=get_settings_repository(), path_guard=get_path_guard())
|
||||||
|
|
||||||
|
|
||||||
|
async def get_remote_client_service() -> RemoteClientService:
|
||||||
|
settings: Settings = get_settings()
|
||||||
|
return RemoteClientService(
|
||||||
|
repository=get_remote_client_repository(),
|
||||||
|
registration_token=settings.remote_client_registration_token,
|
||||||
|
offline_timeout_seconds=settings.remote_client_offline_timeout_seconds,
|
||||||
|
)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from backend.app.api.errors import AppError
|
|||||||
from backend.app.api.routes_bookmarks import router as bookmarks_router
|
from backend.app.api.routes_bookmarks import router as bookmarks_router
|
||||||
from backend.app.api.routes_browse import router as browse_router
|
from backend.app.api.routes_browse import router as browse_router
|
||||||
from backend.app.api.routes_copy import router as copy_router
|
from backend.app.api.routes_copy import router as copy_router
|
||||||
|
from backend.app.api.routes_clients import router as clients_router
|
||||||
from backend.app.api.routes_duplicate import router as duplicate_router
|
from backend.app.api.routes_duplicate import router as duplicate_router
|
||||||
from backend.app.api.routes_files import router as files_router
|
from backend.app.api.routes_files import router as files_router
|
||||||
from backend.app.api.routes_history import router as history_router
|
from backend.app.api.routes_history import router as history_router
|
||||||
@@ -33,6 +34,7 @@ app.mount("/ui", StaticFiles(directory=str(UI_DIR), html=True), name="ui")
|
|||||||
app.include_router(browse_router, prefix="/api")
|
app.include_router(browse_router, prefix="/api")
|
||||||
app.include_router(files_router, prefix="/api")
|
app.include_router(files_router, prefix="/api")
|
||||||
app.include_router(copy_router, prefix="/api")
|
app.include_router(copy_router, prefix="/api")
|
||||||
|
app.include_router(clients_router, prefix="/api")
|
||||||
app.include_router(duplicate_router, prefix="/api")
|
app.include_router(duplicate_router, prefix="/api")
|
||||||
app.include_router(move_router, prefix="/api")
|
app.include_router(move_router, prefix="/api")
|
||||||
app.include_router(search_router, prefix="/api")
|
app.include_router(search_router, prefix="/api")
|
||||||
|
|||||||
@@ -0,0 +1,128 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from backend.app.api.errors import AppError
|
||||||
|
from backend.app.api.schemas import (
|
||||||
|
RemoteClientHeartbeatRequest,
|
||||||
|
RemoteClientItem,
|
||||||
|
RemoteClientListResponse,
|
||||||
|
RemoteClientRegisterRequest,
|
||||||
|
)
|
||||||
|
from backend.app.db.remote_client_repository import RemoteClientRepository
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteClientService:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repository: RemoteClientRepository,
|
||||||
|
registration_token: str,
|
||||||
|
offline_timeout_seconds: int,
|
||||||
|
now: Callable[[], datetime] | None = None,
|
||||||
|
):
|
||||||
|
self._repository = repository
|
||||||
|
self._registration_token = registration_token.strip()
|
||||||
|
self._offline_timeout_seconds = max(1, int(offline_timeout_seconds))
|
||||||
|
self._now = now or (lambda: datetime.now(tz=timezone.utc))
|
||||||
|
|
||||||
|
def list_clients(self) -> RemoteClientListResponse:
|
||||||
|
now = self._now()
|
||||||
|
self._repository.mark_stale_clients_offline(
|
||||||
|
cutoff_iso=self._to_iso(now - timedelta(seconds=self._offline_timeout_seconds)),
|
||||||
|
now_iso=self._to_iso(now),
|
||||||
|
)
|
||||||
|
items = [RemoteClientItem(**row) for row in self._repository.list_clients()]
|
||||||
|
return RemoteClientListResponse(items=items)
|
||||||
|
|
||||||
|
def register_client(self, authorization: str | None, request: RemoteClientRegisterRequest) -> RemoteClientItem:
|
||||||
|
self._require_registration_auth(authorization)
|
||||||
|
payload = self._normalize_register_request(request)
|
||||||
|
now_iso = self._to_iso(self._now())
|
||||||
|
item = self._repository.upsert_client(now_iso=now_iso, **payload)
|
||||||
|
return RemoteClientItem(**item)
|
||||||
|
|
||||||
|
def record_heartbeat(self, authorization: str | None, request: RemoteClientHeartbeatRequest) -> RemoteClientItem:
|
||||||
|
self._require_registration_auth(authorization)
|
||||||
|
client_id = (request.client_id or "").strip()
|
||||||
|
agent_version = (request.agent_version or "").strip()
|
||||||
|
if not client_id:
|
||||||
|
raise AppError(
|
||||||
|
code="invalid_request",
|
||||||
|
message="client_id is required",
|
||||||
|
status_code=400,
|
||||||
|
details={"client_id": request.client_id},
|
||||||
|
)
|
||||||
|
if not agent_version:
|
||||||
|
raise AppError(
|
||||||
|
code="invalid_request",
|
||||||
|
message="agent_version is required",
|
||||||
|
status_code=400,
|
||||||
|
details={"agent_version": request.agent_version},
|
||||||
|
)
|
||||||
|
item = self._repository.record_heartbeat(
|
||||||
|
client_id=client_id,
|
||||||
|
agent_version=agent_version,
|
||||||
|
now_iso=self._to_iso(self._now()),
|
||||||
|
)
|
||||||
|
if item is None:
|
||||||
|
raise AppError(
|
||||||
|
code="path_not_found",
|
||||||
|
message="Remote client was not found",
|
||||||
|
status_code=404,
|
||||||
|
details={"client_id": client_id},
|
||||||
|
)
|
||||||
|
return RemoteClientItem(**item)
|
||||||
|
|
||||||
|
def _require_registration_auth(self, authorization: str | None) -> None:
|
||||||
|
if not self._registration_token:
|
||||||
|
raise AppError(
|
||||||
|
code="remote_client_registration_disabled",
|
||||||
|
message="Remote client registration is not configured",
|
||||||
|
status_code=503,
|
||||||
|
)
|
||||||
|
expected = f"Bearer {self._registration_token}"
|
||||||
|
if (authorization or "").strip() != expected:
|
||||||
|
raise AppError(
|
||||||
|
code="forbidden",
|
||||||
|
message="Invalid remote client registration token",
|
||||||
|
status_code=403,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_register_request(self, request: RemoteClientRegisterRequest) -> dict:
|
||||||
|
client_id = (request.client_id or "").strip()
|
||||||
|
display_name = (request.display_name or "").strip()
|
||||||
|
platform = (request.platform or "").strip()
|
||||||
|
agent_version = (request.agent_version or "").strip()
|
||||||
|
endpoint = (request.endpoint or "").strip()
|
||||||
|
shares = [
|
||||||
|
{"key": (item.key or "").strip(), "label": (item.label or "").strip()}
|
||||||
|
for item in request.shares
|
||||||
|
]
|
||||||
|
shares = [item for item in shares if item["key"] and item["label"]]
|
||||||
|
|
||||||
|
if not client_id:
|
||||||
|
raise AppError("invalid_request", "client_id is required", 400, {"client_id": request.client_id})
|
||||||
|
if not display_name:
|
||||||
|
raise AppError("invalid_request", "display_name is required", 400, {"display_name": request.display_name})
|
||||||
|
if not platform:
|
||||||
|
raise AppError("invalid_request", "platform is required", 400, {"platform": request.platform})
|
||||||
|
if not agent_version:
|
||||||
|
raise AppError("invalid_request", "agent_version is required", 400, {"agent_version": request.agent_version})
|
||||||
|
if not endpoint:
|
||||||
|
raise AppError("invalid_request", "endpoint is required", 400, {"endpoint": request.endpoint})
|
||||||
|
if not shares:
|
||||||
|
raise AppError("invalid_request", "at least one share is required", 400, {"shares": "[]"})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"client_id": client_id,
|
||||||
|
"display_name": display_name,
|
||||||
|
"platform": platform,
|
||||||
|
"agent_version": agent_version,
|
||||||
|
"endpoint": endpoint,
|
||||||
|
"shares": shares,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_iso(value: datetime) -> str:
|
||||||
|
return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||||
Binary file not shown.
@@ -0,0 +1,139 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parents[3]))
|
||||||
|
|
||||||
|
from backend.app.dependencies import get_remote_client_service
|
||||||
|
from backend.app.db.remote_client_repository import RemoteClientRepository
|
||||||
|
from backend.app.main import app
|
||||||
|
from backend.app.services.remote_client_service import RemoteClientService
|
||||||
|
|
||||||
|
|
||||||
|
class _Clock:
|
||||||
|
def __init__(self, current: datetime):
|
||||||
|
self.current = current
|
||||||
|
|
||||||
|
def now(self) -> datetime:
|
||||||
|
return self.current
|
||||||
|
|
||||||
|
def advance(self, *, seconds: int) -> None:
|
||||||
|
self.current += timedelta(seconds=seconds)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteClientsApiGoldenTest(unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
self.clock = _Clock(datetime(2026, 3, 26, 12, 0, 0, tzinfo=timezone.utc))
|
||||||
|
repository = RemoteClientRepository(str(Path(self.temp_dir.name) / "remote-clients.db"))
|
||||||
|
service = RemoteClientService(
|
||||||
|
repository=repository,
|
||||||
|
registration_token="secret-token",
|
||||||
|
offline_timeout_seconds=60,
|
||||||
|
now=self.clock.now,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _override_remote_client_service() -> RemoteClientService:
|
||||||
|
return service
|
||||||
|
|
||||||
|
app.dependency_overrides[get_remote_client_service] = _override_remote_client_service
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
self.temp_dir.cleanup()
|
||||||
|
|
||||||
|
def _request(self, method: str, url: str, payload: dict | None = None, token: str | None = None) -> httpx.Response:
|
||||||
|
async def _run() -> httpx.Response:
|
||||||
|
transport = httpx.ASGITransport(app=app)
|
||||||
|
headers = {}
|
||||||
|
if token is not None:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||||
|
if method == "GET":
|
||||||
|
return await client.get(url, headers=headers)
|
||||||
|
return await client.post(url, json=payload, headers=headers)
|
||||||
|
|
||||||
|
return asyncio.run(_run())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _register_payload() -> dict:
|
||||||
|
return {
|
||||||
|
"client_id": "client-123",
|
||||||
|
"display_name": "Jan MacBook",
|
||||||
|
"platform": "macos",
|
||||||
|
"agent_version": "1.1.0",
|
||||||
|
"endpoint": "http://192.168.1.25:8765",
|
||||||
|
"shares": [{"key": "downloads", "label": "Downloads"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_list_is_empty_by_default(self) -> None:
|
||||||
|
response = self._request("GET", "/api/clients")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(response.json(), {"items": []})
|
||||||
|
|
||||||
|
def test_register_then_list_then_heartbeat_and_status_timeout(self) -> None:
|
||||||
|
register_response = self._request(
|
||||||
|
"POST",
|
||||||
|
"/api/clients/register",
|
||||||
|
self._register_payload(),
|
||||||
|
token="secret-token",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(register_response.status_code, 200)
|
||||||
|
register_body = register_response.json()
|
||||||
|
self.assertEqual(register_body["client_id"], "client-123")
|
||||||
|
self.assertEqual(register_body["display_name"], "Jan MacBook")
|
||||||
|
self.assertEqual(register_body["status"], "online")
|
||||||
|
self.assertEqual(register_body["last_seen"], "2026-03-26T12:00:00Z")
|
||||||
|
self.assertIsNone(register_body["last_error"])
|
||||||
|
self.assertIsNone(register_body["reachable_at"])
|
||||||
|
|
||||||
|
list_response = self._request("GET", "/api/clients")
|
||||||
|
self.assertEqual(list_response.status_code, 200)
|
||||||
|
self.assertEqual(len(list_response.json()["items"]), 1)
|
||||||
|
self.assertEqual(list_response.json()["items"][0]["status"], "online")
|
||||||
|
|
||||||
|
self.clock.advance(seconds=30)
|
||||||
|
heartbeat_response = self._request(
|
||||||
|
"POST",
|
||||||
|
"/api/clients/heartbeat",
|
||||||
|
{"client_id": "client-123", "agent_version": "1.1.1"},
|
||||||
|
token="secret-token",
|
||||||
|
)
|
||||||
|
self.assertEqual(heartbeat_response.status_code, 200)
|
||||||
|
heartbeat_body = heartbeat_response.json()
|
||||||
|
self.assertEqual(heartbeat_body["agent_version"], "1.1.1")
|
||||||
|
self.assertEqual(heartbeat_body["last_seen"], "2026-03-26T12:00:30Z")
|
||||||
|
self.assertEqual(heartbeat_body["status"], "online")
|
||||||
|
|
||||||
|
self.clock.advance(seconds=61)
|
||||||
|
timed_out_list = self._request("GET", "/api/clients")
|
||||||
|
self.assertEqual(timed_out_list.status_code, 200)
|
||||||
|
timed_out_item = timed_out_list.json()["items"][0]
|
||||||
|
self.assertEqual(timed_out_item["status"], "offline")
|
||||||
|
self.assertEqual(timed_out_item["last_seen"], "2026-03-26T12:00:30Z")
|
||||||
|
self.assertIsNone(timed_out_item["last_error"])
|
||||||
|
self.assertIsNone(timed_out_item["reachable_at"])
|
||||||
|
|
||||||
|
def test_register_rejects_invalid_token(self) -> None:
|
||||||
|
response = self._request(
|
||||||
|
"POST",
|
||||||
|
"/api/clients/register",
|
||||||
|
self._register_payload(),
|
||||||
|
token="wrong-token",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 403)
|
||||||
|
self.assertEqual(response.json()["error"]["code"], "forbidden")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user