#!/usr/bin/env python3 """GSM8K 50-question subset benchmark (seed=42).""" import json import os import random import re import sys import time from pathlib import Path from datasets import load_dataset from openai import OpenAI from tqdm import tqdm ENDPOINT = os.environ.get("LLAMA_SWAP_URL", "http://100.101.41.16:8401/v1") RESULTS_DIR = Path(__file__).parent / "results" MAX_TOKENS = 512 SEED = 42 TEMPERATURE = 0 N_QUESTIONS = 50 def load_questions() -> list[dict]: rng = random.Random(SEED) ds = load_dataset("openai/gsm8k", "main", split="test", trust_remote_code=True) indices = list(range(len(ds))) rng.shuffle(indices) questions = [] for idx in indices[:N_QUESTIONS]: row = ds[idx] answer_text = row["answer"] # GSM8K answer format: "#### " at end match = re.search(r"####\s*([0-9,.-]+)", answer_text) expected = int(match.group(1).replace(",", "")) if match else 0 questions.append({ "id": f"gsm8k_{idx}", "question": row["question"], "expected": expected, }) return questions def format_prompt(q: dict) -> str: return ( "Solve this problem step by step, then on the final line write " "'ANSWER: '.\n\n" + q["question"] ) def parse_answer(text: str) -> int | None: matches = re.findall(r"ANSWER:\s*([0-9,.-]+)", text, re.IGNORECASE) if matches: try: return int(matches[-1].replace(",", "")) except ValueError: return None # Fallback: last number in the response nums = re.findall(r"-?\d[\d,]*", text) if nums: try: return int(nums[-1].replace(",", "")) except ValueError: return None return None def run_gsm8k(model: str, client: OpenAI, questions: list[dict]) -> list[dict]: model_dir = RESULTS_DIR / model / "gsm8k" model_dir.mkdir(parents=True, exist_ok=True) results = [] correct = 0 total = 0 skipped = 0 for i, q in enumerate(tqdm(questions, desc=f" GSM8K {model}", file=sys.stderr)): expected = q["expected"] out_path = model_dir / f"{q['id']}.json" if out_path.exists(): try: cached = json.loads(out_path.read_text()) raw = "" if "choices" in cached: msg = cached["choices"][0].get("message", {}) raw = msg.get("content", "") or msg.get("reasoning_content", "") or "" parsed = parse_answer(raw) is_correct = parsed is not None and parsed == expected if is_correct: correct += 1 total += 1 results.append({ "model": model, "benchmark": "gsm8k", "question_id": q["id"], "correct": is_correct, "raw_answer": raw[:200], "parsed_answer": str(parsed) if parsed is not None else "", "expected": str(expected), "latency_ms": 0, }) skipped += 1 continue except (json.JSONDecodeError, KeyError): pass prompt = format_prompt(q) t0 = time.time() resp_json = None for attempt in range(2): try: resp = client.chat.completions.create( model=model, messages=[{"role": "user", "content": prompt}], max_tokens=MAX_TOKENS, temperature=TEMPERATURE, seed=SEED, ) resp_json = resp.model_dump() break except Exception as e: if attempt == 0: time.sleep(5) else: resp_json = {"error": str(e)} latency = (time.time() - t0) * 1000 raw = "" if resp_json and "choices" in resp_json: msg = resp_json["choices"][0].get("message", {}) raw = msg.get("content", "") or msg.get("reasoning_content", "") or "" parsed = parse_answer(raw) is_correct = parsed is not None and parsed == expected if is_correct: correct += 1 total += 1 out_path.write_text(json.dumps(resp_json, indent=2, default=str)) results.append({ "model": model, "benchmark": "gsm8k", "question_id": q["id"], "correct": is_correct, "raw_answer": raw[:200], "parsed_answer": str(parsed) if parsed is not None else "", "expected": str(expected), "latency_ms": round(latency, 1), }) if (i + 1) % 10 == 0: print(f" [{model}] GSM8K {i+1}/{len(questions)} — {correct}/{total} ({correct/total*100:.0f}%)", file=sys.stderr) if skipped: print(f" [{model}] GSM8K resumed: {skipped} cached, {total-skipped} new", file=sys.stderr) print(f" [{model}] GSM8K FINAL: {correct}/{total} ({correct/total*100:.1f}%)", file=sys.stderr) return results if __name__ == "__main__": model = sys.argv[1] if len(sys.argv) > 1 else "qwen3.6-35b-a3b-mxfp4" client = OpenAI(base_url=ENDPOINT, api_key="dummy") questions = load_questions() results = run_gsm8k(model, client, questions) for r in results: print(json.dumps(r))