495 lines
17 KiB
Python
495 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
import sqlite3
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
VALID_STATUSES = {"queued", "running", "completed", "failed", "requested", "preparing", "ready", "cancelled"}
|
|
VALID_OPERATIONS = {"copy", "move", "download", "duplicate"}
|
|
NON_TERMINAL_STATUSES = ("queued", "running", "requested", "preparing")
|
|
TASK_MIGRATION_COLUMNS: dict[str, str] = {
|
|
"operation": "TEXT NOT NULL DEFAULT 'copy'",
|
|
"status": "TEXT NOT NULL DEFAULT 'queued'",
|
|
"source": "TEXT NOT NULL DEFAULT ''",
|
|
"destination": "TEXT NOT NULL DEFAULT ''",
|
|
"done_bytes": "INTEGER NULL",
|
|
"total_bytes": "INTEGER NULL",
|
|
"done_items": "INTEGER NULL",
|
|
"total_items": "INTEGER NULL",
|
|
"current_item": "TEXT NULL",
|
|
"failed_item": "TEXT NULL",
|
|
"error_code": "TEXT NULL",
|
|
"error_message": "TEXT NULL",
|
|
"created_at": "TEXT NOT NULL",
|
|
"started_at": "TEXT NULL",
|
|
"finished_at": "TEXT NULL",
|
|
}
|
|
|
|
|
|
class TaskRepository:
|
|
def __init__(self, db_path: str):
|
|
self._db_path = db_path
|
|
self._ensure_schema()
|
|
|
|
def create_task(
|
|
self,
|
|
operation: str,
|
|
source: str,
|
|
destination: str,
|
|
task_id: str | None = None,
|
|
status: str = "queued",
|
|
) -> dict:
|
|
if operation not in VALID_OPERATIONS:
|
|
raise ValueError("invalid operation")
|
|
if status not in VALID_STATUSES:
|
|
raise ValueError("invalid status")
|
|
|
|
task_id = task_id or str(uuid.uuid4())
|
|
created_at = self._now_iso()
|
|
|
|
with self._connection() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO tasks (
|
|
id, operation, status, source, destination,
|
|
done_bytes, total_bytes, done_items, total_items,
|
|
current_item, failed_item, error_code, error_message,
|
|
created_at, started_at, finished_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
task_id,
|
|
operation,
|
|
status,
|
|
source,
|
|
destination,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
created_at,
|
|
None,
|
|
None,
|
|
),
|
|
)
|
|
row = conn.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)).fetchone()
|
|
|
|
return self._to_dict(row)
|
|
|
|
def insert_task_for_testing(self, task: dict) -> None:
|
|
status = task["status"]
|
|
operation = task["operation"]
|
|
if status not in VALID_STATUSES:
|
|
raise ValueError("invalid status")
|
|
if operation not in VALID_OPERATIONS:
|
|
raise ValueError("invalid operation")
|
|
|
|
with self._connection() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO tasks (
|
|
id, operation, status, source, destination,
|
|
done_bytes, total_bytes, done_items, total_items,
|
|
current_item, failed_item, error_code, error_message,
|
|
created_at, started_at, finished_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
task["id"],
|
|
operation,
|
|
status,
|
|
task["source"],
|
|
task["destination"],
|
|
task.get("done_bytes"),
|
|
task.get("total_bytes"),
|
|
task.get("done_items"),
|
|
task.get("total_items"),
|
|
task.get("current_item"),
|
|
task.get("failed_item"),
|
|
task.get("error_code"),
|
|
task.get("error_message"),
|
|
task["created_at"],
|
|
task.get("started_at"),
|
|
task.get("finished_at"),
|
|
),
|
|
)
|
|
|
|
def get_task(self, task_id: str) -> dict | None:
|
|
with self._connection() as conn:
|
|
row = conn.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)).fetchone()
|
|
return self._to_dict(row) if row else None
|
|
|
|
def list_tasks(self) -> list[dict]:
|
|
with self._connection() as conn:
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT * FROM tasks
|
|
ORDER BY created_at DESC
|
|
"""
|
|
).fetchall()
|
|
return [self._to_dict(row) for row in rows]
|
|
|
|
def mark_running(
|
|
self,
|
|
task_id: str,
|
|
done_bytes: int | None = None,
|
|
total_bytes: int | None = None,
|
|
done_items: int | None = None,
|
|
total_items: int | None = None,
|
|
current_item: str | None = None,
|
|
) -> None:
|
|
started_at = self._now_iso()
|
|
with self._connection() as conn:
|
|
conn.execute(
|
|
"""
|
|
UPDATE tasks
|
|
SET status = ?, started_at = ?, done_bytes = ?, total_bytes = ?, done_items = ?, total_items = ?, current_item = ?
|
|
WHERE id = ?
|
|
""",
|
|
("running", started_at, done_bytes, total_bytes, done_items, total_items, current_item, task_id),
|
|
)
|
|
|
|
def mark_preparing(
|
|
self,
|
|
task_id: str,
|
|
done_items: int | None = None,
|
|
total_items: int | None = None,
|
|
current_item: str | None = None,
|
|
) -> bool:
|
|
started_at = self._now_iso()
|
|
with self._connection() as conn:
|
|
cursor = conn.execute(
|
|
"""
|
|
UPDATE tasks
|
|
SET status = ?, started_at = COALESCE(started_at, ?), done_items = ?, total_items = ?, current_item = ?
|
|
WHERE id = ? AND status = ?
|
|
""",
|
|
("preparing", started_at, done_items, total_items, current_item, task_id, "requested"),
|
|
)
|
|
return cursor.rowcount > 0
|
|
|
|
def update_progress(
|
|
self,
|
|
task_id: str,
|
|
done_bytes: int | None = None,
|
|
total_bytes: int | None = None,
|
|
done_items: int | None = None,
|
|
total_items: int | None = None,
|
|
current_item: str | None = None,
|
|
) -> None:
|
|
with self._connection() as conn:
|
|
conn.execute(
|
|
"""
|
|
UPDATE tasks
|
|
SET done_bytes = ?, total_bytes = ?, done_items = ?, total_items = ?, current_item = ?
|
|
WHERE id = ?
|
|
""",
|
|
(done_bytes, total_bytes, done_items, total_items, current_item, task_id),
|
|
)
|
|
|
|
def mark_completed(
|
|
self,
|
|
task_id: str,
|
|
done_bytes: int | None = None,
|
|
total_bytes: int | None = None,
|
|
done_items: int | None = None,
|
|
total_items: int | None = None,
|
|
) -> None:
|
|
finished_at = self._now_iso()
|
|
with self._connection() as conn:
|
|
conn.execute(
|
|
"""
|
|
UPDATE tasks
|
|
SET status = ?, finished_at = ?, done_bytes = ?, total_bytes = ?, done_items = ?, total_items = ?
|
|
WHERE id = ?
|
|
""",
|
|
("completed", finished_at, done_bytes, total_bytes, done_items, total_items, task_id),
|
|
)
|
|
|
|
def mark_ready(
|
|
self,
|
|
task_id: str,
|
|
done_items: int | None = None,
|
|
total_items: int | None = None,
|
|
) -> bool:
|
|
finished_at = self._now_iso()
|
|
with self._connection() as conn:
|
|
cursor = conn.execute(
|
|
"""
|
|
UPDATE tasks
|
|
SET status = ?, finished_at = ?, done_items = ?, total_items = ?, current_item = NULL
|
|
WHERE id = ? AND status = ?
|
|
""",
|
|
("ready", finished_at, done_items, total_items, task_id, "preparing"),
|
|
)
|
|
return cursor.rowcount > 0
|
|
|
|
def mark_failed(
|
|
self,
|
|
task_id: str,
|
|
error_code: str,
|
|
error_message: str,
|
|
failed_item: str | None,
|
|
done_bytes: int | None,
|
|
total_bytes: int | None,
|
|
done_items: int | None = None,
|
|
total_items: int | None = None,
|
|
) -> None:
|
|
finished_at = self._now_iso()
|
|
with self._connection() as conn:
|
|
conn.execute(
|
|
"""
|
|
UPDATE tasks
|
|
SET status = ?, finished_at = ?, error_code = ?, error_message = ?, failed_item = ?, done_bytes = ?, total_bytes = ?, done_items = ?, total_items = ?
|
|
WHERE id = ?
|
|
""",
|
|
(
|
|
"failed",
|
|
finished_at,
|
|
error_code,
|
|
error_message,
|
|
failed_item,
|
|
done_bytes,
|
|
total_bytes,
|
|
done_items,
|
|
total_items,
|
|
task_id,
|
|
),
|
|
)
|
|
|
|
def mark_failed_if_not_cancelled(
|
|
self,
|
|
task_id: str,
|
|
error_code: str,
|
|
error_message: str,
|
|
failed_item: str | None,
|
|
done_bytes: int | None,
|
|
total_bytes: int | None,
|
|
done_items: int | None = None,
|
|
total_items: int | None = None,
|
|
) -> bool:
|
|
finished_at = self._now_iso()
|
|
with self._connection() as conn:
|
|
cursor = conn.execute(
|
|
"""
|
|
UPDATE tasks
|
|
SET status = ?, finished_at = ?, error_code = ?, error_message = ?, failed_item = ?, done_bytes = ?, total_bytes = ?, done_items = ?, total_items = ?, current_item = NULL
|
|
WHERE id = ? AND status != ?
|
|
""",
|
|
(
|
|
"failed",
|
|
finished_at,
|
|
error_code,
|
|
error_message,
|
|
failed_item,
|
|
done_bytes,
|
|
total_bytes,
|
|
done_items,
|
|
total_items,
|
|
task_id,
|
|
"cancelled",
|
|
),
|
|
)
|
|
return cursor.rowcount > 0
|
|
|
|
def mark_cancelled(self, task_id: str) -> bool:
|
|
finished_at = self._now_iso()
|
|
with self._connection() as conn:
|
|
cursor = conn.execute(
|
|
"""
|
|
UPDATE tasks
|
|
SET status = ?, finished_at = ?, current_item = NULL
|
|
WHERE id = ? AND status IN (?, ?)
|
|
""",
|
|
("cancelled", finished_at, task_id, "requested", "preparing"),
|
|
)
|
|
return cursor.rowcount > 0
|
|
|
|
def _ensure_schema(self) -> None:
|
|
db_path = Path(self._db_path)
|
|
if db_path.parent and str(db_path.parent) not in {"", "."}:
|
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with self._connection() as conn:
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS tasks (
|
|
id TEXT PRIMARY KEY,
|
|
operation TEXT NOT NULL,
|
|
status TEXT NOT NULL,
|
|
source TEXT NOT NULL,
|
|
destination TEXT NOT NULL,
|
|
done_bytes INTEGER NULL,
|
|
total_bytes INTEGER NULL,
|
|
done_items INTEGER NULL,
|
|
total_items INTEGER NULL,
|
|
current_item TEXT NULL,
|
|
failed_item TEXT NULL,
|
|
error_code TEXT NULL,
|
|
error_message TEXT NULL,
|
|
created_at TEXT NOT NULL,
|
|
started_at TEXT NULL,
|
|
finished_at TEXT NULL
|
|
)
|
|
"""
|
|
)
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS task_artifacts (
|
|
task_id TEXT PRIMARY KEY,
|
|
file_path TEXT NOT NULL,
|
|
file_name TEXT NOT NULL,
|
|
expires_at TEXT NOT NULL,
|
|
created_at TEXT NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
conn.execute(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_tasks_created_at_desc
|
|
ON tasks(created_at DESC)
|
|
"""
|
|
)
|
|
conn.execute(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_task_artifacts_expires_at
|
|
ON task_artifacts(expires_at ASC)
|
|
"""
|
|
)
|
|
self._migrate_tasks_columns(conn)
|
|
|
|
def upsert_artifact(self, *, task_id: str, file_path: str, file_name: str, expires_at: str) -> dict:
|
|
created_at = self._now_iso()
|
|
with self._connection() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO task_artifacts (task_id, file_path, file_name, expires_at, created_at)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
ON CONFLICT(task_id) DO UPDATE SET
|
|
file_path = excluded.file_path,
|
|
file_name = excluded.file_name,
|
|
expires_at = excluded.expires_at
|
|
""",
|
|
(task_id, file_path, file_name, expires_at, created_at),
|
|
)
|
|
row = conn.execute("SELECT * FROM task_artifacts WHERE task_id = ?", (task_id,)).fetchone()
|
|
return self._artifact_to_dict(row)
|
|
|
|
def get_artifact(self, task_id: str) -> dict | None:
|
|
with self._connection() as conn:
|
|
row = conn.execute("SELECT * FROM task_artifacts WHERE task_id = ?", (task_id,)).fetchone()
|
|
return self._artifact_to_dict(row) if row else None
|
|
|
|
def list_artifacts(self) -> list[dict]:
|
|
with self._connection() as conn:
|
|
rows = conn.execute("SELECT * FROM task_artifacts ORDER BY created_at ASC").fetchall()
|
|
return [self._artifact_to_dict(row) for row in rows]
|
|
|
|
def delete_artifact(self, task_id: str) -> None:
|
|
with self._connection() as conn:
|
|
conn.execute("DELETE FROM task_artifacts WHERE task_id = ?", (task_id,))
|
|
|
|
def reconcile_incomplete_tasks(
|
|
self,
|
|
*,
|
|
error_code: str = "task_interrupted",
|
|
error_message: str = "Task was interrupted before completion",
|
|
) -> list[str]:
|
|
finished_at = self._now_iso()
|
|
placeholders = ", ".join("?" for _ in NON_TERMINAL_STATUSES)
|
|
with self._connection() as conn:
|
|
rows = conn.execute(
|
|
f"""
|
|
SELECT id
|
|
FROM tasks
|
|
WHERE status IN ({placeholders})
|
|
""",
|
|
NON_TERMINAL_STATUSES,
|
|
).fetchall()
|
|
task_ids = [row["id"] for row in rows]
|
|
if not task_ids:
|
|
return []
|
|
task_placeholders = ", ".join("?" for _ in task_ids)
|
|
conn.execute(
|
|
f"""
|
|
UPDATE tasks
|
|
SET status = ?, finished_at = ?, error_code = ?, error_message = ?, current_item = NULL
|
|
WHERE id IN ({task_placeholders})
|
|
""",
|
|
("failed", finished_at, error_code, error_message, *task_ids),
|
|
)
|
|
conn.execute(
|
|
f"""
|
|
DELETE FROM task_artifacts
|
|
WHERE task_id IN ({task_placeholders})
|
|
""",
|
|
task_ids,
|
|
)
|
|
return task_ids
|
|
|
|
def _migrate_tasks_columns(self, conn: sqlite3.Connection) -> None:
|
|
rows = conn.execute("PRAGMA table_info(tasks)").fetchall()
|
|
existing_columns = {row["name"] for row in rows}
|
|
for column, ddl in TASK_MIGRATION_COLUMNS.items():
|
|
if column in existing_columns:
|
|
continue
|
|
conn.execute(f"ALTER TABLE tasks ADD COLUMN {column} {ddl}")
|
|
|
|
def _connect(self) -> sqlite3.Connection:
|
|
conn = sqlite3.connect(self._db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
@contextmanager
|
|
def _connection(self):
|
|
conn = self._connect()
|
|
try:
|
|
yield conn
|
|
conn.commit()
|
|
except Exception:
|
|
conn.rollback()
|
|
raise
|
|
finally:
|
|
conn.close()
|
|
|
|
@staticmethod
|
|
def _to_dict(row: sqlite3.Row) -> dict:
|
|
return {
|
|
"id": row["id"],
|
|
"operation": row["operation"],
|
|
"status": row["status"],
|
|
"source": row["source"],
|
|
"destination": row["destination"],
|
|
"done_bytes": row["done_bytes"],
|
|
"total_bytes": row["total_bytes"],
|
|
"done_items": row["done_items"],
|
|
"total_items": row["total_items"],
|
|
"current_item": row["current_item"],
|
|
"failed_item": row["failed_item"],
|
|
"error_code": row["error_code"],
|
|
"error_message": row["error_message"],
|
|
"created_at": row["created_at"],
|
|
"started_at": row["started_at"],
|
|
"finished_at": row["finished_at"],
|
|
}
|
|
|
|
@staticmethod
|
|
def _artifact_to_dict(row: sqlite3.Row) -> dict:
|
|
return {
|
|
"task_id": row["task_id"],
|
|
"file_path": row["file_path"],
|
|
"file_name": row["file_name"],
|
|
"expires_at": row["expires_at"],
|
|
"created_at": row["created_at"],
|
|
}
|
|
|
|
@staticmethod
|
|
def _now_iso() -> str:
|
|
return datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z")
|