feat: B2 uit voor veilige archive-downloads

This commit is contained in:
kodi
2026-03-14 14:24:52 +01:00
parent 592b10acc2
commit d463b3977d
24 changed files with 754 additions and 195 deletions
+24 -2
View File
@@ -4,8 +4,9 @@ from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile
from fastapi.responses import StreamingResponse
from starlette.background import BackgroundTask
from backend.app.api.schemas import DeleteRequest, DeleteResponse, FileInfoResponse, MkdirRequest, MkdirResponse, RenameRequest, RenameResponse, SaveRequest, SaveResponse, UploadResponse, ViewResponse
from backend.app.dependencies import get_file_ops_service
from backend.app.api.schemas import ArchivePrepareRequest, DeleteRequest, DeleteResponse, FileInfoResponse, MkdirRequest, MkdirResponse, RenameRequest, RenameResponse, SaveRequest, SaveResponse, TaskCreateResponse, UploadResponse, ViewResponse
from backend.app.dependencies import get_archive_download_task_service, get_file_ops_service
from backend.app.services.archive_download_task_service import ArchiveDownloadTaskService
from backend.app.services.file_ops_service import FileOpsService
router = APIRouter(prefix="/files")
@@ -78,6 +79,27 @@ async def download(
return response
@router.post("/download/archive-prepare", response_model=TaskCreateResponse, status_code=202)
async def archive_prepare(
request: ArchivePrepareRequest,
service: ArchiveDownloadTaskService = Depends(get_archive_download_task_service),
) -> TaskCreateResponse:
return service.create_archive_prepare_task(paths=request.paths)
@router.get("/download/archive/{task_id}")
async def archive_download(
task_id: str,
service: ArchiveDownloadTaskService = Depends(get_archive_download_task_service),
) -> StreamingResponse:
prepared = service.prepare_ready_archive_download(task_id=task_id)
return StreamingResponse(
prepared["content"],
headers=prepared["headers"],
media_type=prepared["content_type"],
)
@router.get("/video")
async def video(
path: str,
+4
View File
@@ -88,6 +88,10 @@ class SaveResponse(BaseModel):
modified: str
class ArchivePrepareRequest(BaseModel):
paths: list[str]
class FileInfoResponse(BaseModel):
name: str
path: str
+106 -4
View File
@@ -6,8 +6,8 @@ from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
VALID_STATUSES = {"queued", "running", "completed", "failed"}
VALID_OPERATIONS = {"copy", "move"}
VALID_STATUSES = {"queued", "running", "completed", "failed", "requested", "preparing", "ready"}
VALID_OPERATIONS = {"copy", "move", "download"}
TASK_MIGRATION_COLUMNS: dict[str, str] = {
"operation": "TEXT NOT NULL DEFAULT 'copy'",
"status": "TEXT NOT NULL DEFAULT 'queued'",
@@ -32,9 +32,18 @@ class TaskRepository:
self._db_path = db_path
self._ensure_schema()
def create_task(self, operation: str, source: str, destination: str, task_id: str | None = None) -> dict:
def create_task(
self,
operation: str,
source: str,
destination: str,
task_id: str | None = None,
status: str = "queued",
) -> dict:
if operation not in VALID_OPERATIONS:
raise ValueError("invalid operation")
if status not in VALID_STATUSES:
raise ValueError("invalid status")
task_id = task_id or str(uuid.uuid4())
created_at = self._now_iso()
@@ -52,7 +61,7 @@ class TaskRepository:
(
task_id,
operation,
"queued",
status,
source,
destination,
None,
@@ -145,6 +154,24 @@ class TaskRepository:
("running", started_at, done_bytes, total_bytes, done_items, total_items, current_item, task_id),
)
def mark_preparing(
self,
task_id: str,
done_items: int | None = None,
total_items: int | None = None,
current_item: str | None = None,
) -> None:
started_at = self._now_iso()
with self._connection() as conn:
conn.execute(
"""
UPDATE tasks
SET status = ?, started_at = COALESCE(started_at, ?), done_items = ?, total_items = ?, current_item = ?
WHERE id = ?
""",
("preparing", started_at, done_items, total_items, current_item, task_id),
)
def update_progress(
self,
task_id: str,
@@ -183,6 +210,23 @@ class TaskRepository:
("completed", finished_at, done_bytes, total_bytes, done_items, total_items, task_id),
)
def mark_ready(
self,
task_id: str,
done_items: int | None = None,
total_items: int | None = None,
) -> None:
finished_at = self._now_iso()
with self._connection() as conn:
conn.execute(
"""
UPDATE tasks
SET status = ?, finished_at = ?, done_items = ?, total_items = ?, current_item = NULL
WHERE id = ?
""",
("ready", finished_at, done_items, total_items, task_id),
)
def mark_failed(
self,
task_id: str,
@@ -244,14 +288,62 @@ class TaskRepository:
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS task_artifacts (
task_id TEXT PRIMARY KEY,
file_path TEXT NOT NULL,
file_name TEXT NOT NULL,
expires_at TEXT NOT NULL,
created_at TEXT NOT NULL
)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_tasks_created_at_desc
ON tasks(created_at DESC)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_task_artifacts_expires_at
ON task_artifacts(expires_at ASC)
"""
)
self._migrate_tasks_columns(conn)
def upsert_artifact(self, *, task_id: str, file_path: str, file_name: str, expires_at: str) -> dict:
created_at = self._now_iso()
with self._connection() as conn:
conn.execute(
"""
INSERT INTO task_artifacts (task_id, file_path, file_name, expires_at, created_at)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(task_id) DO UPDATE SET
file_path = excluded.file_path,
file_name = excluded.file_name,
expires_at = excluded.expires_at
""",
(task_id, file_path, file_name, expires_at, created_at),
)
row = conn.execute("SELECT * FROM task_artifacts WHERE task_id = ?", (task_id,)).fetchone()
return self._artifact_to_dict(row)
def get_artifact(self, task_id: str) -> dict | None:
with self._connection() as conn:
row = conn.execute("SELECT * FROM task_artifacts WHERE task_id = ?", (task_id,)).fetchone()
return self._artifact_to_dict(row) if row else None
def list_artifacts(self) -> list[dict]:
with self._connection() as conn:
rows = conn.execute("SELECT * FROM task_artifacts ORDER BY created_at ASC").fetchall()
return [self._artifact_to_dict(row) for row in rows]
def delete_artifact(self, task_id: str) -> None:
with self._connection() as conn:
conn.execute("DELETE FROM task_artifacts WHERE task_id = ?", (task_id,))
def _migrate_tasks_columns(self, conn: sqlite3.Connection) -> None:
rows = conn.execute("PRAGMA table_info(tasks)").fetchall()
existing_columns = {row["name"] for row in rows}
@@ -298,6 +390,16 @@ class TaskRepository:
"finished_at": row["finished_at"],
}
@staticmethod
def _artifact_to_dict(row: sqlite3.Row) -> dict:
return {
"task_id": row["task_id"],
"file_path": row["file_path"],
"file_name": row["file_name"],
"expires_at": row["expires_at"],
"created_at": row["created_at"],
}
@staticmethod
def _now_iso() -> str:
return datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z")
+23
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
from functools import lru_cache
from pathlib import Path
from backend.app.config import Settings, get_settings
from backend.app.db.bookmark_repository import BookmarkRepository
@@ -12,6 +13,7 @@ from backend.app.security.path_guard import PathGuard
from backend.app.services.bookmark_service import BookmarkService
from backend.app.services.browse_service import BrowseService
from backend.app.services.copy_task_service import CopyTaskService
from backend.app.services.archive_download_task_service import ArchiveDownloadTaskService
from backend.app.services.file_ops_service import FileOpsService
from backend.app.services.history_service import HistoryService
from backend.app.services.move_task_service import MoveTaskService
@@ -64,6 +66,12 @@ def get_task_runner() -> TaskRunner:
)
@lru_cache(maxsize=1)
def get_archive_artifact_root() -> str:
settings: Settings = get_settings()
return str(Path(settings.task_db_path).resolve().parent / "archive_tmp")
async def get_browse_service() -> BrowseService:
return BrowseService(path_guard=get_path_guard(), filesystem=get_filesystem_adapter())
@@ -76,6 +84,21 @@ async def get_file_ops_service() -> FileOpsService:
)
async def get_archive_download_task_service() -> ArchiveDownloadTaskService:
return ArchiveDownloadTaskService(
path_guard=get_path_guard(),
repository=get_task_repository(),
runner=get_task_runner(),
history_repository=get_history_repository(),
file_ops_service=FileOpsService(
path_guard=get_path_guard(),
filesystem=get_filesystem_adapter(),
history_repository=get_history_repository(),
),
artifact_root=Path(get_archive_artifact_root()),
)
async def get_task_service() -> TaskService:
return TaskService(repository=get_task_repository())
@@ -0,0 +1,266 @@
from __future__ import annotations
import os
import uuid
import zipfile
from datetime import datetime, timedelta, timezone
from pathlib import Path
from backend.app.api.errors import AppError
from backend.app.api.schemas import TaskCreateResponse
from backend.app.db.history_repository import HistoryRepository
from backend.app.db.task_repository import TaskRepository
from backend.app.security.path_guard import PathGuard
from backend.app.services.file_ops_service import FileOpsService
from backend.app.tasks_runner import TaskRunner
ARCHIVE_DOWNLOAD_TTL_SECONDS = 30 * 60
class ArchiveDownloadTaskService:
def __init__(
self,
path_guard: PathGuard,
repository: TaskRepository,
runner: TaskRunner,
history_repository: HistoryRepository | None,
file_ops_service: FileOpsService,
artifact_root: Path,
artifact_ttl_seconds: int = ARCHIVE_DOWNLOAD_TTL_SECONDS,
):
self._path_guard = path_guard
self._repository = repository
self._runner = runner
self._history_repository = history_repository
self._file_ops_service = file_ops_service
self._artifact_root = artifact_root
self._artifact_ttl_seconds = artifact_ttl_seconds
self._artifact_root.mkdir(parents=True, exist_ok=True)
self.sweep_artifacts()
def create_archive_prepare_task(self, paths: list[str]) -> TaskCreateResponse:
if not paths:
raise AppError(
code="invalid_request",
message="At least one path is required",
status_code=400,
)
self.sweep_artifacts()
resolved_targets = [self._path_guard.resolve_existing_path(path) for path in paths]
mode = self._file_ops_service._download_mode_from_resolved_targets(resolved_targets)
if mode == "single_file":
raise AppError(
code="invalid_request",
message="Single file downloads must use direct download",
status_code=400,
)
summary = self._file_ops_service._summarize_download_targets([target.relative for target in resolved_targets])
archive_name = self._file_ops_service._download_name_for_targets(resolved_targets)
task_id = str(uuid.uuid4())
task = self._repository.create_task(
operation="download",
source=summary,
destination=archive_name,
task_id=task_id,
status="requested",
)
self._record_history(
entry_id=task_id,
operation="download",
status="requested",
source=mode,
destination=archive_name,
path=summary,
)
target_paths = [target.relative for target in resolved_targets]
self._runner.enqueue_archive_prepare(
lambda: self._run_archive_prepare_task(
task_id=task_id,
target_paths=target_paths,
archive_name=archive_name,
history_mode=mode,
history_path=summary,
)
)
return TaskCreateResponse(task_id=task["id"], status=task["status"])
def prepare_ready_archive_download(self, task_id: str) -> dict:
self.sweep_artifacts()
task = self._repository.get_task(task_id)
if not task:
raise AppError(
code="task_not_found",
message="Task was not found",
status_code=404,
details={"task_id": task_id},
)
if task["operation"] != "download":
raise AppError(
code="invalid_request",
message="Task is not an archive download",
status_code=400,
details={"task_id": task_id},
)
if task["status"] != "ready":
raise AppError(
code="download_not_ready",
message="Archive download is not ready",
status_code=409,
details={"task_id": task_id, "status": task["status"]},
)
artifact = self._repository.get_artifact(task_id)
if not artifact:
raise AppError(
code="archive_not_found",
message="Prepared archive was not found",
status_code=404,
details={"task_id": task_id},
)
if self._is_expired(artifact["expires_at"]):
self._delete_artifact_record_and_file(task_id, artifact["file_path"])
raise AppError(
code="archive_expired",
message="Prepared archive expired",
status_code=410,
details={"task_id": task_id},
)
artifact_path = Path(artifact["file_path"])
if not artifact_path.exists():
self._repository.delete_artifact(task_id)
raise AppError(
code="archive_not_found",
message="Prepared archive was not found",
status_code=404,
details={"task_id": task_id},
)
return {
"content": self._file_ops_service._filesystem.stream_file(artifact_path),
"headers": {
"Content-Disposition": f'attachment; filename="{artifact["file_name"]}"',
"Content-Length": str(int(artifact_path.stat().st_size)),
},
"content_type": "application/zip",
}
def sweep_artifacts(self) -> None:
self._artifact_root.mkdir(parents=True, exist_ok=True)
referenced_paths: set[Path] = set()
for artifact in self._repository.list_artifacts():
artifact_path = Path(artifact["file_path"])
referenced_paths.add(artifact_path)
if self._is_expired(artifact["expires_at"]) or not artifact_path.exists():
self._delete_artifact_record_and_file(artifact["task_id"], artifact["file_path"])
for candidate in self._artifact_root.iterdir():
if candidate.is_file() and candidate not in referenced_paths:
try:
candidate.unlink()
except FileNotFoundError:
pass
def _run_archive_prepare_task(
self,
*,
task_id: str,
target_paths: list[str],
archive_name: str,
history_mode: str,
history_path: str,
) -> None:
partial_path = self._artifact_root / f"{task_id}.partial.zip"
final_path = self._artifact_root / f"{task_id}.zip"
total_items = len(target_paths)
try:
self._repository.mark_preparing(
task_id=task_id,
done_items=0,
total_items=total_items,
current_item=target_paths[0] if target_paths else None,
)
resolved_targets = [self._path_guard.resolve_existing_path(path) for path in target_paths]
self._file_ops_service._validate_zip_download_archive_names(resolved_targets)
self._file_ops_service._run_zip_download_preflight(resolved_targets)
with zipfile.ZipFile(partial_path, "w", compression=zipfile.ZIP_DEFLATED) as archive:
for resolved_target in resolved_targets:
self._file_ops_service._write_download_target_to_zip(archive, resolved_target)
os.replace(partial_path, final_path)
self._repository.upsert_artifact(
task_id=task_id,
file_path=str(final_path),
file_name=archive_name,
expires_at=self._expires_at_iso(),
)
self._repository.mark_ready(
task_id=task_id,
done_items=total_items,
total_items=total_items,
)
self._update_history_ready(task_id)
except AppError as exc:
self._delete_artifact_record_and_file(task_id, str(partial_path))
self._delete_artifact_record_and_file(task_id, str(final_path))
self._repository.mark_failed(
task_id=task_id,
error_code=exc.code,
error_message=exc.message,
failed_item=history_path,
done_bytes=None,
total_bytes=None,
done_items=0,
total_items=total_items,
)
self._update_history_failed(task_id, exc.code, exc.message)
except OSError as exc:
self._delete_artifact_record_and_file(task_id, str(partial_path))
self._delete_artifact_record_and_file(task_id, str(final_path))
self._repository.mark_failed(
task_id=task_id,
error_code="io_error",
error_message=str(exc),
failed_item=history_path,
done_bytes=None,
total_bytes=None,
done_items=0,
total_items=total_items,
)
self._update_history_failed(task_id, "io_error", str(exc))
def _delete_artifact_record_and_file(self, task_id: str, file_path: str) -> None:
self._repository.delete_artifact(task_id)
path = Path(file_path)
try:
path.unlink()
except FileNotFoundError:
pass
def _update_history_ready(self, task_id: str) -> None:
if self._history_repository:
self._history_repository.update_entry(entry_id=task_id, status="ready")
def _update_history_failed(self, task_id: str, error_code: str, error_message: str) -> None:
if self._history_repository:
self._history_repository.update_entry(
entry_id=task_id,
status="failed",
error_code=error_code,
error_message=error_message,
)
def _record_history(self, **kwargs) -> None:
if self._history_repository:
self._history_repository.create_entry(**kwargs)
def _expires_at_iso(self) -> str:
return (datetime.now(timezone.utc) + timedelta(seconds=self._artifact_ttl_seconds)).replace(microsecond=0).isoformat().replace("+00:00", "Z")
@staticmethod
def _is_expired(expires_at: str) -> bool:
return datetime.now(timezone.utc) >= datetime.fromisoformat(expires_at.replace("Z", "+00:00"))
+22 -14
View File
@@ -411,6 +411,14 @@ class FileOpsService:
history_mode = self._download_mode_from_resolved_targets(resolved_targets)
history_path = self._summarize_download_targets([target.relative for target in resolved_targets])
history_download_name = self._download_name_for_targets(resolved_targets)
if history_mode != "single_file":
raise AppError(
code="invalid_request",
message="Archive downloads must be prepared first",
status_code=400,
)
history_entry_id = self._record_download_status(
status="requested",
mode=history_mode,
@@ -418,10 +426,7 @@ class FileOpsService:
download_name=history_download_name,
)
if len(resolved_targets) == 1 and resolved_targets[0].absolute.is_file():
prepared = self._prepare_single_file_download(resolved_targets[0])
else:
prepared = self._prepare_zip_download(resolved_targets, history_download_name)
prepared = self._prepare_single_file_download(resolved_targets[0])
self._record_download_status(
status="ready",
@@ -757,16 +762,7 @@ class FileOpsService:
}
def _prepare_zip_download(self, resolved_targets: list, download_name: str) -> dict:
archive_names: set[str] = set()
for resolved_target in resolved_targets:
archive_name = resolved_target.absolute.name
if archive_name in archive_names:
raise AppError(
code="invalid_request",
message="Selected items must have distinct top-level names",
status_code=400,
)
archive_names.add(archive_name)
self._validate_zip_download_archive_names(resolved_targets)
self._run_zip_download_preflight(resolved_targets)
buffer = BytesIO()
@@ -786,6 +782,18 @@ class FileOpsService:
"content_type": "application/zip",
}
def _validate_zip_download_archive_names(self, resolved_targets: list) -> None:
archive_names: set[str] = set()
for resolved_target in resolved_targets:
archive_name = resolved_target.absolute.name
if archive_name in archive_names:
raise AppError(
code="invalid_request",
message="Selected items must have distinct top-level names",
status_code=400,
)
archive_names.add(archive_name)
def _download_name_for_targets(self, resolved_targets: list) -> str:
if len(resolved_targets) == 1 and resolved_targets[0].absolute.is_file():
return resolved_targets[0].absolute.name
+7
View File
@@ -69,6 +69,13 @@ class TaskRunner:
)
thread.start()
def enqueue_archive_prepare(self, worker) -> None:
thread = threading.Thread(
target=worker,
daemon=True,
)
thread.start()
def _run_copy_file(self, task_id: str, source: str, destination: str, total_bytes: int) -> None:
self._repository.mark_running(
task_id=task_id,