from __future__ import annotations import asyncio import sys import tempfile import threading import time import unittest import zipfile from io import BytesIO from pathlib import Path import httpx sys.path.insert(0, str(Path(__file__).resolve().parents[3])) from backend.app.dependencies import get_archive_download_task_service, get_file_ops_service, get_task_service from backend.app.db.history_repository import HistoryRepository from backend.app.db.task_repository import TaskRepository from backend.app.fs.filesystem_adapter import FilesystemAdapter from backend.app.main import app from backend.app.security.path_guard import PathGuard from backend.app.services.archive_download_task_service import ArchiveDownloadTaskService from backend.app.services.file_ops_service import FileOpsService, ZipDownloadPreflightLimits from backend.app.services.task_service import TaskService from backend.app.tasks_runner import TaskRunner class BlockingArchiveFileOpsService(FileOpsService): def __init__(self, *args, gate: threading.Event, **kwargs): super().__init__(*args, **kwargs) self._gate = gate def _run_zip_download_preflight(self, resolved_targets: list) -> None: super()._run_zip_download_preflight(resolved_targets) self._gate.wait(timeout=2.0) class FailingArchiveFileOpsService(FileOpsService): def _write_download_target_to_zip(self, archive: zipfile.ZipFile, resolved_target) -> None: archive.writestr("partial.txt", b"partial") raise OSError("forced archive failure") class DownloadApiGoldenTest(unittest.TestCase): def setUp(self) -> None: self.temp_dir = tempfile.TemporaryDirectory() self.root = Path(self.temp_dir.name) / "root" self.root.mkdir(parents=True, exist_ok=True) self.db_path = str(Path(self.temp_dir.name) / "tasks.db") self.artifact_root = Path(self.temp_dir.name) / "archive_tmp" self.path_guard = PathGuard({"storage1": str(self.root), "storage2": str(self.root)}) self.filesystem = FilesystemAdapter() self.task_repo = TaskRepository(self.db_path) self.history_repo = HistoryRepository(self.db_path) self._override_services() def tearDown(self) -> None: app.dependency_overrides.clear() self.temp_dir.cleanup() def _override_services( self, *, file_ops_service: FileOpsService | None = None, artifact_ttl_seconds: int = 1800, ) -> None: file_ops_service = file_ops_service or FileOpsService( path_guard=self.path_guard, filesystem=self.filesystem, history_repository=self.history_repo, zip_download_preflight_limits=ZipDownloadPreflightLimits(), ) runner = TaskRunner(repository=self.task_repo, filesystem=self.filesystem, history_repository=self.history_repo) archive_service = ArchiveDownloadTaskService( path_guard=self.path_guard, repository=self.task_repo, runner=runner, history_repository=self.history_repo, file_ops_service=file_ops_service, artifact_root=self.artifact_root, artifact_ttl_seconds=artifact_ttl_seconds, ) task_service = TaskService(repository=self.task_repo) async def _override_file_ops_service() -> FileOpsService: return file_ops_service async def _override_archive_service() -> ArchiveDownloadTaskService: return archive_service async def _override_task_service() -> TaskService: return task_service app.dependency_overrides[get_file_ops_service] = _override_file_ops_service app.dependency_overrides[get_archive_download_task_service] = _override_archive_service app.dependency_overrides[get_task_service] = _override_task_service def _request(self, method: str, url: str, payload: dict | None = None) -> httpx.Response: async def _run() -> httpx.Response: transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: if method == "GET": return await client.get(url) return await client.post(url, json=payload) return asyncio.run(_run()) def _wait_for_task_status(self, task_id: str, statuses: set[str], timeout_s: float = 2.0) -> dict: deadline = time.time() + timeout_s while time.time() < deadline: response = self._request("GET", f"/api/tasks/{task_id}") body = response.json() if body["status"] in statuses: return body time.sleep(0.02) self.fail("task did not reach expected status in time") def test_download_success_for_allowed_file(self) -> None: src = self.root / "report.txt" src.write_text("hello download", encoding="utf-8") response = self._request("GET", "/api/files/download?path=storage1/report.txt") self.assertEqual(response.status_code, 200) self.assertEqual(response.content, b"hello download") self.assertIn('attachment; filename="report.txt"', response.headers.get("content-disposition", "")) self.assertEqual(response.headers.get("content-type"), "text/plain; charset=utf-8") def test_archive_prepare_single_directory_ends_ready(self) -> None: (self.root / "docs").mkdir() (self.root / "docs" / "a.txt").write_text("a", encoding="utf-8") created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]}) self.assertEqual(created.status_code, 202) task = self._wait_for_task_status(created.json()["task_id"], {"ready"}) self.assertEqual(task["operation"], "download") self.assertEqual(task["status"], "ready") self.assertEqual(task["destination"], "docs.zip") def test_archive_prepare_multi_mixed_selection_ends_ready(self) -> None: (self.root / "readme.txt").write_text("R", encoding="utf-8") (self.root / "photos").mkdir() (self.root / "photos" / "img.txt").write_text("P", encoding="utf-8") created = self._request( "POST", "/api/files/download/archive-prepare", {"paths": ["storage1/readme.txt", "storage1/photos"]}, ) self.assertEqual(created.status_code, 202) task = self._wait_for_task_status(created.json()["task_id"], {"ready"}) self.assertEqual(task["status"], "ready") self.assertEqual(task["source"], "storage1/readme.txt, storage1/photos") self.assertRegex(task["destination"], r'^kodidownload-\d{8}-\d{6}\.zip$') def test_archive_retrieval_from_ready_task_works(self) -> None: (self.root / "docs").mkdir() (self.root / "docs" / "a.txt").write_text("a", encoding="utf-8") created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]}) task = self._wait_for_task_status(created.json()["task_id"], {"ready"}) response = self._request("GET", f"/api/files/download/archive/{task['id']}") self.assertEqual(response.status_code, 200) self.assertIn('attachment; filename="docs.zip"', response.headers.get("content-disposition", "")) with zipfile.ZipFile(BytesIO(response.content)) as archive: self.assertIn("docs/", archive.namelist()) self.assertIn("docs/a.txt", archive.namelist()) self.assertEqual(archive.read("docs/a.txt"), b"a") def test_archive_retrieval_before_ready_rejected(self) -> None: gate = threading.Event() file_ops_service = BlockingArchiveFileOpsService( path_guard=self.path_guard, filesystem=self.filesystem, history_repository=self.history_repo, zip_download_preflight_limits=ZipDownloadPreflightLimits(), gate=gate, ) self._override_services(file_ops_service=file_ops_service) (self.root / "docs").mkdir() (self.root / "docs" / "a.txt").write_text("a", encoding="utf-8") created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]}) task = self._wait_for_task_status(created.json()["task_id"], {"requested", "preparing"}) response = self._request("GET", f"/api/files/download/archive/{task['id']}") gate.set() self.assertEqual(response.status_code, 409) self.assertEqual(response.json()["error"]["code"], "download_not_ready") def test_archive_preflight_failure_sets_failed_and_error_code(self) -> None: target = self.root / "real.txt" target.write_text("x", encoding="utf-8") (self.root / "docs").mkdir() (self.root / "docs" / "link.txt").symlink_to(target) created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]}) task = self._wait_for_task_status(created.json()["task_id"], {"failed"}) self.assertEqual(task["status"], "failed") self.assertEqual(task["error_code"], "download_preflight_failed") def test_archive_failure_removes_partial_artifact(self) -> None: file_ops_service = FailingArchiveFileOpsService( path_guard=self.path_guard, filesystem=self.filesystem, history_repository=self.history_repo, zip_download_preflight_limits=ZipDownloadPreflightLimits(), ) self._override_services(file_ops_service=file_ops_service) (self.root / "docs").mkdir() (self.root / "docs" / "a.txt").write_text("a", encoding="utf-8") created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]}) task = self._wait_for_task_status(created.json()["task_id"], {"failed"}) self.assertEqual(task["error_code"], "io_error") self.assertEqual(list(self.artifact_root.glob("*")), []) def test_expired_artifact_rejected_and_removed(self) -> None: (self.root / "docs").mkdir() (self.root / "docs" / "a.txt").write_text("a", encoding="utf-8") self._override_services(artifact_ttl_seconds=1) created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]}) task = self._wait_for_task_status(created.json()["task_id"], {"ready"}) artifact = self.task_repo.get_artifact(task["id"]) self.task_repo.upsert_artifact( task_id=task["id"], file_path=artifact["file_path"], file_name=artifact["file_name"], expires_at="2000-01-01T00:00:00Z", ) response = self._request("GET", f"/api/files/download/archive/{task['id']}") self.assertEqual(response.status_code, 410) self.assertEqual(response.json()["error"]["code"], "archive_expired") self.assertIsNone(self.task_repo.get_artifact(task["id"])) self.assertFalse(Path(artifact["file_path"]).exists()) def test_archive_prepare_rejects_single_file(self) -> None: (self.root / "report.txt").write_text("hello download", encoding="utf-8") response = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/report.txt"]}) self.assertEqual(response.status_code, 400) self.assertEqual(response.json()["error"]["code"], "invalid_request") def test_direct_archive_download_route_rejected(self) -> None: (self.root / "docs").mkdir() (self.root / "docs" / "a.txt").write_text("a", encoding="utf-8") response = self._request("GET", "/api/files/download?path=storage1/docs") self.assertEqual(response.status_code, 400) self.assertEqual(response.json()["error"]["code"], "invalid_request") if __name__ == "__main__": unittest.main()