From 689a9a470e5d96a853731b2accd463475e5e9a19 Mon Sep 17 00:00:00 2001 From: Ruixiang Wang Date: Fri, 29 May 2026 23:09:47 +0200 Subject: [PATCH] server-bench : add speed-bench for speculative decoding benchmarking (#23869) * spec: add speed-bench support for benchmarking * speed-bench : add trailing newline to requirements.txt * speed-bench : bump datasets to 4.8.0 to fix ty check * server-bench : remove now-unused type: ignore after datasets bump --- docs/speculative.md | 5 + requirements/requirements-server-bench.txt | 2 +- scripts/server-bench.py | 2 +- tools/server/bench/speed-bench/README.md | 117 +++++ .../server/bench/speed-bench/requirements.txt | 3 + tools/server/bench/speed-bench/speed_bench.py | 432 ++++++++++++++++++ .../bench/speed-bench/speed_bench_compare.py | 84 ++++ 7 files changed, 643 insertions(+), 2 deletions(-) create mode 100644 tools/server/bench/speed-bench/README.md create mode 100644 tools/server/bench/speed-bench/requirements.txt create mode 100644 tools/server/bench/speed-bench/speed_bench.py create mode 100644 tools/server/bench/speed-bench/speed_bench_compare.py diff --git a/docs/speculative.md b/docs/speculative.md index 041ff58038..43d1818589 100644 --- a/docs/speculative.md +++ b/docs/speculative.md @@ -323,3 +323,8 @@ statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts - `#gen tokens`: number of tokens generated by this implementation (including rejected tokens) - `#acc tokens`: number of tokens accepted by the main model - `dur(b,g,a): durations of begin (new prompt), generation and accumulation (process acceptance). + +## Benchmarking + +To measure the end-to-end effect of speculative decoding (throughput, latency, and draft acceptance) across diverse prompts, see the SPEED-Bench client in [tools/server/bench/speed-bench](../tools/server/bench/speed-bench/README.md). +It runs against a running `llama-server` and can compare a baseline run against a speculative-decoding run. diff --git a/requirements/requirements-server-bench.txt b/requirements/requirements-server-bench.txt index ea5849fa10..fb3b0d2664 100644 --- a/requirements/requirements-server-bench.txt +++ b/requirements/requirements-server-bench.txt @@ -1,4 +1,4 @@ -datasets~=3.2.0 +datasets~=4.8.0 matplotlib~=3.10.0 numpy~=1.26.4 requests~=2.32.3 diff --git a/scripts/server-bench.py b/scripts/server-bench.py index 1b557a495a..2eabb3bce8 100755 --- a/scripts/server-bench.py +++ b/scripts/server-bench.py @@ -25,7 +25,7 @@ def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]: ret = [] if dataset_name.lower() == "mmlu": logger.info("Loading MMLU dataset...") - ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore + ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] else: return None if n_prompts >= 0: diff --git a/tools/server/bench/speed-bench/README.md b/tools/server/bench/speed-bench/README.md new file mode 100644 index 0000000000..8d3fcd804c --- /dev/null +++ b/tools/server/bench/speed-bench/README.md @@ -0,0 +1,117 @@ +# SPEED-Bench server benchmark + +A lightweight [SPEED-Bench](https://huggingface.co/datasets/nvidia/SPEED-Bench) client for benchmarking an already-running `llama-server` through its OpenAI-compatible API. It is primarily meant to evaluate speculative decoding (draft model, n-gram, MTP, EAGLE3, ...) by reporting per-category throughput, latency, and draft acceptance. + +The dataset handling follows the [aiperf SPEED-Bench tutorial](https://github.com/ai-dynamo/aiperf/blob/main/docs/tutorials/speed-bench.md), which also documents the dataset layout in more detail. + +## Install + +```bash +pip install -r tools/server/bench/speed-bench/requirements.txt +``` + +## Start a server + +The client does not launch the server, so start `llama-server` yourself first. If you care about throughput numbers, set the client `--concurrency` to the server's slot count (`--np`): + +```bash +llama-server \ + -m target.gguf \ + -c 8192 \ + --port 8080 \ + -ngl 99 -fa on \ + --np 1 \ + --jinja +``` + +For speculative decoding, start the server with the appropriate flags for your setup (e.g. a draft model with `-md`, or `--spec-type ngram-mod`). See the [speculative decoding doc](../../../../docs/speculative.md) for details. + +## Run + +```bash +python tools/server/bench/speed-bench/speed_bench.py \ + --url localhost:8080 \ + --bench qualitative \ + --category coding \ + --osl 1024 \ + --concurrency 1 +``` + +## Options + +| Option | Default | Description | +| --- | --- | --- | +| `--url` | `localhost:8080` | Server URL. The scheme and `/v1` are optional and a trailing slash is fine, so `localhost:8080` and `http://localhost:8080/v1/` both work. | +| `--model` | none | Optional `model` field sent in each request. | +| `--bench` | `qualitative` | SPEED-Bench config, e.g. `qualitative`, `throughput_1k`. See [available dataset variants](https://github.com/ai-dynamo/aiperf/blob/main/docs/tutorials/speed-bench.md#available-dataset-variants). | +| `--category` | `all` | Category filter within the bench; comma-separated list or `all`. For `qualitative` the categories are `coding`, `humanities`, `math`, `multilingual`, `qa`, `rag`, `reasoning`, `roleplay`, `stem`, `summarization`, `writing`. For the `throughput_{ISL}` splits they are `high_entropy`, `low_entropy`, `mixed`. | +| `--osl` | `1024` | Output sequence length, mapped to `max_tokens`. | +| `--extra-inputs` | `{"temperature":0}` | Extra request fields as a JSON object. | +| `--concurrency` | `1` | Concurrent client requests; usually match `--np`. | +| `--limit` | none | Max samples per category (handy for smoke tests). | +| `--timeout` | `600` | Per-request timeout in seconds. | +| `--output` | none | Save raw per-request results and the summary to JSON. | + +A few common ones: + +- `--category all` runs every category in the bench. +- `--category coding,math` runs just those two. +- `--bench throughput_8k` runs a fixed-input-length throughput split. +- `--limit 8` keeps at most 8 samples per category, which is enough for a quick check. + +The `throughput_{ISL}` splits use fixed input lengths (1k - 32k), so they are handy for long-context testing and for comparing different `llama-server` batching settings (e.g. sweeping `-ub` / `--ubatch-size`) on prompts of a known size. Make sure the server `-c` is large enough for the chosen split. When raising `-ub`, also raise `-b` to at least the same value, since the physical ubatch cannot exceed the logical batch. + +When `--output` is given, the JSON file holds the run `config`, the `selected_samples` / `completed_samples` / `failed_samples` counts, the per-category `summary` rows, and the per-sample `results`. + +## Metrics + +The summary prints one row per category plus an `overall` row: + +- `samples` - how many samples finished successfully. +- `avg_prompt_t/s` - prefill throughput from llama.cpp (`timings.prompt_per_second`), averaged over the category's samples. +- `avg_pred_t/s` - decode throughput from llama.cpp (`timings.predicted_per_second`), averaged over the category's samples. +- `avg_latency` - average end-to-end request latency seen by the client. +- `accept_rate` - `accepted / draft_n` over the category, or `n/a` if nothing was drafted (`draft_n == 0`). + +## Baseline vs speculative decoding + +Save a run from each server with `--output`, then diff the two JSON files with `speed_bench_compare.py`. + +First, start a plain `llama-server` (no speculative decoding) and save a baseline: + +```bash +python tools/server/bench/speed-bench/speed_bench.py \ + --url localhost:8080 \ + --bench qualitative \ + --category all \ + --osl 1024 \ + --concurrency 1 \ + --output baseline.json +``` + +Then restart `llama-server` with speculative decoding enabled and save another run: + +```bash +python tools/server/bench/speed-bench/speed_bench.py \ + --url localhost:8080 \ + --bench qualitative \ + --category all \ + --osl 1024 \ + --concurrency 1 \ + --output spec.json +``` + +Finally compare the two: + +```bash +python tools/server/bench/speed-bench/speed_bench_compare.py \ + --baseline baseline.json \ + --speculative spec.json +``` + +The comparison table adds: + +- `decode_speedup = spec_avg_pred_t/s / base_avg_pred_t/s` +- `latency_speedup = base_avg_latency / spec_avg_latency` + +Keep `--bench`, `--category`, `--osl`, and `--limit` the same across both runs, otherwise they won't be using the same prompts. diff --git a/tools/server/bench/speed-bench/requirements.txt b/tools/server/bench/speed-bench/requirements.txt new file mode 100644 index 0000000000..a524c2f519 --- /dev/null +++ b/tools/server/bench/speed-bench/requirements.txt @@ -0,0 +1,3 @@ +datasets +requests +tqdm diff --git a/tools/server/bench/speed-bench/speed_bench.py b/tools/server/bench/speed-bench/speed_bench.py new file mode 100644 index 0000000000..adb378a6bf --- /dev/null +++ b/tools/server/bench/speed-bench/speed_bench.py @@ -0,0 +1,432 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import concurrent.futures +import json +import statistics +import sys +import time +from dataclasses import asdict, dataclass +from typing import Any +from urllib.parse import urlparse + +import requests +from datasets import get_dataset_config_names, load_dataset +from tqdm import tqdm + + +DATASET_REPO = "nvidia/SPEED-Bench" + +@dataclass +class Sample: + id: str + category: str + turns: list[str] + + +@dataclass +class RequestResult: + id: str + category: str + ok: bool + turns: int + latency_s: float + prompt_tokens: int + completion_tokens: int + total_tokens: int + finish_reason: str | None + draft_n: int + draft_n_accepted: int + prompt_ms: float | None + predicted_ms: float | None + prompt_per_second: float | None + predicted_per_second: float | None + error: str | None + + +def normalize_base_url(url: str) -> str: + url = url.strip().rstrip("/") + if not url: + raise ValueError("--url cannot be empty") + if "://" not in url: + url = "http://" + url + parsed = urlparse(url) + if not parsed.scheme or not parsed.netloc: + raise ValueError(f"invalid --url: {url}") + if not parsed.path.rstrip("/").endswith("/v1"): + url = url + "/v1" + return url.rstrip("/") + + +def parse_extra_inputs(value: str) -> dict[str, Any]: + extra = json.loads(value) + if not isinstance(extra, dict): + raise ValueError("--extra-inputs must be a JSON object") + return extra + + +def extract_turns(row: dict[str, Any]) -> list[str]: + turns = row.get("turns") + if isinstance(turns, list) and turns: + clean_turns = [str(turn).strip() for turn in turns if turn and str(turn).strip()] + if clean_turns: + return clean_turns + raise ValueError("missing or empty turns") + + +def load_samples(args: argparse.Namespace) -> list[Sample]: + bench_names = get_dataset_config_names(DATASET_REPO) + if args.bench not in bench_names: + raise ValueError( + f"unknown --bench {args.bench!r}; available benches: {', '.join(bench_names)}" + ) + + dataset = load_dataset(DATASET_REPO, name=args.bench, split="test") + categories = list(dict.fromkeys(str(category) for category in dataset["category"])) + requested_categories = None + if args.category != "all": + requested_list = [category.strip() for category in args.category.split(",") if category.strip()] + if not requested_list: + raise ValueError( + f"--category must be 'all' or a comma-separated list; available categories: {', '.join(categories)}" + ) + requested_categories = set(requested_list) + unknown_categories = [category for category in requested_list if category not in categories] + if unknown_categories: + unknown = ", ".join(unknown_categories) + raise ValueError( + f"unknown --category {unknown!r} for bench {args.bench!r}; " + f"available categories: all, {', '.join(categories)}" + ) + + samples: list[Sample] = [] + samples_per_category: dict[str, int] = {} + skipped = 0 + for index, row_raw in enumerate(dataset): + row = dict(row_raw) + category_raw = row.get("category") + if not isinstance(category_raw, str) or not category_raw.strip(): + skipped += 1 + continue + category = category_raw.strip() + if requested_categories is not None and category not in requested_categories: + continue + if args.limit is not None and samples_per_category.get(category, 0) >= args.limit: + continue + + try: + turns = extract_turns(row) + except ValueError: + skipped += 1 + continue + question_id = row.get("question_id") + if not isinstance(question_id, str) or not question_id.strip(): + skipped += 1 + continue + sample_id = question_id.strip() + samples.append(Sample(id=sample_id, category=category, turns=turns)) + samples_per_category[category] = samples_per_category.get(category, 0) + 1 + + if not samples: + raise RuntimeError(f"no samples selected from bench={args.bench} category={args.category}") + + if skipped: + print(f"speed_bench: skipped {skipped} rows without usable turns") + return samples + + +def parse_completion_response(data: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any], str | None, str]: + usage = data.get("usage") or {} + timings = data.get("timings") or {} + finish_reason = None + content = "" + choices = data.get("choices") + if isinstance(choices, list) and choices and isinstance(choices[0], dict): + choice = choices[0] + finish_reason = choice.get("finish_reason") + message = choice.get("message") + if isinstance(message, dict) and isinstance(message.get("content"), str): + content = message["content"] + elif isinstance(choice.get("text"), str): + content = choice["text"] + return usage, timings, finish_reason, content + + +def run_request( + endpoint: str, + model: str | None, + messages: list[dict[str, str]], + osl: int, + extra_inputs: dict[str, Any], + timeout: float, +) -> tuple[dict[str, Any], float]: + payload: dict[str, Any] = { + "messages": messages, + "max_tokens": osl, + "stream": False, + } + if model: + payload["model"] = model + payload.update(extra_inputs) + payload["max_tokens"] = osl + + start = time.perf_counter() + response = requests.post(endpoint, json=payload, timeout=timeout) + latency_s = time.perf_counter() - start + if response.status_code != 200: + body = response.text[:500].replace("\n", "\\n") + raise RuntimeError(f"HTTP {response.status_code}: {body}") + return response.json(), latency_s + + +def run_one( + sample: Sample, + endpoint: str, + model: str | None, + osl: int, + extra_inputs: dict[str, Any], + timeout: float, +) -> RequestResult: + selected_turns = sample.turns + messages: list[dict[str, str]] = [] + total_latency_s = 0.0 + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + draft_n = 0 + draft_n_accepted = 0 + prompt_ms = 0.0 + predicted_ms = 0.0 + prompt_per_second = None + predicted_per_second = None + finish_reason: str | None = None + try: + for turn in selected_turns: + messages.append({"role": "user", "content": turn}) + data, latency_s = run_request(endpoint, model, messages, osl, extra_inputs, timeout) + total_latency_s += latency_s + usage, timings, finish_reason, assistant_text = parse_completion_response(data) + + turn_prompt_tokens = int(usage.get("prompt_tokens") or timings.get("prompt_n") or 0) + turn_completion_tokens_count = int(usage.get("completion_tokens") or timings.get("predicted_n") or 0) + turn_total_tokens_count = int(usage.get("total_tokens") or (turn_prompt_tokens + turn_completion_tokens_count)) + prompt_tokens += turn_prompt_tokens + completion_tokens += turn_completion_tokens_count + total_tokens += turn_total_tokens_count + draft_n += int(timings.get("draft_n") or 0) + draft_n_accepted += int(timings.get("draft_n_accepted") or 0) + prompt_ms += float(timings.get("prompt_ms") or 0) + predicted_ms += float(timings.get("predicted_ms") or 0) + if len(selected_turns) == 1 and isinstance(timings.get("prompt_per_second"), (int, float)): + prompt_per_second = float(timings["prompt_per_second"]) + if len(selected_turns) == 1 and isinstance(timings.get("predicted_per_second"), (int, float)): + predicted_per_second = float(timings["predicted_per_second"]) + + messages.append({"role": "assistant", "content": assistant_text}) + + if total_tokens == 0: + total_tokens = prompt_tokens + completion_tokens + if len(selected_turns) > 1: + prompt_per_second = (prompt_tokens / (prompt_ms / 1000)) if prompt_ms > 0 else None + predicted_per_second = (completion_tokens / (predicted_ms / 1000)) if predicted_ms > 0 else None + + return RequestResult( + id=sample.id, + category=sample.category, + ok=True, + turns=len(selected_turns), + latency_s=total_latency_s, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + finish_reason=finish_reason, + draft_n=draft_n, + draft_n_accepted=draft_n_accepted, + prompt_ms=prompt_ms if prompt_ms > 0 else None, + predicted_ms=predicted_ms if predicted_ms > 0 else None, + prompt_per_second=prompt_per_second, + predicted_per_second=predicted_per_second, + error=None, + ) + except Exception as exc: + return RequestResult( + id=sample.id, + category=sample.category, + ok=False, + turns=len(selected_turns), + latency_s=total_latency_s, + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + finish_reason=None, + draft_n=0, + draft_n_accepted=0, + prompt_ms=None, + predicted_ms=None, + prompt_per_second=None, + predicted_per_second=None, + error=str(exc), + ) + + +def summarize_group(category: str, results: list[RequestResult]) -> dict[str, Any]: + ok_results = [result for result in results if result.ok] + latencies = [result.latency_s for result in ok_results] + server_prompt_speeds = [ + result.prompt_per_second + for result in ok_results + if result.prompt_per_second is not None + ] + server_completion_speeds = [ + result.predicted_per_second + for result in ok_results + if result.predicted_per_second is not None + ] + turns = sum(result.turns for result in ok_results) + draft_n = sum(result.draft_n for result in ok_results) + accepted = sum(result.draft_n_accepted for result in ok_results) + + return { + "category": category, + "requests": len(ok_results), + "turns": turns, + "failed": len(results) - len(ok_results), + "avg_prompt_t_s": statistics.mean(server_prompt_speeds) if server_prompt_speeds else None, + "avg_pred_t_s": statistics.mean(server_completion_speeds) if server_completion_speeds else None, + "avg_latency": statistics.mean(latencies) if latencies else None, + "draft_n": draft_n, + "accepted": accepted, + "accept_rate": (accepted / draft_n) if draft_n > 0 else None, + } + + +def fmt_value(value: Any, kind: str = "") -> str: + if value is None: + return "n/a" + if kind == "int": + return str(int(value)) + if kind == "rate": + return f"{float(value):.4f}" + if kind == "seconds": + return f"{float(value):.3f}s" + if kind == "speed": + return f"{float(value):.2f}" + if kind == "speedup": + return f"{float(value):.2f}x" + return str(value) + + +def print_table(rows: list[dict[str, Any]]) -> None: + columns = [ + ("category", "category", ""), + ("samples", "requests", "int"), + ("avg_prompt_t/s", "avg_prompt_t_s", "speed"), + ("avg_pred_t/s", "avg_pred_t_s", "speed"), + ("avg_latency", "avg_latency", "seconds"), + ("accept_rate", "accept_rate", "rate"), + ] + print_rows(rows, columns) + + +def print_rows(rows: list[dict[str, Any]], columns: list[tuple[str, str, str]]) -> None: + rendered_rows = [] + for row in rows: + rendered_rows.append([fmt_value(row.get(key), kind) for _, key, kind in columns]) + + widths = [len(header) for header, _, _ in columns] + for rendered in rendered_rows: + for i, cell in enumerate(rendered): + widths[i] = max(widths[i], len(cell)) + + header = " ".join(header.ljust(widths[i]) for i, (header, _, _) in enumerate(columns)) + print(header) + print(" ".join("-" * width for width in widths)) + for rendered in rendered_rows: + print(" ".join(cell.ljust(widths[i]) for i, cell in enumerate(rendered))) + + +def save_output(path: str, args: argparse.Namespace, samples: list[Sample], results: list[RequestResult], summary: list[dict[str, Any]]) -> None: + payload = { + "config": { + "url": args.url, + "model": args.model, + "bench": args.bench, + "category": args.category, + "osl": args.osl, + "concurrency": args.concurrency, + "extra_inputs": args.extra_inputs, + }, + "selected_samples": len(samples), + "completed_samples": sum(1 for result in results if result.ok), + "failed_samples": sum(1 for result in results if not result.ok), + "summary": summary, + "results": [asdict(result) for result in results], + } + with open(path, "w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, sort_keys=True) + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description="Run SPEED-Bench against an OpenAI-compatible llama-server.") + parser.add_argument("--url", default="localhost:8080", help="Server URL, for example localhost:8080 or http://localhost:8080/v1") + parser.add_argument("--model", default=None, help="Optional model name to send in OpenAI requests") + parser.add_argument("--bench", default="qualitative", help="SPEED-Bench config to run, for example qualitative or throughput_1k") + parser.add_argument("--category", default="all", help="Category to run within the selected bench; use all for no category filter") + parser.add_argument("--osl", type=int, default=4096, help="Output sequence length, mapped to max_tokens") + parser.add_argument("--extra-inputs", default='{"temperature":0}', help="Extra request fields as a JSON object") + parser.add_argument("--concurrency", type=int, default=1, help="Concurrent client requests; usually match llama-server --np") + parser.add_argument("--limit", type=int, default=None, help="Optional sample limit per category for smoke tests") + parser.add_argument("--timeout", type=float, default=600, help="Per-request timeout in seconds") + parser.add_argument("--output", default=None, help="Optional path to save raw results JSON") + args = parser.parse_args(argv) + try: + base_url = normalize_base_url(args.url) + endpoint = base_url + "/chat/completions" + extra_inputs = parse_extra_inputs(args.extra_inputs) + args.extra_inputs = extra_inputs + samples = load_samples(args) + except Exception as exc: + print(f"speed_bench: setup failed: {exc}", file=sys.stderr) + return 2 + + print(f"speed_bench: loaded {len(samples)} samples from bench={args.bench} category={args.category}") + + results: list[RequestResult] = [] + started = time.perf_counter() + with concurrent.futures.ThreadPoolExecutor(max_workers=args.concurrency) as executor: + futures = [ + executor.submit(run_one, sample, endpoint, args.model, args.osl, extra_inputs, args.timeout) + for sample in samples + ] + for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="speed_bench", unit="sample"): + result = future.result() + results.append(result) + + elapsed = time.perf_counter() - started + categories = list(dict.fromkeys(sample.category for sample in samples)) + summary = [ + summarize_group(category, [result for result in results if result.category == category]) + for category in categories + ] + summary.append(summarize_group("overall", results)) + print() + print(f"Summary (elapsed={elapsed:.2f}s)") + print_table(summary) + + if args.output: + save_output(args.output, args, samples, results, summary) + print(f"\nspeed_bench: wrote {args.output}") + + failed = sum(1 for result in results if not result.ok) + if failed: + print(f"\nspeed_bench: {failed} samples failed", file=sys.stderr) + first_error = next((result.error for result in results if result.error), None) + if first_error: + print(f"first error: {first_error}", file=sys.stderr) + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/server/bench/speed-bench/speed_bench_compare.py b/tools/server/bench/speed-bench/speed_bench_compare.py new file mode 100644 index 0000000000..070ab57db5 --- /dev/null +++ b/tools/server/bench/speed-bench/speed_bench_compare.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import sys +from typing import Any + +from speed_bench import fmt_value, print_rows + + +def load_summary(path: str) -> list[dict[str, Any]]: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + summary = data.get("summary") + if not isinstance(summary, list): + raise ValueError(f"{path} does not contain a summary list") + return summary + + +def compare_rows(baseline: list[dict[str, Any]], speculative: list[dict[str, Any]]) -> list[dict[str, Any]]: + baseline_by_category = {row["category"]: row for row in baseline} + comparisons = [] + for row in speculative: + base = baseline_by_category.get(row["category"]) + if not base: + continue + base_speed = base.get("avg_pred_t_s") + spec_speed = row.get("avg_pred_t_s") + base_latency = base.get("avg_latency") + spec_latency = row.get("avg_latency") + comparisons.append( + { + "category": row["category"], + "base_avg_pred_t_s": base_speed, + "spec_avg_pred_t_s": spec_speed, + "decode_speedup": (spec_speed / base_speed) if base_speed and spec_speed else None, + "base_avg_latency": base_latency, + "spec_avg_latency": spec_latency, + "latency_speedup": (base_latency / spec_latency) if base_latency and spec_latency else None, + "accept_rate": row.get("accept_rate"), + } + ) + return comparisons + + +def print_comparison(rows: list[dict[str, Any]]) -> None: + if not rows: + print("No overlapping categories found for comparison.") + return + columns = [ + ("category", "category", ""), + ("base_avg_pred_t/s", "base_avg_pred_t_s", "speed"), + ("spec_avg_pred_t/s", "spec_avg_pred_t_s", "speed"), + ("decode_speedup", "decode_speedup", "speedup"), + ("base_avg_latency", "base_avg_latency", "seconds"), + ("spec_avg_latency", "spec_avg_latency", "seconds"), + ("latency_speedup", "latency_speedup", "speedup"), + ("accept_rate", "accept_rate", "rate"), + ] + print_rows(rows, columns) + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description="Compare two SPEED-Bench runs (baseline vs speculative).") + parser.add_argument("--baseline", required=True, help="Baseline results JSON produced by speed_bench.py --output") + parser.add_argument("--speculative", required=True, help="Speculative decoding results JSON produced by speed_bench.py --output") + args = parser.parse_args(argv) + + try: + baseline = load_summary(args.baseline) + speculative = load_summary(args.speculative) + except Exception as exc: + print(f"speed_bench_compare: failed to load inputs: {exc}", file=sys.stderr) + return 2 + + comparisons = compare_rows(baseline, speculative) + print(f"Comparison: baseline={args.baseline} speculative={args.speculative}") + print_comparison(comparisons) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())