Files
sortof/api/expansion.py

141 lines
4.9 KiB
Python

"""Background async task: take a freshly-created sort_jobs row in 'expanding'
phase, resolve its collection_ids via Steam, populate wsids[], advance phase
to 'queued' (and drop wsids into download_jobs as needed)."""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Dict, List, Tuple
import asyncpg
import httpx
from jobs import update_phase
from steam import fetch_collection_details
log = logging.getLogger("sortof.expansion")
async def _resolve_collections(
conn: asyncpg.Connection,
http: httpx.AsyncClient,
collection_ids: List[str],
) -> Tuple[Dict[str, List[str]], List[str]]:
"""Returns (resolved, unresolvable). resolved maps collection_id ->
[child_wsids]. unresolvable lists collection_ids that GetCollectionDetails
couldn't fetch (after one retry)."""
if not collection_ids:
return ({}, [])
# Cache lookup (TTL = 6h via last_fetched_at).
cache_rows = await conn.fetch(
"""
SELECT collection_id, child_workshop_ids
FROM collections
WHERE collection_id = ANY($1::text[])
AND last_fetched_at > now() - interval '6 hours'
""",
collection_ids,
)
resolved: Dict[str, List[str]] = {
r["collection_id"]: list(r["child_workshop_ids"])
for r in cache_rows
}
miss = [cid for cid in collection_ids if cid not in resolved]
unresolvable: List[str] = []
if miss:
# Spec §5.4: 1 retry with 2s backoff on HTTPError. If both attempts
# raise, api_out stays {} and the per-cid pass below uniformly marks
# every miss as unresolvable (rec is None branch).
api_out: Dict[str, Any] = {}
for attempt in (1, 2):
try:
api_out = await fetch_collection_details(http, miss)
break
except httpx.HTTPError as e:
log.warning("GetCollectionDetails attempt %d failed: %s", attempt, e)
if attempt == 1:
await asyncio.sleep(2.0)
for cid in miss:
rec = api_out.get(cid)
if rec is None or rec.get("result") != 1:
unresolvable.append(cid)
continue
children = rec.get("children") or []
resolved[cid] = list(children)
await conn.execute(
"""
INSERT INTO collections (collection_id, child_workshop_ids, last_fetched_at)
VALUES ($1, $2, now())
ON CONFLICT (collection_id) DO UPDATE
SET child_workshop_ids = EXCLUDED.child_workshop_ids,
last_fetched_at = now()
""",
cid, children,
)
return (resolved, unresolvable)
async def run_expansion(
pool: asyncpg.Pool,
http: httpx.AsyncClient,
job_id: str,
bare_wsids: List[str],
collection_ids: List[str],
) -> None:
"""Top-level expansion task. Logs and persists; never raises out."""
try:
async with pool.acquire() as conn:
resolved, unresolvable = await _resolve_collections(conn, http, collection_ids)
# Compose wsids: collections (in input order) + bare wsids, deduped.
seen: set = set()
wsids: List[str] = []
for cid in collection_ids:
for w in resolved.get(cid, []):
if w and w not in seen:
seen.add(w)
wsids.append(w)
for w in bare_wsids:
if w not in seen:
seen.add(w)
wsids.append(w)
if not wsids:
# All collections unresolvable AND no bare wsids. Job dies.
await update_phase(
conn, job_id, "failed",
failure_reason="all input collections unresolvable",
)
log.info("expansion %s: failed - all collections unresolvable", job_id)
return
partial_warnings = [
{
"tag": "collection-partial",
"level": "warning",
"msg": f"collection {cid} could not be fetched",
}
for cid in unresolvable
]
seed_result = {"WARNINGS": partial_warnings} if partial_warnings else None
await update_phase(
conn, job_id, "queued",
wsids=wsids,
result_json=seed_result,
)
log.info(
"expansion %s: queued (wsids=%d unresolvable=%d)",
job_id, len(wsids), len(unresolvable),
)
except Exception:
log.exception("expansion %s: crashed", job_id)
try:
async with pool.acquire() as conn:
await update_phase(conn, job_id, "failed", failure_reason="expansion crashed")
except Exception:
log.exception("expansion %s: cleanup failed", job_id)