feat: B2 uit voor veilige archive-downloads

This commit is contained in:
kodi
2026-03-14 14:24:52 +01:00
parent 592b10acc2
commit d463b3977d
24 changed files with 754 additions and 195 deletions
+24 -2
View File
@@ -4,8 +4,9 @@ from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from backend.app.api.schemas import DeleteRequest, DeleteResponse, FileInfoResponse, MkdirRequest, MkdirResponse, RenameRequest, RenameResponse, SaveRequest, SaveResponse, UploadResponse, ViewResponse from backend.app.api.schemas import ArchivePrepareRequest, DeleteRequest, DeleteResponse, FileInfoResponse, MkdirRequest, MkdirResponse, RenameRequest, RenameResponse, SaveRequest, SaveResponse, TaskCreateResponse, UploadResponse, ViewResponse
from backend.app.dependencies import get_file_ops_service from backend.app.dependencies import get_archive_download_task_service, get_file_ops_service
from backend.app.services.archive_download_task_service import ArchiveDownloadTaskService
from backend.app.services.file_ops_service import FileOpsService from backend.app.services.file_ops_service import FileOpsService
router = APIRouter(prefix="/files") router = APIRouter(prefix="/files")
@@ -78,6 +79,27 @@ async def download(
return response return response
@router.post("/download/archive-prepare", response_model=TaskCreateResponse, status_code=202)
async def archive_prepare(
request: ArchivePrepareRequest,
service: ArchiveDownloadTaskService = Depends(get_archive_download_task_service),
) -> TaskCreateResponse:
return service.create_archive_prepare_task(paths=request.paths)
@router.get("/download/archive/{task_id}")
async def archive_download(
task_id: str,
service: ArchiveDownloadTaskService = Depends(get_archive_download_task_service),
) -> StreamingResponse:
prepared = service.prepare_ready_archive_download(task_id=task_id)
return StreamingResponse(
prepared["content"],
headers=prepared["headers"],
media_type=prepared["content_type"],
)
@router.get("/video") @router.get("/video")
async def video( async def video(
path: str, path: str,
+4
View File
@@ -88,6 +88,10 @@ class SaveResponse(BaseModel):
modified: str modified: str
class ArchivePrepareRequest(BaseModel):
paths: list[str]
class FileInfoResponse(BaseModel): class FileInfoResponse(BaseModel):
name: str name: str
path: str path: str
+106 -4
View File
@@ -6,8 +6,8 @@ from contextlib import contextmanager
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
VALID_STATUSES = {"queued", "running", "completed", "failed"} VALID_STATUSES = {"queued", "running", "completed", "failed", "requested", "preparing", "ready"}
VALID_OPERATIONS = {"copy", "move"} VALID_OPERATIONS = {"copy", "move", "download"}
TASK_MIGRATION_COLUMNS: dict[str, str] = { TASK_MIGRATION_COLUMNS: dict[str, str] = {
"operation": "TEXT NOT NULL DEFAULT 'copy'", "operation": "TEXT NOT NULL DEFAULT 'copy'",
"status": "TEXT NOT NULL DEFAULT 'queued'", "status": "TEXT NOT NULL DEFAULT 'queued'",
@@ -32,9 +32,18 @@ class TaskRepository:
self._db_path = db_path self._db_path = db_path
self._ensure_schema() self._ensure_schema()
def create_task(self, operation: str, source: str, destination: str, task_id: str | None = None) -> dict: def create_task(
self,
operation: str,
source: str,
destination: str,
task_id: str | None = None,
status: str = "queued",
) -> dict:
if operation not in VALID_OPERATIONS: if operation not in VALID_OPERATIONS:
raise ValueError("invalid operation") raise ValueError("invalid operation")
if status not in VALID_STATUSES:
raise ValueError("invalid status")
task_id = task_id or str(uuid.uuid4()) task_id = task_id or str(uuid.uuid4())
created_at = self._now_iso() created_at = self._now_iso()
@@ -52,7 +61,7 @@ class TaskRepository:
( (
task_id, task_id,
operation, operation,
"queued", status,
source, source,
destination, destination,
None, None,
@@ -145,6 +154,24 @@ class TaskRepository:
("running", started_at, done_bytes, total_bytes, done_items, total_items, current_item, task_id), ("running", started_at, done_bytes, total_bytes, done_items, total_items, current_item, task_id),
) )
def mark_preparing(
self,
task_id: str,
done_items: int | None = None,
total_items: int | None = None,
current_item: str | None = None,
) -> None:
started_at = self._now_iso()
with self._connection() as conn:
conn.execute(
"""
UPDATE tasks
SET status = ?, started_at = COALESCE(started_at, ?), done_items = ?, total_items = ?, current_item = ?
WHERE id = ?
""",
("preparing", started_at, done_items, total_items, current_item, task_id),
)
def update_progress( def update_progress(
self, self,
task_id: str, task_id: str,
@@ -183,6 +210,23 @@ class TaskRepository:
("completed", finished_at, done_bytes, total_bytes, done_items, total_items, task_id), ("completed", finished_at, done_bytes, total_bytes, done_items, total_items, task_id),
) )
def mark_ready(
self,
task_id: str,
done_items: int | None = None,
total_items: int | None = None,
) -> None:
finished_at = self._now_iso()
with self._connection() as conn:
conn.execute(
"""
UPDATE tasks
SET status = ?, finished_at = ?, done_items = ?, total_items = ?, current_item = NULL
WHERE id = ?
""",
("ready", finished_at, done_items, total_items, task_id),
)
def mark_failed( def mark_failed(
self, self,
task_id: str, task_id: str,
@@ -244,14 +288,62 @@ class TaskRepository:
) )
""" """
) )
conn.execute(
"""
CREATE TABLE IF NOT EXISTS task_artifacts (
task_id TEXT PRIMARY KEY,
file_path TEXT NOT NULL,
file_name TEXT NOT NULL,
expires_at TEXT NOT NULL,
created_at TEXT NOT NULL
)
"""
)
conn.execute( conn.execute(
""" """
CREATE INDEX IF NOT EXISTS idx_tasks_created_at_desc CREATE INDEX IF NOT EXISTS idx_tasks_created_at_desc
ON tasks(created_at DESC) ON tasks(created_at DESC)
""" """
) )
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_task_artifacts_expires_at
ON task_artifacts(expires_at ASC)
"""
)
self._migrate_tasks_columns(conn) self._migrate_tasks_columns(conn)
def upsert_artifact(self, *, task_id: str, file_path: str, file_name: str, expires_at: str) -> dict:
created_at = self._now_iso()
with self._connection() as conn:
conn.execute(
"""
INSERT INTO task_artifacts (task_id, file_path, file_name, expires_at, created_at)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(task_id) DO UPDATE SET
file_path = excluded.file_path,
file_name = excluded.file_name,
expires_at = excluded.expires_at
""",
(task_id, file_path, file_name, expires_at, created_at),
)
row = conn.execute("SELECT * FROM task_artifacts WHERE task_id = ?", (task_id,)).fetchone()
return self._artifact_to_dict(row)
def get_artifact(self, task_id: str) -> dict | None:
with self._connection() as conn:
row = conn.execute("SELECT * FROM task_artifacts WHERE task_id = ?", (task_id,)).fetchone()
return self._artifact_to_dict(row) if row else None
def list_artifacts(self) -> list[dict]:
with self._connection() as conn:
rows = conn.execute("SELECT * FROM task_artifacts ORDER BY created_at ASC").fetchall()
return [self._artifact_to_dict(row) for row in rows]
def delete_artifact(self, task_id: str) -> None:
with self._connection() as conn:
conn.execute("DELETE FROM task_artifacts WHERE task_id = ?", (task_id,))
def _migrate_tasks_columns(self, conn: sqlite3.Connection) -> None: def _migrate_tasks_columns(self, conn: sqlite3.Connection) -> None:
rows = conn.execute("PRAGMA table_info(tasks)").fetchall() rows = conn.execute("PRAGMA table_info(tasks)").fetchall()
existing_columns = {row["name"] for row in rows} existing_columns = {row["name"] for row in rows}
@@ -298,6 +390,16 @@ class TaskRepository:
"finished_at": row["finished_at"], "finished_at": row["finished_at"],
} }
@staticmethod
def _artifact_to_dict(row: sqlite3.Row) -> dict:
return {
"task_id": row["task_id"],
"file_path": row["file_path"],
"file_name": row["file_name"],
"expires_at": row["expires_at"],
"created_at": row["created_at"],
}
@staticmethod @staticmethod
def _now_iso() -> str: def _now_iso() -> str:
return datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z") return datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z")
+23
View File
@@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from functools import lru_cache from functools import lru_cache
from pathlib import Path
from backend.app.config import Settings, get_settings from backend.app.config import Settings, get_settings
from backend.app.db.bookmark_repository import BookmarkRepository from backend.app.db.bookmark_repository import BookmarkRepository
@@ -12,6 +13,7 @@ from backend.app.security.path_guard import PathGuard
from backend.app.services.bookmark_service import BookmarkService from backend.app.services.bookmark_service import BookmarkService
from backend.app.services.browse_service import BrowseService from backend.app.services.browse_service import BrowseService
from backend.app.services.copy_task_service import CopyTaskService from backend.app.services.copy_task_service import CopyTaskService
from backend.app.services.archive_download_task_service import ArchiveDownloadTaskService
from backend.app.services.file_ops_service import FileOpsService from backend.app.services.file_ops_service import FileOpsService
from backend.app.services.history_service import HistoryService from backend.app.services.history_service import HistoryService
from backend.app.services.move_task_service import MoveTaskService from backend.app.services.move_task_service import MoveTaskService
@@ -64,6 +66,12 @@ def get_task_runner() -> TaskRunner:
) )
@lru_cache(maxsize=1)
def get_archive_artifact_root() -> str:
settings: Settings = get_settings()
return str(Path(settings.task_db_path).resolve().parent / "archive_tmp")
async def get_browse_service() -> BrowseService: async def get_browse_service() -> BrowseService:
return BrowseService(path_guard=get_path_guard(), filesystem=get_filesystem_adapter()) return BrowseService(path_guard=get_path_guard(), filesystem=get_filesystem_adapter())
@@ -76,6 +84,21 @@ async def get_file_ops_service() -> FileOpsService:
) )
async def get_archive_download_task_service() -> ArchiveDownloadTaskService:
return ArchiveDownloadTaskService(
path_guard=get_path_guard(),
repository=get_task_repository(),
runner=get_task_runner(),
history_repository=get_history_repository(),
file_ops_service=FileOpsService(
path_guard=get_path_guard(),
filesystem=get_filesystem_adapter(),
history_repository=get_history_repository(),
),
artifact_root=Path(get_archive_artifact_root()),
)
async def get_task_service() -> TaskService: async def get_task_service() -> TaskService:
return TaskService(repository=get_task_repository()) return TaskService(repository=get_task_repository())
@@ -0,0 +1,266 @@
from __future__ import annotations
import os
import uuid
import zipfile
from datetime import datetime, timedelta, timezone
from pathlib import Path
from backend.app.api.errors import AppError
from backend.app.api.schemas import TaskCreateResponse
from backend.app.db.history_repository import HistoryRepository
from backend.app.db.task_repository import TaskRepository
from backend.app.security.path_guard import PathGuard
from backend.app.services.file_ops_service import FileOpsService
from backend.app.tasks_runner import TaskRunner
ARCHIVE_DOWNLOAD_TTL_SECONDS = 30 * 60
class ArchiveDownloadTaskService:
def __init__(
self,
path_guard: PathGuard,
repository: TaskRepository,
runner: TaskRunner,
history_repository: HistoryRepository | None,
file_ops_service: FileOpsService,
artifact_root: Path,
artifact_ttl_seconds: int = ARCHIVE_DOWNLOAD_TTL_SECONDS,
):
self._path_guard = path_guard
self._repository = repository
self._runner = runner
self._history_repository = history_repository
self._file_ops_service = file_ops_service
self._artifact_root = artifact_root
self._artifact_ttl_seconds = artifact_ttl_seconds
self._artifact_root.mkdir(parents=True, exist_ok=True)
self.sweep_artifacts()
def create_archive_prepare_task(self, paths: list[str]) -> TaskCreateResponse:
if not paths:
raise AppError(
code="invalid_request",
message="At least one path is required",
status_code=400,
)
self.sweep_artifacts()
resolved_targets = [self._path_guard.resolve_existing_path(path) for path in paths]
mode = self._file_ops_service._download_mode_from_resolved_targets(resolved_targets)
if mode == "single_file":
raise AppError(
code="invalid_request",
message="Single file downloads must use direct download",
status_code=400,
)
summary = self._file_ops_service._summarize_download_targets([target.relative for target in resolved_targets])
archive_name = self._file_ops_service._download_name_for_targets(resolved_targets)
task_id = str(uuid.uuid4())
task = self._repository.create_task(
operation="download",
source=summary,
destination=archive_name,
task_id=task_id,
status="requested",
)
self._record_history(
entry_id=task_id,
operation="download",
status="requested",
source=mode,
destination=archive_name,
path=summary,
)
target_paths = [target.relative for target in resolved_targets]
self._runner.enqueue_archive_prepare(
lambda: self._run_archive_prepare_task(
task_id=task_id,
target_paths=target_paths,
archive_name=archive_name,
history_mode=mode,
history_path=summary,
)
)
return TaskCreateResponse(task_id=task["id"], status=task["status"])
def prepare_ready_archive_download(self, task_id: str) -> dict:
self.sweep_artifacts()
task = self._repository.get_task(task_id)
if not task:
raise AppError(
code="task_not_found",
message="Task was not found",
status_code=404,
details={"task_id": task_id},
)
if task["operation"] != "download":
raise AppError(
code="invalid_request",
message="Task is not an archive download",
status_code=400,
details={"task_id": task_id},
)
if task["status"] != "ready":
raise AppError(
code="download_not_ready",
message="Archive download is not ready",
status_code=409,
details={"task_id": task_id, "status": task["status"]},
)
artifact = self._repository.get_artifact(task_id)
if not artifact:
raise AppError(
code="archive_not_found",
message="Prepared archive was not found",
status_code=404,
details={"task_id": task_id},
)
if self._is_expired(artifact["expires_at"]):
self._delete_artifact_record_and_file(task_id, artifact["file_path"])
raise AppError(
code="archive_expired",
message="Prepared archive expired",
status_code=410,
details={"task_id": task_id},
)
artifact_path = Path(artifact["file_path"])
if not artifact_path.exists():
self._repository.delete_artifact(task_id)
raise AppError(
code="archive_not_found",
message="Prepared archive was not found",
status_code=404,
details={"task_id": task_id},
)
return {
"content": self._file_ops_service._filesystem.stream_file(artifact_path),
"headers": {
"Content-Disposition": f'attachment; filename="{artifact["file_name"]}"',
"Content-Length": str(int(artifact_path.stat().st_size)),
},
"content_type": "application/zip",
}
def sweep_artifacts(self) -> None:
self._artifact_root.mkdir(parents=True, exist_ok=True)
referenced_paths: set[Path] = set()
for artifact in self._repository.list_artifacts():
artifact_path = Path(artifact["file_path"])
referenced_paths.add(artifact_path)
if self._is_expired(artifact["expires_at"]) or not artifact_path.exists():
self._delete_artifact_record_and_file(artifact["task_id"], artifact["file_path"])
for candidate in self._artifact_root.iterdir():
if candidate.is_file() and candidate not in referenced_paths:
try:
candidate.unlink()
except FileNotFoundError:
pass
def _run_archive_prepare_task(
self,
*,
task_id: str,
target_paths: list[str],
archive_name: str,
history_mode: str,
history_path: str,
) -> None:
partial_path = self._artifact_root / f"{task_id}.partial.zip"
final_path = self._artifact_root / f"{task_id}.zip"
total_items = len(target_paths)
try:
self._repository.mark_preparing(
task_id=task_id,
done_items=0,
total_items=total_items,
current_item=target_paths[0] if target_paths else None,
)
resolved_targets = [self._path_guard.resolve_existing_path(path) for path in target_paths]
self._file_ops_service._validate_zip_download_archive_names(resolved_targets)
self._file_ops_service._run_zip_download_preflight(resolved_targets)
with zipfile.ZipFile(partial_path, "w", compression=zipfile.ZIP_DEFLATED) as archive:
for resolved_target in resolved_targets:
self._file_ops_service._write_download_target_to_zip(archive, resolved_target)
os.replace(partial_path, final_path)
self._repository.upsert_artifact(
task_id=task_id,
file_path=str(final_path),
file_name=archive_name,
expires_at=self._expires_at_iso(),
)
self._repository.mark_ready(
task_id=task_id,
done_items=total_items,
total_items=total_items,
)
self._update_history_ready(task_id)
except AppError as exc:
self._delete_artifact_record_and_file(task_id, str(partial_path))
self._delete_artifact_record_and_file(task_id, str(final_path))
self._repository.mark_failed(
task_id=task_id,
error_code=exc.code,
error_message=exc.message,
failed_item=history_path,
done_bytes=None,
total_bytes=None,
done_items=0,
total_items=total_items,
)
self._update_history_failed(task_id, exc.code, exc.message)
except OSError as exc:
self._delete_artifact_record_and_file(task_id, str(partial_path))
self._delete_artifact_record_and_file(task_id, str(final_path))
self._repository.mark_failed(
task_id=task_id,
error_code="io_error",
error_message=str(exc),
failed_item=history_path,
done_bytes=None,
total_bytes=None,
done_items=0,
total_items=total_items,
)
self._update_history_failed(task_id, "io_error", str(exc))
def _delete_artifact_record_and_file(self, task_id: str, file_path: str) -> None:
self._repository.delete_artifact(task_id)
path = Path(file_path)
try:
path.unlink()
except FileNotFoundError:
pass
def _update_history_ready(self, task_id: str) -> None:
if self._history_repository:
self._history_repository.update_entry(entry_id=task_id, status="ready")
def _update_history_failed(self, task_id: str, error_code: str, error_message: str) -> None:
if self._history_repository:
self._history_repository.update_entry(
entry_id=task_id,
status="failed",
error_code=error_code,
error_message=error_message,
)
def _record_history(self, **kwargs) -> None:
if self._history_repository:
self._history_repository.create_entry(**kwargs)
def _expires_at_iso(self) -> str:
return (datetime.now(timezone.utc) + timedelta(seconds=self._artifact_ttl_seconds)).replace(microsecond=0).isoformat().replace("+00:00", "Z")
@staticmethod
def _is_expired(expires_at: str) -> bool:
return datetime.now(timezone.utc) >= datetime.fromisoformat(expires_at.replace("Z", "+00:00"))
+21 -13
View File
@@ -411,6 +411,14 @@ class FileOpsService:
history_mode = self._download_mode_from_resolved_targets(resolved_targets) history_mode = self._download_mode_from_resolved_targets(resolved_targets)
history_path = self._summarize_download_targets([target.relative for target in resolved_targets]) history_path = self._summarize_download_targets([target.relative for target in resolved_targets])
history_download_name = self._download_name_for_targets(resolved_targets) history_download_name = self._download_name_for_targets(resolved_targets)
if history_mode != "single_file":
raise AppError(
code="invalid_request",
message="Archive downloads must be prepared first",
status_code=400,
)
history_entry_id = self._record_download_status( history_entry_id = self._record_download_status(
status="requested", status="requested",
mode=history_mode, mode=history_mode,
@@ -418,10 +426,7 @@ class FileOpsService:
download_name=history_download_name, download_name=history_download_name,
) )
if len(resolved_targets) == 1 and resolved_targets[0].absolute.is_file():
prepared = self._prepare_single_file_download(resolved_targets[0]) prepared = self._prepare_single_file_download(resolved_targets[0])
else:
prepared = self._prepare_zip_download(resolved_targets, history_download_name)
self._record_download_status( self._record_download_status(
status="ready", status="ready",
@@ -757,16 +762,7 @@ class FileOpsService:
} }
def _prepare_zip_download(self, resolved_targets: list, download_name: str) -> dict: def _prepare_zip_download(self, resolved_targets: list, download_name: str) -> dict:
archive_names: set[str] = set() self._validate_zip_download_archive_names(resolved_targets)
for resolved_target in resolved_targets:
archive_name = resolved_target.absolute.name
if archive_name in archive_names:
raise AppError(
code="invalid_request",
message="Selected items must have distinct top-level names",
status_code=400,
)
archive_names.add(archive_name)
self._run_zip_download_preflight(resolved_targets) self._run_zip_download_preflight(resolved_targets)
buffer = BytesIO() buffer = BytesIO()
@@ -786,6 +782,18 @@ class FileOpsService:
"content_type": "application/zip", "content_type": "application/zip",
} }
def _validate_zip_download_archive_names(self, resolved_targets: list) -> None:
archive_names: set[str] = set()
for resolved_target in resolved_targets:
archive_name = resolved_target.absolute.name
if archive_name in archive_names:
raise AppError(
code="invalid_request",
message="Selected items must have distinct top-level names",
status_code=400,
)
archive_names.add(archive_name)
def _download_name_for_targets(self, resolved_targets: list) -> str: def _download_name_for_targets(self, resolved_targets: list) -> str:
if len(resolved_targets) == 1 and resolved_targets[0].absolute.is_file(): if len(resolved_targets) == 1 and resolved_targets[0].absolute.is_file():
return resolved_targets[0].absolute.name return resolved_targets[0].absolute.name
+7
View File
@@ -69,6 +69,13 @@ class TaskRunner:
) )
thread.start() thread.start()
def enqueue_archive_prepare(self, worker) -> None:
thread = threading.Thread(
target=worker,
daemon=True,
)
thread.start()
def _run_copy_file(self, task_id: str, source: str, destination: str, total_bytes: int) -> None: def _run_copy_file(self, task_id: str, source: str, destination: str, total_bytes: int) -> None:
self._repository.mark_running( self._repository.mark_running(
task_id=task_id, task_id=task_id,
@@ -3,6 +3,8 @@ from __future__ import annotations
import asyncio import asyncio
import sys import sys
import tempfile import tempfile
import threading
import time
import unittest import unittest
import zipfile import zipfile
from io import BytesIO from io import BytesIO
@@ -12,11 +14,32 @@ import httpx
sys.path.insert(0, str(Path(__file__).resolve().parents[3])) sys.path.insert(0, str(Path(__file__).resolve().parents[3]))
from backend.app.dependencies import get_file_ops_service from backend.app.dependencies import get_archive_download_task_service, get_file_ops_service, get_task_service
from backend.app.db.history_repository import HistoryRepository
from backend.app.db.task_repository import TaskRepository
from backend.app.fs.filesystem_adapter import FilesystemAdapter from backend.app.fs.filesystem_adapter import FilesystemAdapter
from backend.app.main import app from backend.app.main import app
from backend.app.security.path_guard import PathGuard from backend.app.security.path_guard import PathGuard
from backend.app.services.archive_download_task_service import ArchiveDownloadTaskService
from backend.app.services.file_ops_service import FileOpsService, ZipDownloadPreflightLimits from backend.app.services.file_ops_service import FileOpsService, ZipDownloadPreflightLimits
from backend.app.services.task_service import TaskService
from backend.app.tasks_runner import TaskRunner
class BlockingArchiveFileOpsService(FileOpsService):
def __init__(self, *args, gate: threading.Event, **kwargs):
super().__init__(*args, **kwargs)
self._gate = gate
def _run_zip_download_preflight(self, resolved_targets: list) -> None:
super()._run_zip_download_preflight(resolved_targets)
self._gate.wait(timeout=2.0)
class FailingArchiveFileOpsService(FileOpsService):
def _write_download_target_to_zip(self, archive: zipfile.ZipFile, resolved_target) -> None:
archive.writestr("partial.txt", b"partial")
raise OSError("forced archive failure")
class DownloadApiGoldenTest(unittest.TestCase): class DownloadApiGoldenTest(unittest.TestCase):
@@ -24,56 +47,122 @@ class DownloadApiGoldenTest(unittest.TestCase):
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
self.root = Path(self.temp_dir.name) / "root" self.root = Path(self.temp_dir.name) / "root"
self.root.mkdir(parents=True, exist_ok=True) self.root.mkdir(parents=True, exist_ok=True)
self.db_path = str(Path(self.temp_dir.name) / "tasks.db")
self.artifact_root = Path(self.temp_dir.name) / "archive_tmp"
self.path_guard = PathGuard({"storage1": str(self.root), "storage2": str(self.root)}) self.path_guard = PathGuard({"storage1": str(self.root), "storage2": str(self.root)})
self.filesystem = FilesystemAdapter() self.filesystem = FilesystemAdapter()
self._override_service() self.task_repo = TaskRepository(self.db_path)
self.history_repo = HistoryRepository(self.db_path)
self._override_services()
def tearDown(self) -> None: def tearDown(self) -> None:
app.dependency_overrides.clear() app.dependency_overrides.clear()
self.temp_dir.cleanup() self.temp_dir.cleanup()
def _get(self, url: str) -> httpx.Response: def _override_services(
self,
*,
file_ops_service: FileOpsService | None = None,
artifact_ttl_seconds: int = 1800,
) -> None:
file_ops_service = file_ops_service or FileOpsService(
path_guard=self.path_guard,
filesystem=self.filesystem,
history_repository=self.history_repo,
zip_download_preflight_limits=ZipDownloadPreflightLimits(),
)
runner = TaskRunner(repository=self.task_repo, filesystem=self.filesystem, history_repository=self.history_repo)
archive_service = ArchiveDownloadTaskService(
path_guard=self.path_guard,
repository=self.task_repo,
runner=runner,
history_repository=self.history_repo,
file_ops_service=file_ops_service,
artifact_root=self.artifact_root,
artifact_ttl_seconds=artifact_ttl_seconds,
)
task_service = TaskService(repository=self.task_repo)
async def _override_file_ops_service() -> FileOpsService:
return file_ops_service
async def _override_archive_service() -> ArchiveDownloadTaskService:
return archive_service
async def _override_task_service() -> TaskService:
return task_service
app.dependency_overrides[get_file_ops_service] = _override_file_ops_service
app.dependency_overrides[get_archive_download_task_service] = _override_archive_service
app.dependency_overrides[get_task_service] = _override_task_service
def _request(self, method: str, url: str, payload: dict | None = None) -> httpx.Response:
async def _run() -> httpx.Response: async def _run() -> httpx.Response:
transport = httpx.ASGITransport(app=app) transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
if method == "GET":
return await client.get(url) return await client.get(url)
return await client.post(url, json=payload)
return asyncio.run(_run()) return asyncio.run(_run())
def _override_service( def _wait_for_task_status(self, task_id: str, statuses: set[str], timeout_s: float = 2.0) -> dict:
self, deadline = time.time() + timeout_s
*, while time.time() < deadline:
limits: ZipDownloadPreflightLimits | None = None, response = self._request("GET", f"/api/tasks/{task_id}")
monotonic=None, body = response.json()
) -> None: if body["status"] in statuses:
service = FileOpsService( return body
path_guard=self.path_guard, time.sleep(0.02)
filesystem=self.filesystem, self.fail("task did not reach expected status in time")
zip_download_preflight_limits=limits or ZipDownloadPreflightLimits(),
monotonic=monotonic,
)
async def _override_file_ops_service() -> FileOpsService:
return service
app.dependency_overrides[get_file_ops_service] = _override_file_ops_service
def test_download_success_for_allowed_file(self) -> None: def test_download_success_for_allowed_file(self) -> None:
src = self.root / "report.txt" src = self.root / "report.txt"
src.write_text("hello download", encoding="utf-8") src.write_text("hello download", encoding="utf-8")
response = self._get("/api/files/download?path=storage1/report.txt") response = self._request("GET", "/api/files/download?path=storage1/report.txt")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"hello download") self.assertEqual(response.content, b"hello download")
self.assertIn('attachment; filename="report.txt"', response.headers.get("content-disposition", "")) self.assertIn('attachment; filename="report.txt"', response.headers.get("content-disposition", ""))
self.assertEqual(response.headers.get("content-type"), "text/plain; charset=utf-8") self.assertEqual(response.headers.get("content-type"), "text/plain; charset=utf-8")
def test_download_single_directory_as_zip(self) -> None: def test_archive_prepare_single_directory_ends_ready(self) -> None:
(self.root / "docs").mkdir() (self.root / "docs").mkdir()
(self.root / "docs" / "a.txt").write_text("a", encoding="utf-8") (self.root / "docs" / "a.txt").write_text("a", encoding="utf-8")
response = self._get("/api/files/download?path=storage1/docs") created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]})
self.assertEqual(created.status_code, 202)
task = self._wait_for_task_status(created.json()["task_id"], {"ready"})
self.assertEqual(task["operation"], "download")
self.assertEqual(task["status"], "ready")
self.assertEqual(task["destination"], "docs.zip")
def test_archive_prepare_multi_mixed_selection_ends_ready(self) -> None:
(self.root / "readme.txt").write_text("R", encoding="utf-8")
(self.root / "photos").mkdir()
(self.root / "photos" / "img.txt").write_text("P", encoding="utf-8")
created = self._request(
"POST",
"/api/files/download/archive-prepare",
{"paths": ["storage1/readme.txt", "storage1/photos"]},
)
self.assertEqual(created.status_code, 202)
task = self._wait_for_task_status(created.json()["task_id"], {"ready"})
self.assertEqual(task["status"], "ready")
self.assertEqual(task["source"], "storage1/readme.txt, storage1/photos")
self.assertRegex(task["destination"], r'^kodidownload-\d{8}-\d{6}\.zip$')
def test_archive_retrieval_from_ready_task_works(self) -> None:
(self.root / "docs").mkdir()
(self.root / "docs" / "a.txt").write_text("a", encoding="utf-8")
created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]})
task = self._wait_for_task_status(created.json()["task_id"], {"ready"})
response = self._request("GET", f"/api/files/download/archive/{task['id']}")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertIn('attachment; filename="docs.zip"', response.headers.get("content-disposition", "")) self.assertIn('attachment; filename="docs.zip"', response.headers.get("content-disposition", ""))
@@ -82,167 +171,93 @@ class DownloadApiGoldenTest(unittest.TestCase):
self.assertIn("docs/a.txt", archive.namelist()) self.assertIn("docs/a.txt", archive.namelist())
self.assertEqual(archive.read("docs/a.txt"), b"a") self.assertEqual(archive.read("docs/a.txt"), b"a")
def test_download_multi_file_selection_as_zip(self) -> None: def test_archive_retrieval_before_ready_rejected(self) -> None:
(self.root / "a.txt").write_text("A", encoding="utf-8") gate = threading.Event()
(self.root / "b.txt").write_text("B", encoding="utf-8") file_ops_service = BlockingArchiveFileOpsService(
path_guard=self.path_guard,
response = self._get("/api/files/download?path=storage1/a.txt&path=storage1/b.txt") filesystem=self.filesystem,
history_repository=self.history_repo,
self.assertEqual(response.status_code, 200) zip_download_preflight_limits=ZipDownloadPreflightLimits(),
self.assertRegex( gate=gate,
response.headers.get("content-disposition", ""),
r'attachment; filename="kodidownload-\d{8}-\d{6}\.zip"',
) )
with zipfile.ZipFile(BytesIO(response.content)) as archive: self._override_services(file_ops_service=file_ops_service)
self.assertIn("a.txt", archive.namelist())
self.assertIn("b.txt", archive.namelist())
self.assertEqual(archive.read("a.txt"), b"A")
self.assertEqual(archive.read("b.txt"), b"B")
def test_download_multi_directory_selection_as_zip(self) -> None:
(self.root / "dir1" / "sub").mkdir(parents=True)
(self.root / "dir2").mkdir()
(self.root / "dir1" / "sub" / "a.txt").write_text("A", encoding="utf-8")
(self.root / "dir2" / "b.txt").write_text("B", encoding="utf-8")
response = self._get("/api/files/download?path=storage1/dir1&path=storage1/dir2")
self.assertEqual(response.status_code, 200)
self.assertRegex(
response.headers.get("content-disposition", ""),
r'attachment; filename="kodidownload-\d{8}-\d{6}\.zip"',
)
with zipfile.ZipFile(BytesIO(response.content)) as archive:
self.assertIn("dir1/", archive.namelist())
self.assertIn("dir1/sub/", archive.namelist())
self.assertIn("dir1/sub/a.txt", archive.namelist())
self.assertIn("dir2/b.txt", archive.namelist())
def test_download_mixed_file_and_directory_selection_as_zip(self) -> None:
(self.root / "readme.txt").write_text("R", encoding="utf-8")
(self.root / "photos" / "nested").mkdir(parents=True)
(self.root / "photos" / "nested" / "img.txt").write_text("P", encoding="utf-8")
response = self._get("/api/files/download?path=storage1/readme.txt&path=storage1/photos")
self.assertEqual(response.status_code, 200)
self.assertRegex(
response.headers.get("content-disposition", ""),
r'attachment; filename="kodidownload-\d{8}-\d{6}\.zip"',
)
with zipfile.ZipFile(BytesIO(response.content)) as archive:
self.assertIn("readme.txt", archive.namelist())
self.assertIn("photos/", archive.namelist())
self.assertIn("photos/nested/img.txt", archive.namelist())
def test_download_zip_rejected_when_max_items_exceeded(self) -> None:
(self.root / "docs").mkdir() (self.root / "docs").mkdir()
(self.root / "docs" / "a.txt").write_text("A", encoding="utf-8") (self.root / "docs" / "a.txt").write_text("a", encoding="utf-8")
(self.root / "docs" / "b.txt").write_text("B", encoding="utf-8") created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]})
(self.root / "docs" / "c.txt").write_text("C", encoding="utf-8") task = self._wait_for_task_status(created.json()["task_id"], {"requested", "preparing"})
self._override_service(
limits=ZipDownloadPreflightLimits(
max_items=3,
max_total_input_bytes=1024,
max_individual_file_bytes=1024,
scan_timeout_seconds=10.0,
)
)
response = self._get("/api/files/download?path=storage1/docs") response = self._request("GET", f"/api/files/download/archive/{task['id']}")
gate.set()
self.assertEqual(response.status_code, 409) self.assertEqual(response.status_code, 409)
self.assertEqual(response.json()["error"]["code"], "download_preflight_failed") self.assertEqual(response.json()["error"]["code"], "download_not_ready")
self.assertEqual(response.json()["error"]["message"], "Zip download preflight failed")
self.assertEqual(response.json()["error"]["details"]["reason"], "max_items_exceeded")
def test_download_zip_rejected_when_max_total_input_bytes_exceeded(self) -> None: def test_archive_preflight_failure_sets_failed_and_error_code(self) -> None:
(self.root / "a.txt").write_text("AAAA", encoding="utf-8")
(self.root / "b.txt").write_text("BBBB", encoding="utf-8")
self._override_service(
limits=ZipDownloadPreflightLimits(
max_items=10,
max_total_input_bytes=7,
max_individual_file_bytes=1024,
scan_timeout_seconds=10.0,
)
)
response = self._get("/api/files/download?path=storage1/a.txt&path=storage1/b.txt")
self.assertEqual(response.status_code, 409)
self.assertEqual(response.json()["error"]["code"], "download_preflight_failed")
self.assertEqual(response.json()["error"]["details"]["reason"], "max_total_input_bytes_exceeded")
def test_download_zip_rejected_when_individual_file_too_large(self) -> None:
(self.root / "docs").mkdir()
(self.root / "docs" / "large.bin").write_bytes(b"123456")
self._override_service(
limits=ZipDownloadPreflightLimits(
max_items=10,
max_total_input_bytes=1024,
max_individual_file_bytes=5,
scan_timeout_seconds=10.0,
)
)
response = self._get("/api/files/download?path=storage1/docs")
self.assertEqual(response.status_code, 409)
self.assertEqual(response.json()["error"]["code"], "download_preflight_failed")
self.assertEqual(response.json()["error"]["details"]["reason"], "max_individual_file_size_exceeded")
self.assertEqual(response.json()["error"]["details"]["path"], "storage1/docs/large.bin")
def test_download_directory_with_symlink_rejected(self) -> None:
target = self.root / "real.txt" target = self.root / "real.txt"
target.write_text("x", encoding="utf-8") target.write_text("x", encoding="utf-8")
(self.root / "docs").mkdir() (self.root / "docs").mkdir()
(self.root / "docs" / "link.txt").symlink_to(target) (self.root / "docs" / "link.txt").symlink_to(target)
response = self._get("/api/files/download?path=storage1/docs") created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]})
task = self._wait_for_task_status(created.json()["task_id"], {"failed"})
self.assertEqual(response.status_code, 409) self.assertEqual(task["status"], "failed")
self.assertEqual(response.json()["error"]["code"], "download_preflight_failed") self.assertEqual(task["error_code"], "download_preflight_failed")
self.assertEqual(response.json()["error"]["details"]["reason"], "symlink_detected")
self.assertEqual(response.json()["error"]["details"]["path"], "storage1/docs/link.txt")
def test_download_zip_preflight_timeout_rejected_cleanly(self) -> None: def test_archive_failure_removes_partial_artifact(self) -> None:
(self.root / "a.txt").write_text("A", encoding="utf-8") file_ops_service = FailingArchiveFileOpsService(
(self.root / "b.txt").write_text("B", encoding="utf-8") path_guard=self.path_guard,
ticks = iter([0.0, 11.0, 11.0, 11.0]) filesystem=self.filesystem,
self._override_service( history_repository=self.history_repo,
limits=ZipDownloadPreflightLimits( zip_download_preflight_limits=ZipDownloadPreflightLimits(),
max_items=10, )
max_total_input_bytes=1024, self._override_services(file_ops_service=file_ops_service)
max_individual_file_bytes=1024, (self.root / "docs").mkdir()
scan_timeout_seconds=10.0, (self.root / "docs" / "a.txt").write_text("a", encoding="utf-8")
),
monotonic=lambda: next(ticks), created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]})
task = self._wait_for_task_status(created.json()["task_id"], {"failed"})
self.assertEqual(task["error_code"], "io_error")
self.assertEqual(list(self.artifact_root.glob("*")), [])
def test_expired_artifact_rejected_and_removed(self) -> None:
(self.root / "docs").mkdir()
(self.root / "docs" / "a.txt").write_text("a", encoding="utf-8")
self._override_services(artifact_ttl_seconds=1)
created = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/docs"]})
task = self._wait_for_task_status(created.json()["task_id"], {"ready"})
artifact = self.task_repo.get_artifact(task["id"])
self.task_repo.upsert_artifact(
task_id=task["id"],
file_path=artifact["file_path"],
file_name=artifact["file_name"],
expires_at="2000-01-01T00:00:00Z",
) )
response = self._get("/api/files/download?path=storage1/a.txt&path=storage1/b.txt") response = self._request("GET", f"/api/files/download/archive/{task['id']}")
self.assertEqual(response.status_code, 409) self.assertEqual(response.status_code, 410)
self.assertEqual(response.json()["error"]["code"], "download_preflight_failed") self.assertEqual(response.json()["error"]["code"], "archive_expired")
self.assertEqual(response.json()["error"]["details"]["reason"], "preflight_timeout") self.assertIsNone(self.task_repo.get_artifact(task["id"]))
self.assertFalse(Path(artifact["file_path"]).exists())
def test_download_path_not_found(self) -> None: def test_archive_prepare_rejects_single_file(self) -> None:
response = self._get("/api/files/download?path=storage1/missing.txt") (self.root / "report.txt").write_text("hello download", encoding="utf-8")
self.assertEqual(response.status_code, 404) response = self._request("POST", "/api/files/download/archive-prepare", {"paths": ["storage1/report.txt"]})
self.assertEqual(response.json()["error"]["code"], "path_not_found")
def test_download_invalid_root_alias(self) -> None: self.assertEqual(response.status_code, 400)
response = self._get("/api/files/download?path=unknown/file.txt") self.assertEqual(response.json()["error"]["code"], "invalid_request")
self.assertEqual(response.status_code, 403) def test_direct_archive_download_route_rejected(self) -> None:
self.assertEqual(response.json()["error"]["code"], "invalid_root_alias") (self.root / "docs").mkdir()
(self.root / "docs" / "a.txt").write_text("a", encoding="utf-8")
def test_download_traversal_blocked(self) -> None: response = self._request("GET", "/api/files/download?path=storage1/docs")
response = self._get("/api/files/download?path=storage1/../etc/passwd")
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()["error"]["code"], "path_traversal_detected") self.assertEqual(response.json()["error"]["code"], "invalid_request")
if __name__ == "__main__": if __name__ == "__main__":
@@ -91,7 +91,7 @@ class HistoryApiGoldenTest(unittest.TestCase):
while time.time() < deadline: while time.time() < deadline:
response = self._request('GET', f'/api/tasks/{task_id}') response = self._request('GET', f'/api/tasks/{task_id}')
body = response.json() body = response.json()
if body['status'] in {'completed', 'failed'}: if body['status'] in {'completed', 'failed', 'ready'}:
return body return body
time.sleep(0.02) time.sleep(0.02)
self.fail('task did not reach terminal state in time') self.fail('task did not reach terminal state in time')
@@ -198,9 +198,10 @@ class HistoryApiGoldenTest(unittest.TestCase):
(self.root1 / 'docs').mkdir() (self.root1 / 'docs').mkdir()
(self.root1 / 'docs' / 'a.txt').write_text('A', encoding='utf-8') (self.root1 / 'docs' / 'a.txt').write_text('A', encoding='utf-8')
response = self._request('GET', '/api/files/download?path=storage1/docs') response = self._request('POST', '/api/files/download/archive-prepare', {'paths': ['storage1/docs']})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 202)
self._wait_task(response.json()['task_id'])
history = self._request('GET', '/api/history').json()['items'] history = self._request('GET', '/api/history').json()['items']
self.assertEqual(history[0]['operation'], 'download') self.assertEqual(history[0]['operation'], 'download')
self.assertEqual(history[0]['status'], 'ready') self.assertEqual(history[0]['status'], 'ready')
@@ -213,9 +214,10 @@ class HistoryApiGoldenTest(unittest.TestCase):
(self.root1 / 'photos').mkdir() (self.root1 / 'photos').mkdir()
(self.root1 / 'photos' / 'img.txt').write_text('P', encoding='utf-8') (self.root1 / 'photos' / 'img.txt').write_text('P', encoding='utf-8')
response = self._request('GET', '/api/files/download?path=storage1/readme.txt&path=storage1/photos') response = self._request('POST', '/api/files/download/archive-prepare', {'paths': ['storage1/readme.txt', 'storage1/photos']})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 202)
self._wait_task(response.json()['task_id'])
history = self._request('GET', '/api/history').json()['items'] history = self._request('GET', '/api/history').json()['items']
self.assertEqual(history[0]['operation'], 'download') self.assertEqual(history[0]['operation'], 'download')
self.assertEqual(history[0]['status'], 'ready') self.assertEqual(history[0]['status'], 'ready')
@@ -229,12 +231,13 @@ class HistoryApiGoldenTest(unittest.TestCase):
(self.root1 / 'docs').mkdir() (self.root1 / 'docs').mkdir()
(self.root1 / 'docs' / 'link.txt').symlink_to(target) (self.root1 / 'docs' / 'link.txt').symlink_to(target)
response = self._request('GET', '/api/files/download?path=storage1/docs') response = self._request('POST', '/api/files/download/archive-prepare', {'paths': ['storage1/docs']})
self.assertEqual(response.status_code, 409) self.assertEqual(response.status_code, 202)
self._wait_task(response.json()['task_id'])
history = self._request('GET', '/api/history').json()['items'] history = self._request('GET', '/api/history').json()['items']
self.assertEqual(history[0]['operation'], 'download') self.assertEqual(history[0]['operation'], 'download')
self.assertEqual(history[0]['status'], 'preflight_failed') self.assertEqual(history[0]['status'], 'failed')
self.assertEqual(history[0]['source'], 'single_directory_zip') self.assertEqual(history[0]['source'], 'single_directory_zip')
self.assertEqual(history[0]['path'], 'storage1/docs') self.assertEqual(history[0]['path'], 'storage1/docs')
self.assertEqual(history[0]['destination'], 'docs.zip') self.assertEqual(history[0]['destination'], 'docs.zip')
@@ -241,6 +241,28 @@ class TasksApiGoldenTest(unittest.TestCase):
self.assertEqual(body["error_code"], "io_error") self.assertEqual(body["error_code"], "io_error")
self.assertEqual(body["error_message"], "write failed") self.assertEqual(body["error_message"], "write failed")
def test_get_task_detail_ready_archive_download(self) -> None:
self._insert_task(
task_id="task-download-ready",
operation="download",
status="ready",
source="storage1/docs",
destination="docs.zip",
created_at="2026-03-10T10:00:00Z",
started_at="2026-03-10T10:00:01Z",
finished_at="2026-03-10T10:00:05Z",
done_items=1,
total_items=1,
)
response = self._get("/api/tasks/task-download-ready")
self.assertEqual(response.status_code, 200)
body = response.json()
self.assertEqual(body["operation"], "download")
self.assertEqual(body["status"], "ready")
self.assertEqual(body["destination"], "docs.zip")
def test_get_task_not_found(self) -> None: def test_get_task_not_found(self) -> None:
response = self._get("/api/tasks/task-missing") response = self._get("/api/tasks/task-missing")
@@ -233,6 +233,10 @@ class UiSmokeGoldenTest(unittest.TestCase):
self.assertIn('function markZipDownloadFailed(err)', app_js) self.assertIn('function markZipDownloadFailed(err)', app_js)
self.assertIn('function closeDownloadModal()', app_js) self.assertIn('function closeDownloadModal()', app_js)
self.assertIn('function zipDownloadRequestKey(paths)', app_js) self.assertIn('function zipDownloadRequestKey(paths)', app_js)
self.assertIn('async function createArchiveDownloadTask(paths)', app_js)
self.assertIn('async function getTaskRequest(taskId)', app_js)
self.assertIn('function startArchiveDownload(taskId, fileName)', app_js)
self.assertIn('async function waitForArchiveDownloadReady(taskId)', app_js)
self.assertIn('function contextMenuElements()', app_js) self.assertIn('function contextMenuElements()', app_js)
self.assertIn('function openContextMenu(pane, entry, event)', app_js) self.assertIn('function openContextMenu(pane, entry, event)', app_js)
self.assertIn('function closeContextMenu()', app_js) self.assertIn('function closeContextMenu()', app_js)
@@ -250,6 +254,9 @@ class UiSmokeGoldenTest(unittest.TestCase):
self.assertIn('statusText: err.message || "Download failed"', app_js) self.assertIn('statusText: err.message || "Download failed"', app_js)
self.assertIn('downloadProgressState.requestKey === requestKey', app_js) self.assertIn('downloadProgressState.requestKey === requestKey', app_js)
self.assertIn('setStatus("Preparing download...");', app_js) self.assertIn('setStatus("Preparing download...");', app_js)
self.assertIn('"/api/files/download/archive-prepare"', app_js)
self.assertIn('`/api/tasks/${encodeURIComponent(taskId)}`', app_js)
self.assertIn('`/api/files/download/archive/${encodeURIComponent(taskId)}`', app_js)
self.assertIn('function applyContextMenuSelection()', app_js) self.assertIn('function applyContextMenuSelection()', app_js)
self.assertIn('function startContextMenuOpen()', app_js) self.assertIn('function startContextMenuOpen()', app_js)
self.assertIn('function startContextMenuEdit()', app_js) self.assertIn('function startContextMenuEdit()', app_js)
@@ -284,7 +291,10 @@ class UiSmokeGoldenTest(unittest.TestCase):
self.assertIn('elements.propertiesButton.disabled = items.length === 0;', app_js) self.assertIn('elements.propertiesButton.disabled = items.length === 0;', app_js)
self.assertIn('openCurrentDirectory();', app_js) self.assertIn('openCurrentDirectory();', app_js)
self.assertIn('openEditor();', app_js) self.assertIn('openEditor();', app_js)
self.assertIn('downloadFileRequest(selectedItems.map((item) => item.path));', app_js) self.assertIn('const created = await createArchiveDownloadTask(selectedPaths);', app_js)
self.assertIn('const task = await waitForArchiveDownloadReady(created.task_id);', app_js)
self.assertIn('startArchiveDownload(task.id, task.destination);', app_js)
self.assertIn('const { blob, fileName } = await downloadFileRequest(selectedPaths);', app_js)
self.assertIn('anchor.download = fileName || selected.name;', app_js) self.assertIn('anchor.download = fileName || selected.name;', app_js)
self.assertIn('openRenamePopup();', app_js) self.assertIn('openRenamePopup();', app_js)
self.assertIn('startCopySelected();', app_js) self.assertIn('startCopySelected();', app_js)
@@ -59,6 +59,27 @@ class TaskRepositoryTest(unittest.TestCase):
} }
) )
def test_create_download_task_with_requested_status_and_artifact(self) -> None:
created = self.repo.create_task(
operation="download",
source="storage1/docs",
destination="docs.zip",
status="requested",
)
self.repo.upsert_artifact(
task_id=created["id"],
file_path="/tmp/archive.zip",
file_name="docs.zip",
expires_at="2026-03-10T10:30:00Z",
)
task = self.repo.get_task(created["id"])
artifact = self.repo.get_artifact(created["id"])
self.assertEqual(task["operation"], "download")
self.assertEqual(task["status"], "requested")
self.assertEqual(artifact["file_name"], "docs.zip")
def test_migrates_legacy_tasks_schema_missing_source_destination(self) -> None: def test_migrates_legacy_tasks_schema_missing_source_destination(self) -> None:
legacy_db_path = Path(self.temp_dir.name) / "legacy.db" legacy_db_path = Path(self.temp_dir.name) / "legacy.db"
conn = sqlite3.connect(legacy_db_path) conn = sqlite3.connect(legacy_db_path)
+59 -3
View File
@@ -460,6 +460,40 @@ function markZipDownloadFailed(err) {
}); });
} }
function updateZipDownloadTaskProgress(task) {
if (!downloadProgressState.active) {
return;
}
updateDownloadModalDisplay({
active: true,
targetText: "Preparing download...",
currentFileText: task.current_item ? `Current: ${task.current_item}` : `Selection: ${selectedItemCountLabel(downloadProgressState.totalItems)}`,
countText: task.total_items ? `${task.done_items || 0}/${task.total_items} top-level items` : "Preparing zip download",
statusText: task.status === "ready" ? "Download started" : "Preparing download...",
percent: task.status === "ready" ? 100 : 55,
});
}
function sleep(ms) {
return new Promise((resolve) => window.setTimeout(resolve, ms));
}
async function waitForArchiveDownloadReady(taskId) {
while (true) {
const task = await getTaskRequest(taskId);
if (task.status === "ready") {
return task;
}
if (task.status === "failed") {
const err = new Error(task.error_message || "Archive download failed");
err.code = task.error_code || null;
throw err;
}
updateZipDownloadTaskProgress(task);
await sleep(250);
}
}
function closeDownloadModal() { function closeDownloadModal() {
if (downloadProgressState.active) { if (downloadProgressState.active) {
return; return;
@@ -643,15 +677,20 @@ async function startDownloadSelected() {
} }
try { try {
const selected = selectedItems[0]; const selected = selectedItems[0];
if (zipDownload) {
const created = await createArchiveDownloadTask(selectedPaths);
const task = await waitForArchiveDownloadReady(created.task_id);
startArchiveDownload(task.id, task.destination);
markZipDownloadReady(task.destination);
setStatus(`Download started: ${task.destination}`);
return;
}
const { blob, fileName } = await downloadFileRequest(selectedPaths); const { blob, fileName } = await downloadFileRequest(selectedPaths);
const url = URL.createObjectURL(blob); const url = URL.createObjectURL(blob);
const anchor = document.createElement("a"); const anchor = document.createElement("a");
anchor.href = url; anchor.href = url;
anchor.download = fileName || selected.name; anchor.download = fileName || selected.name;
document.body.append(anchor); document.body.append(anchor);
if (zipDownload) {
markZipDownloadReady(anchor.download);
}
anchor.click(); anchor.click();
anchor.remove(); anchor.remove();
URL.revokeObjectURL(url); URL.revokeObjectURL(url);
@@ -957,6 +996,23 @@ async function downloadFileRequest(paths) {
}; };
} }
async function createArchiveDownloadTask(paths) {
return apiRequest("POST", "/api/files/download/archive-prepare", { paths });
}
async function getTaskRequest(taskId) {
return apiRequest("GET", `/api/tasks/${encodeURIComponent(taskId)}`);
}
function startArchiveDownload(taskId, fileName) {
const anchor = document.createElement("a");
anchor.href = `/api/files/download/archive/${encodeURIComponent(taskId)}`;
anchor.download = fileName || "";
document.body.append(anchor);
anchor.click();
anchor.remove();
}
async function uploadFileRequest(targetPath, file, overwrite = false) { async function uploadFileRequest(targetPath, file, overwrite = false) {
const formData = new FormData(); const formData = new FormData();
formData.append("target_path", targetPath); formData.append("target_path", targetPath);