"""sort_jobs persistence + phase derivation. Phase is *derived* on every GET (Spec B+F §4): never stored as the source of truth except for terminal states. The function `derive_phase` reads live counts from download_jobs and decides expanding/queued/draining/done. This makes the system restart-resilient by construction - there is no event log to replay. """ from __future__ import annotations import json from typing import Any, Dict, List, Optional from uuid import UUID import asyncpg # ── CRUD ──────────────────────────────────────────────────────────────────── async def create_job( conn: asyncpg.Connection, *, input_raw: str, collection_ids: List[str], wsids: Optional[List[str]], rules_raw: Optional[str], initial_phase: str, pz_build: Optional[str] = None, ) -> str: """Insert a sort_jobs row and return the job_id (UUID as string). initial_phase: 'expanding' if collections still need resolving, 'queued' if wsids are already resolved at submit time. pz_build: 'B41' / 'B42' captured at submit so the polling-path result regen can emit build-mismatch warnings against the user's chosen build. """ row = await conn.fetchrow( """ INSERT INTO sort_jobs (phase, input_raw, collection_ids, wsids, rules_raw, pz_build) VALUES ($1, $2, $3, $4, $5, $6) RETURNING job_id """, initial_phase, input_raw, collection_ids, wsids, rules_raw, pz_build, ) return str(row["job_id"]) async def get_job_row(conn: asyncpg.Connection, job_id: str) -> Optional[Dict[str, Any]]: """Fetch a sort_jobs row by id. Returns None if not found. job_id may be either a string UUID or asyncpg-native UUID. """ try: uid = UUID(job_id) if isinstance(job_id, str) else job_id except ValueError: return None row = await conn.fetchrow( "SELECT * FROM sort_jobs WHERE job_id = $1", uid, ) if row is None: return None out = dict(row) # asyncpg returns jsonb as raw text by default (no codec registered # in db.py). Decode result_json so callers always receive a dict. rj = out.get("result_json") if isinstance(rj, str): out["result_json"] = json.loads(rj) return out async def update_phase( conn: asyncpg.Connection, job_id: str, phase: str, *, wsids: Optional[List[str]] = None, result_json: Optional[Dict[str, Any]] = None, failure_reason: Optional[str] = None, ) -> None: """Advance a job's phase. wsids/result_json/failure_reason are optional column updates that pair with phase transitions.""" # Accept str OR UUID; mirrors get_job_row's input shape. uid = UUID(job_id) if isinstance(job_id, str) else job_id sets = ["phase = $2", "phase_started_at = now()"] # Convention: $1=job_id, $2=phase; optional fields start at $3. params: List[Any] = [uid, phase] idx = 3 if wsids is not None: sets.append(f"wsids = ${idx}::text[]") params.append(wsids) idx += 1 if result_json is not None: sets.append(f"result_json = ${idx}::jsonb") params.append(json.dumps(result_json)) idx += 1 if failure_reason is not None: sets.append(f"failure_reason = ${idx}") params.append(failure_reason) idx += 1 await conn.execute( f"UPDATE sort_jobs SET {', '.join(sets)} WHERE job_id = $1", *params, ) # ── live counts (Spec B+F §6) ─────────────────────────────────────────────── async def compute_counts(conn: asyncpg.Connection, wsids: List[str]) -> Dict[str, int]: """Compute live cached/queued/draining/terminal_failed counts. Empty wsids → all zeros. terminal_failed: wsids whose LATEST download_jobs row has status='failed'. These will not appear in mod_parsed and are not coming back; without this, derive_phase would loop forever when a job's wsids include non-mods or permanently-broken downloads. """ if not wsids: return {"cached": 0, "queued": 0, "draining": 0, "terminal_failed": 0} rows = await conn.fetch( """ SELECT (SELECT COUNT(DISTINCT mp.workshop_id) FROM mod_parsed mp JOIN workshop_meta wm ON wm.workshop_id = mp.workshop_id WHERE mp.workshop_id = ANY($1::text[]) AND mp.parsed_at_time_updated = wm.time_updated) AS cached, (SELECT COUNT(DISTINCT workshop_id) FROM download_jobs WHERE workshop_id = ANY($1::text[]) AND status = 'queued') AS queued, (SELECT COUNT(DISTINCT workshop_id) FROM download_jobs WHERE workshop_id = ANY($1::text[]) AND status = 'downloading') AS draining, (SELECT COUNT(*) FROM ( SELECT DISTINCT ON (workshop_id) workshop_id, status FROM download_jobs WHERE workshop_id = ANY($1::text[]) ORDER BY workshop_id, updated_at DESC ) latest WHERE status = 'failed') AS terminal_failed """, wsids, ) r = rows[0] return { "cached": int(r["cached"]), "queued": int(r["queued"]), "draining": int(r["draining"]), "terminal_failed": int(r["terminal_failed"]), } # ── phase derivation (Spec B+F §4) ────────────────────────────────────────── def derive_phase( stored_phase: str, wsids: Optional[List[str]], counts: Dict[str, int], ) -> str: """Decide the live phase from the row's stored phase + current counts. Terminal phases (done/failed) are never demoted. Non-terminal phases are recomputed from current state. """ if stored_phase in ("done", "failed"): return stored_phase if wsids is None: return "expanding" if counts["draining"] > 0: return "draining" if counts["queued"] > 0: return "queued" # Terminal: every wsid is either cached or permanently-failed. settled = counts["cached"] + counts.get("terminal_failed", 0) if settled >= len(wsids): return "done" # Transient gap: a row just left 'queued' and hasn't shown up in # mod_parsed yet. Most likely just-failed and not yet re-queued. return "queued" # ── stale-expansion sweep (Spec B+F §9) ───────────────────────────────────── STALE_EXPANSION_SQL = """ UPDATE sort_jobs SET phase = 'failed', failure_reason = 'expansion timed out', updated_at = now() WHERE phase = 'expanding' AND phase_started_at < now() - interval '10 minutes' RETURNING job_id; """ async def sweep_stale_expansions(conn: asyncpg.Connection) -> int: """Run on uvicorn lifespan startup. Returns the number of jobs reaped.""" rows = await conn.fetch(STALE_EXPANSION_SQL) return len(rows)