aboutsummaryrefslogtreecommitdiffstats
path: root/test/scripts/compare.py
diff options
context:
space:
mode:
authorClaude <claude@anthropic.com>2026-03-04 19:14:55 +0100
committerClaude <claude@anthropic.com>2026-03-04 19:14:55 +0100
commit171c5b86ef05974426ba5c5d8547c8025977d1a2 (patch)
tree2a1193e2bb81a6341e55d0b883a3fc33f77f8be1 /test/scripts/compare.py
parent9f14edf2b97286e02830d528038b32d5b31aaa0a (diff)
parent0278c87f062a9ae7d617b92be22b175558a05086 (diff)
downloadgemini-py-main.tar.gz
gemini-py-main.zip
Add initial versionHEADmain
Diffstat (limited to 'test/scripts/compare.py')
-rw-r--r--test/scripts/compare.py159
1 files changed, 159 insertions, 0 deletions
diff --git a/test/scripts/compare.py b/test/scripts/compare.py
new file mode 100644
index 0000000..d6dc858
--- /dev/null
+++ b/test/scripts/compare.py
@@ -0,0 +1,159 @@
+#!/usr/bin/env python3
+"""Compare captured mitmproxy requests between gemini-cli and gemini-py."""
+
+import json
+import sys
+from pathlib import Path
+
+CAPTURES_DIR = sys.argv[1] if len(sys.argv) > 1 else "./test/captures"
+
+IGNORE_HEADERS = {
+ "authorization",
+ "content-length",
+ "host",
+ "connection",
+ "proxy-connection",
+ "proxy-authorization",
+ "transfer-encoding",
+ "accept-encoding",
+ "x-capture-source",
+ "user-agent",
+}
+
+# Session-specific fields inside request body
+VOLATILE_REQUEST_KEYS = {"session_id", "contents"}
+
+
+def load_captures(directory):
+ caps = []
+ for f in sorted(Path(directory).glob("req_*.json")):
+ with open(f) as fh:
+ d = json.load(fh)
+ # Only compare the main content generation calls
+ url = d.get("url", "")
+ u = url.lower()
+ if "generatecontent" in u and "loadcodeassist" not in u and "counttokens" not in u:
+ caps.append(d)
+ return caps
+
+
+def normalize_headers(headers):
+ return {k.lower(): v for k, v in headers.items() if k.lower() not in IGNORE_HEADERS}
+
+
+def normalize_body(body):
+ if not isinstance(body, dict):
+ return body
+ out = {}
+ for k, v in body.items():
+ if k == "user_prompt_id":
+ out[k] = "<redacted>"
+ elif k == "request" and isinstance(v, dict):
+ out[k] = {
+ rk: ("<redacted>" if rk in VOLATILE_REQUEST_KEYS else rv) for rk, rv in v.items()
+ }
+ else:
+ out[k] = v
+ return out
+
+
+def fmt(v):
+ if isinstance(v, (dict, list)):
+ return json.dumps(v, indent=2, default=str)
+ return str(v)
+
+
+def compare_dicts(label, d1, d2, n1, n2):
+ keys = sorted(set(list(d1.keys()) + list(d2.keys())))
+ diffs = 0
+ for key in keys:
+ v1 = d1.get(key, "<missing>")
+ v2 = d2.get(key, "<missing>")
+ if v1 == v2:
+ print(f" \033[32m✓\033[0m {key}")
+ else:
+ diffs += 1
+ print(f" \033[31m✗\033[0m {key}")
+ print(f" {n1}: {fmt(v1)}")
+ print(f" {n2}: {fmt(v2)}")
+ if diffs == 0:
+ print(f" \033[32m{label}: IDENTICAL\033[0m")
+ else:
+ print(f" \033[31m{label}: {diffs} difference(s)\033[0m")
+ return diffs
+
+
+def dump_request(cap):
+ src = cap.get("source", "?")
+ print(f" Source: {src}")
+ print(f" URL: {cap['url']}")
+ headers = normalize_headers(cap.get("headers", {}))
+ body = cap.get("body", {})
+ print(f" Headers ({len(headers)}):")
+ for k in sorted(headers):
+ print(f" {k}: {headers[k]}")
+ if isinstance(body, dict):
+ print(f" Body keys: {sorted(body.keys())}")
+ print(f" model: {body.get('model', 'N/A')}")
+ print(f" project: {body.get('project', 'N/A')}")
+ req = body.get("request", {})
+ if isinstance(req, dict):
+ print(f" generationConfig: {req.get('generationConfig', 'N/A')}")
+
+
+def main():
+ caps = load_captures(CAPTURES_DIR)
+ if not caps:
+ print(f"No generateContent captures found in {CAPTURES_DIR}")
+ print("Run gemini-cli and gemini-py through the proxy first.")
+ sys.exit(1)
+
+ by_source = {}
+ for c in caps:
+ src = c.get("source", "unknown")
+ by_source.setdefault(src, []).append(c)
+
+ sources = {k: len(v) for k, v in by_source.items()}
+ print(f"Found {len(caps)} generateContent request(s)")
+ print(f"Sources: {sources}\n")
+
+ for i, cap in enumerate(caps):
+ print(f"--- Request #{cap.get('capture_id', i + 1)} ---")
+ dump_request(cap)
+ print()
+
+ sources = list(by_source.keys())
+ if len(sources) == 2 and all(len(v) >= 1 for v in by_source.values()):
+ n1, n2 = sources
+
+ def pick(caps):
+ # Prefer streamGenerateContent (actual response) over routing calls
+ streaming = [c for c in caps if "streamgeneratecontent" in c.get("url", "").lower()]
+ return streaming[-1] if streaming else caps[-1]
+
+ c1, c2 = pick(by_source[n1]), pick(by_source[n2])
+
+ print("=" * 60)
+ print(f"COMPARING: {n1} vs {n2}")
+ print("=" * 60)
+
+ h1 = normalize_headers(c1.get("headers", {}))
+ h2 = normalize_headers(c2.get("headers", {}))
+ total = compare_dicts("Headers", h1, h2, n1, n2)
+
+ b1 = normalize_body(c1.get("body", {}))
+ b2 = normalize_body(c2.get("body", {}))
+ total += compare_dicts("Body", b1, b2, n1, n2)
+
+ print(f"\n{'=' * 60}")
+ if total == 0:
+ print("\033[32mRESULT: Requests are identical on the wire.\033[0m")
+ else:
+ print(f"\033[31mRESULT: {total} total difference(s).\033[0m")
+ sys.exit(0 if total == 0 else 1)
+ else:
+ print("Tip: Clear test/captures/ and run one request from each client to compare.")
+
+
+if __name__ == "__main__":
+ main()