221 lines
7.4 KiB
Python
221 lines
7.4 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
import threading
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib import error, request
|
|
from urllib.parse import urlparse
|
|
|
|
import uvicorn
|
|
|
|
|
|
AGENT_VERSION = "1.1.0-phase1"
|
|
|
|
|
|
@dataclass
|
|
class AgentConfig:
|
|
config_path: Path
|
|
webmanager_base_url: str
|
|
registration_token: str
|
|
agent_access_token: str
|
|
display_name: str
|
|
endpoint: str
|
|
shares: dict[str, str]
|
|
heartbeat_interval_seconds: int
|
|
client_id: str
|
|
platform: str = "macos"
|
|
|
|
@property
|
|
def normalized_base_url(self) -> str:
|
|
return self.webmanager_base_url.rstrip("/")
|
|
|
|
|
|
def load_config(config_path: Path) -> AgentConfig:
|
|
raw = json.loads(config_path.read_text(encoding="utf-8"))
|
|
client_id = str(raw.get("client_id", "")).strip()
|
|
if not client_id:
|
|
client_id = str(uuid.uuid4())
|
|
raw["client_id"] = client_id
|
|
config_path.write_text(json.dumps(raw, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
|
|
|
shares_raw = raw.get("shares") or {}
|
|
shares: dict[str, str] = {}
|
|
if isinstance(shares_raw, dict):
|
|
for key, value in shares_raw.items():
|
|
normalized_key = str(key).strip()
|
|
normalized_value = str(value).strip()
|
|
if normalized_key and normalized_value:
|
|
shares[normalized_key] = normalized_value
|
|
|
|
if not shares:
|
|
raise ValueError("config requires at least one share")
|
|
|
|
return AgentConfig(
|
|
config_path=config_path,
|
|
webmanager_base_url=str(raw.get("webmanager_base_url", "")).strip(),
|
|
registration_token=str(raw.get("registration_token", "")).strip(),
|
|
agent_access_token=str(raw.get("agent_access_token", "")).strip(),
|
|
display_name=str(raw.get("display_name", "")).strip(),
|
|
endpoint=str(raw.get("public_endpoint", raw.get("endpoint", ""))).strip(),
|
|
shares=shares,
|
|
heartbeat_interval_seconds=max(5, int(raw.get("heartbeat_interval_seconds", 20))),
|
|
client_id=client_id,
|
|
platform=str(raw.get("platform", "macos")).strip() or "macos",
|
|
)
|
|
|
|
|
|
def require_non_empty(value: str, field: str) -> str:
|
|
normalized = value.strip()
|
|
if not normalized:
|
|
raise ValueError(f"config field '{field}' is required")
|
|
return normalized
|
|
|
|
|
|
def build_register_payload(config: AgentConfig) -> dict[str, Any]:
|
|
return {
|
|
"client_id": config.client_id,
|
|
"display_name": config.display_name,
|
|
"platform": config.platform,
|
|
"agent_version": AGENT_VERSION,
|
|
"endpoint": config.endpoint,
|
|
"shares": [{"key": key, "label": key.capitalize()} for key in sorted(config.shares.keys())],
|
|
}
|
|
|
|
|
|
def build_heartbeat_payload(config: AgentConfig) -> dict[str, Any]:
|
|
return {
|
|
"client_id": config.client_id,
|
|
"agent_version": AGENT_VERSION,
|
|
}
|
|
|
|
|
|
def post_json(url: str, token: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
data = json.dumps(payload).encode("utf-8")
|
|
req = request.Request(
|
|
url,
|
|
method="POST",
|
|
data=data,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {token}",
|
|
},
|
|
)
|
|
with request.urlopen(req, timeout=10) as resp:
|
|
return json.loads(resp.read().decode("utf-8"))
|
|
|
|
|
|
def run_heartbeat_loop(config: AgentConfig, stop_event: threading.Event) -> None:
|
|
require_non_empty(config.webmanager_base_url, "webmanager_base_url")
|
|
require_non_empty(config.registration_token, "registration_token")
|
|
require_non_empty(config.agent_access_token, "agent_access_token")
|
|
require_non_empty(config.display_name, "display_name")
|
|
require_non_empty(config.endpoint, "public_endpoint")
|
|
|
|
register_url = f"{config.normalized_base_url}/api/clients/register"
|
|
heartbeat_url = f"{config.normalized_base_url}/api/clients/heartbeat"
|
|
|
|
print(f"Starting remote client agent for {config.display_name} ({config.client_id})", flush=True)
|
|
print(f"Using config: {config.config_path}", flush=True)
|
|
print("agent_access_token is configured for authenticated agent endpoints", flush=True)
|
|
|
|
while not stop_event.is_set():
|
|
try:
|
|
post_json(register_url, config.registration_token, build_register_payload(config))
|
|
print("register ok", flush=True)
|
|
break
|
|
except error.HTTPError as exc:
|
|
print(f"register failed: HTTP {exc.code}", file=sys.stderr, flush=True)
|
|
except error.URLError as exc:
|
|
print(f"register failed: {exc.reason}", file=sys.stderr, flush=True)
|
|
if stop_event.wait(config.heartbeat_interval_seconds):
|
|
return
|
|
|
|
while not stop_event.is_set():
|
|
try:
|
|
post_json(heartbeat_url, config.registration_token, build_heartbeat_payload(config))
|
|
print("heartbeat ok", flush=True)
|
|
except error.HTTPError as exc:
|
|
print(f"heartbeat failed: HTTP {exc.code}", file=sys.stderr, flush=True)
|
|
except error.URLError as exc:
|
|
print(f"heartbeat failed: {exc.reason}", file=sys.stderr, flush=True)
|
|
if stop_event.wait(config.heartbeat_interval_seconds):
|
|
return
|
|
|
|
|
|
def resolve_bind_host(config: AgentConfig, requested_host: str | None) -> str:
|
|
normalized = (requested_host or "").strip()
|
|
if normalized:
|
|
return normalized
|
|
return "0.0.0.0"
|
|
|
|
|
|
def resolve_bind_port(config: AgentConfig, requested_port: int | None) -> int:
|
|
if requested_port and requested_port > 0:
|
|
return requested_port
|
|
parsed = urlparse(config.endpoint)
|
|
if parsed.port:
|
|
return parsed.port
|
|
if parsed.scheme == "https":
|
|
return 443
|
|
if parsed.scheme == "http":
|
|
return 80
|
|
return 8765
|
|
|
|
|
|
def run(config: AgentConfig, requested_host: str | None, requested_port: int | None) -> None:
|
|
stop_event = threading.Event()
|
|
heartbeat_thread = threading.Thread(
|
|
target=run_heartbeat_loop,
|
|
args=(config, stop_event),
|
|
daemon=True,
|
|
name="remote-client-heartbeat",
|
|
)
|
|
heartbeat_thread.start()
|
|
|
|
bind_host = resolve_bind_host(config, requested_host)
|
|
bind_port = resolve_bind_port(config, requested_port)
|
|
print(f"Starting HTTP agent on {bind_host}:{bind_port}", flush=True)
|
|
print(f"Advertised endpoint: {config.endpoint}", flush=True)
|
|
try:
|
|
import os
|
|
|
|
os.environ["FINDER_COMMANDER_REMOTE_AGENT_CONFIG"] = str(config.config_path)
|
|
uvicorn.run("app.main:app", host=bind_host, port=bind_port)
|
|
finally:
|
|
stop_event.set()
|
|
heartbeat_thread.join(timeout=2)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Remote client agent Phase 1 for WebManager MVP")
|
|
parser.add_argument(
|
|
"--config",
|
|
default=str(Path(__file__).resolve().with_name("remote_client_agent.example.json")),
|
|
help="Path to remote client agent config JSON",
|
|
)
|
|
parser.add_argument("--host", default="", help="Bind host for the HTTP agent, defaults to 0.0.0.0")
|
|
parser.add_argument("--port", type=int, default=0, help="Bind port for the HTTP agent, defaults to endpoint port")
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
try:
|
|
config = load_config(Path(args.config).resolve())
|
|
run(config, requested_host=args.host, requested_port=args.port)
|
|
except KeyboardInterrupt:
|
|
return 130
|
|
except Exception as exc:
|
|
print(str(exc), file=sys.stderr)
|
|
return 1
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|