129 lines
5.2 KiB
Python
129 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Callable
|
|
|
|
from backend.app.api.errors import AppError
|
|
from backend.app.api.schemas import (
|
|
RemoteClientHeartbeatRequest,
|
|
RemoteClientItem,
|
|
RemoteClientListResponse,
|
|
RemoteClientRegisterRequest,
|
|
)
|
|
from backend.app.db.remote_client_repository import RemoteClientRepository
|
|
|
|
|
|
class RemoteClientService:
|
|
def __init__(
|
|
self,
|
|
repository: RemoteClientRepository,
|
|
registration_token: str,
|
|
offline_timeout_seconds: int,
|
|
now: Callable[[], datetime] | None = None,
|
|
):
|
|
self._repository = repository
|
|
self._registration_token = registration_token.strip()
|
|
self._offline_timeout_seconds = max(1, int(offline_timeout_seconds))
|
|
self._now = now or (lambda: datetime.now(tz=timezone.utc))
|
|
|
|
def list_clients(self) -> RemoteClientListResponse:
|
|
now = self._now()
|
|
self._repository.mark_stale_clients_offline(
|
|
cutoff_iso=self._to_iso(now - timedelta(seconds=self._offline_timeout_seconds)),
|
|
now_iso=self._to_iso(now),
|
|
)
|
|
items = [RemoteClientItem(**row) for row in self._repository.list_clients()]
|
|
return RemoteClientListResponse(items=items)
|
|
|
|
def register_client(self, authorization: str | None, request: RemoteClientRegisterRequest) -> RemoteClientItem:
|
|
self._require_registration_auth(authorization)
|
|
payload = self._normalize_register_request(request)
|
|
now_iso = self._to_iso(self._now())
|
|
item = self._repository.upsert_client(now_iso=now_iso, **payload)
|
|
return RemoteClientItem(**item)
|
|
|
|
def record_heartbeat(self, authorization: str | None, request: RemoteClientHeartbeatRequest) -> RemoteClientItem:
|
|
self._require_registration_auth(authorization)
|
|
client_id = (request.client_id or "").strip()
|
|
agent_version = (request.agent_version or "").strip()
|
|
if not client_id:
|
|
raise AppError(
|
|
code="invalid_request",
|
|
message="client_id is required",
|
|
status_code=400,
|
|
details={"client_id": request.client_id},
|
|
)
|
|
if not agent_version:
|
|
raise AppError(
|
|
code="invalid_request",
|
|
message="agent_version is required",
|
|
status_code=400,
|
|
details={"agent_version": request.agent_version},
|
|
)
|
|
item = self._repository.record_heartbeat(
|
|
client_id=client_id,
|
|
agent_version=agent_version,
|
|
now_iso=self._to_iso(self._now()),
|
|
)
|
|
if item is None:
|
|
raise AppError(
|
|
code="path_not_found",
|
|
message="Remote client was not found",
|
|
status_code=404,
|
|
details={"client_id": client_id},
|
|
)
|
|
return RemoteClientItem(**item)
|
|
|
|
def _require_registration_auth(self, authorization: str | None) -> None:
|
|
if not self._registration_token:
|
|
raise AppError(
|
|
code="remote_client_registration_disabled",
|
|
message="Remote client registration is not configured",
|
|
status_code=503,
|
|
)
|
|
expected = f"Bearer {self._registration_token}"
|
|
if (authorization or "").strip() != expected:
|
|
raise AppError(
|
|
code="forbidden",
|
|
message="Invalid remote client registration token",
|
|
status_code=403,
|
|
)
|
|
|
|
def _normalize_register_request(self, request: RemoteClientRegisterRequest) -> dict:
|
|
client_id = (request.client_id or "").strip()
|
|
display_name = (request.display_name or "").strip()
|
|
platform = (request.platform or "").strip()
|
|
agent_version = (request.agent_version or "").strip()
|
|
endpoint = (request.endpoint or "").strip()
|
|
shares = [
|
|
{"key": (item.key or "").strip(), "label": (item.label or "").strip()}
|
|
for item in request.shares
|
|
]
|
|
shares = [item for item in shares if item["key"] and item["label"]]
|
|
|
|
if not client_id:
|
|
raise AppError("invalid_request", "client_id is required", 400, {"client_id": request.client_id})
|
|
if not display_name:
|
|
raise AppError("invalid_request", "display_name is required", 400, {"display_name": request.display_name})
|
|
if not platform:
|
|
raise AppError("invalid_request", "platform is required", 400, {"platform": request.platform})
|
|
if not agent_version:
|
|
raise AppError("invalid_request", "agent_version is required", 400, {"agent_version": request.agent_version})
|
|
if not endpoint:
|
|
raise AppError("invalid_request", "endpoint is required", 400, {"endpoint": request.endpoint})
|
|
if not shares:
|
|
raise AppError("invalid_request", "at least one share is required", 400, {"shares": "[]"})
|
|
|
|
return {
|
|
"client_id": client_id,
|
|
"display_name": display_name,
|
|
"platform": platform,
|
|
"agent_version": agent_version,
|
|
"endpoint": endpoint,
|
|
"shares": shares,
|
|
}
|
|
|
|
@staticmethod
|
|
def _to_iso(value: datetime) -> str:
|
|
return value.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")
|