Files
sortof/api/jobs.py

200 lines
7.1 KiB
Python

"""sort_jobs persistence + phase derivation.
Phase is *derived* on every GET (Spec B+F §4): never stored as the source
of truth except for terminal states. The function `derive_phase` reads
live counts from download_jobs and decides expanding/queued/draining/done.
This makes the system restart-resilient by construction - there is no
event log to replay.
"""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
from uuid import UUID
import asyncpg
# ── CRUD ────────────────────────────────────────────────────────────────────
async def create_job(
conn: asyncpg.Connection,
*,
input_raw: str,
collection_ids: List[str],
wsids: Optional[List[str]],
rules_raw: Optional[str],
initial_phase: str,
pz_build: Optional[str] = None,
) -> str:
"""Insert a sort_jobs row and return the job_id (UUID as string).
initial_phase: 'expanding' if collections still need resolving,
'queued' if wsids are already resolved at submit time.
pz_build: 'B41' / 'B42' captured at submit so the polling-path
result regen can emit build-mismatch warnings against the
user's chosen build.
"""
row = await conn.fetchrow(
"""
INSERT INTO sort_jobs (phase, input_raw, collection_ids, wsids, rules_raw, pz_build)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING job_id
""",
initial_phase, input_raw, collection_ids, wsids, rules_raw, pz_build,
)
return str(row["job_id"])
async def get_job_row(conn: asyncpg.Connection, job_id: str) -> Optional[Dict[str, Any]]:
"""Fetch a sort_jobs row by id. Returns None if not found.
job_id may be either a string UUID or asyncpg-native UUID.
"""
try:
uid = UUID(job_id) if isinstance(job_id, str) else job_id
except ValueError:
return None
row = await conn.fetchrow(
"SELECT * FROM sort_jobs WHERE job_id = $1",
uid,
)
if row is None:
return None
out = dict(row)
# asyncpg returns jsonb as raw text by default (no codec registered
# in db.py). Decode result_json so callers always receive a dict.
rj = out.get("result_json")
if isinstance(rj, str):
out["result_json"] = json.loads(rj)
return out
async def update_phase(
conn: asyncpg.Connection,
job_id: str,
phase: str,
*,
wsids: Optional[List[str]] = None,
result_json: Optional[Dict[str, Any]] = None,
failure_reason: Optional[str] = None,
) -> None:
"""Advance a job's phase. wsids/result_json/failure_reason are optional
column updates that pair with phase transitions."""
# Accept str OR UUID; mirrors get_job_row's input shape.
uid = UUID(job_id) if isinstance(job_id, str) else job_id
sets = ["phase = $2", "phase_started_at = now()"]
# Convention: $1=job_id, $2=phase; optional fields start at $3.
params: List[Any] = [uid, phase]
idx = 3
if wsids is not None:
sets.append(f"wsids = ${idx}::text[]")
params.append(wsids)
idx += 1
if result_json is not None:
sets.append(f"result_json = ${idx}::jsonb")
params.append(json.dumps(result_json))
idx += 1
if failure_reason is not None:
sets.append(f"failure_reason = ${idx}")
params.append(failure_reason)
idx += 1
await conn.execute(
f"UPDATE sort_jobs SET {', '.join(sets)} WHERE job_id = $1",
*params,
)
# ── live counts (Spec B+F §6) ───────────────────────────────────────────────
async def compute_counts(conn: asyncpg.Connection, wsids: List[str]) -> Dict[str, int]:
"""Compute live cached/queued/draining/terminal_failed counts.
Empty wsids → all zeros.
terminal_failed: wsids whose LATEST download_jobs row has status='failed'.
These will not appear in mod_parsed and are not coming back; without this,
derive_phase would loop forever when a job's wsids include non-mods or
permanently-broken downloads.
"""
if not wsids:
return {"cached": 0, "queued": 0, "draining": 0, "terminal_failed": 0}
rows = await conn.fetch(
"""
SELECT
(SELECT COUNT(DISTINCT mp.workshop_id)
FROM mod_parsed mp
JOIN workshop_meta wm ON wm.workshop_id = mp.workshop_id
WHERE mp.workshop_id = ANY($1::text[])
AND mp.parsed_at_time_updated = wm.time_updated) AS cached,
(SELECT COUNT(DISTINCT workshop_id)
FROM download_jobs
WHERE workshop_id = ANY($1::text[]) AND status = 'queued') AS queued,
(SELECT COUNT(DISTINCT workshop_id)
FROM download_jobs
WHERE workshop_id = ANY($1::text[]) AND status = 'downloading') AS draining,
(SELECT COUNT(*) FROM (
SELECT DISTINCT ON (workshop_id) workshop_id, status
FROM download_jobs
WHERE workshop_id = ANY($1::text[])
ORDER BY workshop_id, updated_at DESC
) latest WHERE status = 'failed') AS terminal_failed
""",
wsids,
)
r = rows[0]
return {
"cached": int(r["cached"]),
"queued": int(r["queued"]),
"draining": int(r["draining"]),
"terminal_failed": int(r["terminal_failed"]),
}
# ── phase derivation (Spec B+F §4) ──────────────────────────────────────────
def derive_phase(
stored_phase: str,
wsids: Optional[List[str]],
counts: Dict[str, int],
) -> str:
"""Decide the live phase from the row's stored phase + current counts.
Terminal phases (done/failed) are never demoted. Non-terminal phases
are recomputed from current state.
"""
if stored_phase in ("done", "failed"):
return stored_phase
if wsids is None:
return "expanding"
if counts["draining"] > 0:
return "draining"
if counts["queued"] > 0:
return "queued"
# Terminal: every wsid is either cached or permanently-failed.
settled = counts["cached"] + counts.get("terminal_failed", 0)
if settled >= len(wsids):
return "done"
# Transient gap: a row just left 'queued' and hasn't shown up in
# mod_parsed yet. Most likely just-failed and not yet re-queued.
return "queued"
# ── stale-expansion sweep (Spec B+F §9) ─────────────────────────────────────
STALE_EXPANSION_SQL = """
UPDATE sort_jobs
SET phase = 'failed',
failure_reason = 'expansion timed out',
updated_at = now()
WHERE phase = 'expanding'
AND phase_started_at < now() - interval '10 minutes'
RETURNING job_id;
"""
async def sweep_stale_expansions(conn: asyncpg.Connection) -> int:
"""Run on uvicorn lifespan startup. Returns the number of jobs reaped."""
rows = await conn.fetch(STALE_EXPANSION_SQL)
return len(rows)