from __future__ import annotations import sqlite3 import uuid from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path VALID_STATUSES = {"queued", "running", "completed", "failed", "requested", "preparing", "ready", "cancelled"} VALID_OPERATIONS = {"copy", "move", "download"} TASK_MIGRATION_COLUMNS: dict[str, str] = { "operation": "TEXT NOT NULL DEFAULT 'copy'", "status": "TEXT NOT NULL DEFAULT 'queued'", "source": "TEXT NOT NULL DEFAULT ''", "destination": "TEXT NOT NULL DEFAULT ''", "done_bytes": "INTEGER NULL", "total_bytes": "INTEGER NULL", "done_items": "INTEGER NULL", "total_items": "INTEGER NULL", "current_item": "TEXT NULL", "failed_item": "TEXT NULL", "error_code": "TEXT NULL", "error_message": "TEXT NULL", "created_at": "TEXT NOT NULL", "started_at": "TEXT NULL", "finished_at": "TEXT NULL", } class TaskRepository: def __init__(self, db_path: str): self._db_path = db_path self._ensure_schema() def create_task( self, operation: str, source: str, destination: str, task_id: str | None = None, status: str = "queued", ) -> dict: if operation not in VALID_OPERATIONS: raise ValueError("invalid operation") if status not in VALID_STATUSES: raise ValueError("invalid status") task_id = task_id or str(uuid.uuid4()) created_at = self._now_iso() with self._connection() as conn: conn.execute( """ INSERT INTO tasks ( id, operation, status, source, destination, done_bytes, total_bytes, done_items, total_items, current_item, failed_item, error_code, error_message, created_at, started_at, finished_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( task_id, operation, status, source, destination, None, None, None, None, None, None, None, None, created_at, None, None, ), ) row = conn.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)).fetchone() return self._to_dict(row) def insert_task_for_testing(self, task: dict) -> None: status = task["status"] operation = task["operation"] if status not in VALID_STATUSES: raise ValueError("invalid status") if operation not in VALID_OPERATIONS: raise ValueError("invalid operation") with self._connection() as conn: conn.execute( """ INSERT INTO tasks ( id, operation, status, source, destination, done_bytes, total_bytes, done_items, total_items, current_item, failed_item, error_code, error_message, created_at, started_at, finished_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( task["id"], operation, status, task["source"], task["destination"], task.get("done_bytes"), task.get("total_bytes"), task.get("done_items"), task.get("total_items"), task.get("current_item"), task.get("failed_item"), task.get("error_code"), task.get("error_message"), task["created_at"], task.get("started_at"), task.get("finished_at"), ), ) def get_task(self, task_id: str) -> dict | None: with self._connection() as conn: row = conn.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)).fetchone() return self._to_dict(row) if row else None def list_tasks(self) -> list[dict]: with self._connection() as conn: rows = conn.execute( """ SELECT * FROM tasks ORDER BY created_at DESC """ ).fetchall() return [self._to_dict(row) for row in rows] def mark_running( self, task_id: str, done_bytes: int | None = None, total_bytes: int | None = None, done_items: int | None = None, total_items: int | None = None, current_item: str | None = None, ) -> None: started_at = self._now_iso() with self._connection() as conn: conn.execute( """ UPDATE tasks SET status = ?, started_at = ?, done_bytes = ?, total_bytes = ?, done_items = ?, total_items = ?, current_item = ? WHERE id = ? """, ("running", started_at, done_bytes, total_bytes, done_items, total_items, current_item, task_id), ) def mark_preparing( self, task_id: str, done_items: int | None = None, total_items: int | None = None, current_item: str | None = None, ) -> bool: started_at = self._now_iso() with self._connection() as conn: cursor = conn.execute( """ UPDATE tasks SET status = ?, started_at = COALESCE(started_at, ?), done_items = ?, total_items = ?, current_item = ? WHERE id = ? AND status = ? """, ("preparing", started_at, done_items, total_items, current_item, task_id, "requested"), ) return cursor.rowcount > 0 def update_progress( self, task_id: str, done_bytes: int | None = None, total_bytes: int | None = None, done_items: int | None = None, total_items: int | None = None, current_item: str | None = None, ) -> None: with self._connection() as conn: conn.execute( """ UPDATE tasks SET done_bytes = ?, total_bytes = ?, done_items = ?, total_items = ?, current_item = ? WHERE id = ? """, (done_bytes, total_bytes, done_items, total_items, current_item, task_id), ) def mark_completed( self, task_id: str, done_bytes: int | None = None, total_bytes: int | None = None, done_items: int | None = None, total_items: int | None = None, ) -> None: finished_at = self._now_iso() with self._connection() as conn: conn.execute( """ UPDATE tasks SET status = ?, finished_at = ?, done_bytes = ?, total_bytes = ?, done_items = ?, total_items = ? WHERE id = ? """, ("completed", finished_at, done_bytes, total_bytes, done_items, total_items, task_id), ) def mark_ready( self, task_id: str, done_items: int | None = None, total_items: int | None = None, ) -> bool: finished_at = self._now_iso() with self._connection() as conn: cursor = conn.execute( """ UPDATE tasks SET status = ?, finished_at = ?, done_items = ?, total_items = ?, current_item = NULL WHERE id = ? AND status = ? """, ("ready", finished_at, done_items, total_items, task_id, "preparing"), ) return cursor.rowcount > 0 def mark_failed( self, task_id: str, error_code: str, error_message: str, failed_item: str | None, done_bytes: int | None, total_bytes: int | None, done_items: int | None = None, total_items: int | None = None, ) -> None: finished_at = self._now_iso() with self._connection() as conn: conn.execute( """ UPDATE tasks SET status = ?, finished_at = ?, error_code = ?, error_message = ?, failed_item = ?, done_bytes = ?, total_bytes = ?, done_items = ?, total_items = ? WHERE id = ? """, ( "failed", finished_at, error_code, error_message, failed_item, done_bytes, total_bytes, done_items, total_items, task_id, ), ) def mark_failed_if_not_cancelled( self, task_id: str, error_code: str, error_message: str, failed_item: str | None, done_bytes: int | None, total_bytes: int | None, done_items: int | None = None, total_items: int | None = None, ) -> bool: finished_at = self._now_iso() with self._connection() as conn: cursor = conn.execute( """ UPDATE tasks SET status = ?, finished_at = ?, error_code = ?, error_message = ?, failed_item = ?, done_bytes = ?, total_bytes = ?, done_items = ?, total_items = ?, current_item = NULL WHERE id = ? AND status != ? """, ( "failed", finished_at, error_code, error_message, failed_item, done_bytes, total_bytes, done_items, total_items, task_id, "cancelled", ), ) return cursor.rowcount > 0 def mark_cancelled(self, task_id: str) -> bool: finished_at = self._now_iso() with self._connection() as conn: cursor = conn.execute( """ UPDATE tasks SET status = ?, finished_at = ?, current_item = NULL WHERE id = ? AND status IN (?, ?) """, ("cancelled", finished_at, task_id, "requested", "preparing"), ) return cursor.rowcount > 0 def _ensure_schema(self) -> None: db_path = Path(self._db_path) if db_path.parent and str(db_path.parent) not in {"", "."}: db_path.parent.mkdir(parents=True, exist_ok=True) with self._connection() as conn: conn.execute( """ CREATE TABLE IF NOT EXISTS tasks ( id TEXT PRIMARY KEY, operation TEXT NOT NULL, status TEXT NOT NULL, source TEXT NOT NULL, destination TEXT NOT NULL, done_bytes INTEGER NULL, total_bytes INTEGER NULL, done_items INTEGER NULL, total_items INTEGER NULL, current_item TEXT NULL, failed_item TEXT NULL, error_code TEXT NULL, error_message TEXT NULL, created_at TEXT NOT NULL, started_at TEXT NULL, finished_at TEXT NULL ) """ ) conn.execute( """ CREATE TABLE IF NOT EXISTS task_artifacts ( task_id TEXT PRIMARY KEY, file_path TEXT NOT NULL, file_name TEXT NOT NULL, expires_at TEXT NOT NULL, created_at TEXT NOT NULL ) """ ) conn.execute( """ CREATE INDEX IF NOT EXISTS idx_tasks_created_at_desc ON tasks(created_at DESC) """ ) conn.execute( """ CREATE INDEX IF NOT EXISTS idx_task_artifacts_expires_at ON task_artifacts(expires_at ASC) """ ) self._migrate_tasks_columns(conn) def upsert_artifact(self, *, task_id: str, file_path: str, file_name: str, expires_at: str) -> dict: created_at = self._now_iso() with self._connection() as conn: conn.execute( """ INSERT INTO task_artifacts (task_id, file_path, file_name, expires_at, created_at) VALUES (?, ?, ?, ?, ?) ON CONFLICT(task_id) DO UPDATE SET file_path = excluded.file_path, file_name = excluded.file_name, expires_at = excluded.expires_at """, (task_id, file_path, file_name, expires_at, created_at), ) row = conn.execute("SELECT * FROM task_artifacts WHERE task_id = ?", (task_id,)).fetchone() return self._artifact_to_dict(row) def get_artifact(self, task_id: str) -> dict | None: with self._connection() as conn: row = conn.execute("SELECT * FROM task_artifacts WHERE task_id = ?", (task_id,)).fetchone() return self._artifact_to_dict(row) if row else None def list_artifacts(self) -> list[dict]: with self._connection() as conn: rows = conn.execute("SELECT * FROM task_artifacts ORDER BY created_at ASC").fetchall() return [self._artifact_to_dict(row) for row in rows] def delete_artifact(self, task_id: str) -> None: with self._connection() as conn: conn.execute("DELETE FROM task_artifacts WHERE task_id = ?", (task_id,)) def _migrate_tasks_columns(self, conn: sqlite3.Connection) -> None: rows = conn.execute("PRAGMA table_info(tasks)").fetchall() existing_columns = {row["name"] for row in rows} for column, ddl in TASK_MIGRATION_COLUMNS.items(): if column in existing_columns: continue conn.execute(f"ALTER TABLE tasks ADD COLUMN {column} {ddl}") def _connect(self) -> sqlite3.Connection: conn = sqlite3.connect(self._db_path) conn.row_factory = sqlite3.Row return conn @contextmanager def _connection(self): conn = self._connect() try: yield conn conn.commit() except Exception: conn.rollback() raise finally: conn.close() @staticmethod def _to_dict(row: sqlite3.Row) -> dict: return { "id": row["id"], "operation": row["operation"], "status": row["status"], "source": row["source"], "destination": row["destination"], "done_bytes": row["done_bytes"], "total_bytes": row["total_bytes"], "done_items": row["done_items"], "total_items": row["total_items"], "current_item": row["current_item"], "failed_item": row["failed_item"], "error_code": row["error_code"], "error_message": row["error_message"], "created_at": row["created_at"], "started_at": row["started_at"], "finished_at": row["finished_at"], } @staticmethod def _artifact_to_dict(row: sqlite3.Row) -> dict: return { "task_id": row["task_id"], "file_path": row["file_path"], "file_name": row["file_name"], "expires_at": row["expires_at"], "created_at": row["created_at"], } @staticmethod def _now_iso() -> str: return datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z")