diff options
| author | Claude <claude@anthropic.com> | 2026-03-04 19:14:55 +0100 |
|---|---|---|
| committer | Claude <claude@anthropic.com> | 2026-03-04 19:14:55 +0100 |
| commit | 171c5b86ef05974426ba5c5d8547c8025977d1a2 (patch) | |
| tree | 2a1193e2bb81a6341e55d0b883a3fc33f77f8be1 /test | |
| parent | 9f14edf2b97286e02830d528038b32d5b31aaa0a (diff) | |
| parent | 0278c87f062a9ae7d617b92be22b175558a05086 (diff) | |
| download | gemini-py-main.tar.gz gemini-py-main.zip | |
Diffstat (limited to 'test')
| -rw-r--r-- | test/Dockerfile.gemini-cli | 11 | ||||
| -rw-r--r-- | test/Dockerfile.gemini-py | 15 | ||||
| -rw-r--r-- | test/entrypoint.sh | 48 | ||||
| -rw-r--r-- | test/gemini-config/settings.json | 7 | ||||
| -rw-r--r-- | test/mitmproxy/capture.py | 77 | ||||
| -rw-r--r-- | test/scripts/compare.py | 159 | ||||
| -rw-r--r-- | test/scripts/send_test.py | 36 | ||||
| -rw-r--r-- | test/scripts/tool_test.py | 100 |
8 files changed, 453 insertions, 0 deletions
diff --git a/test/Dockerfile.gemini-cli b/test/Dockerfile.gemini-cli new file mode 100644 index 0000000..44e0c38 --- /dev/null +++ b/test/Dockerfile.gemini-cli @@ -0,0 +1,11 @@ +FROM node:20-bookworm + +RUN npm install -g @google/gemini-cli + +COPY test/entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +WORKDIR /workspace + +ENTRYPOINT ["/entrypoint.sh"] +CMD ["sh", "-c", "gemini --yolo -p \"${TEST_PROMPT:-Say hi in exactly 3 words}\""] diff --git a/test/Dockerfile.gemini-py b/test/Dockerfile.gemini-py new file mode 100644 index 0000000..fd256c7 --- /dev/null +++ b/test/Dockerfile.gemini-py @@ -0,0 +1,15 @@ +FROM python:3.12-bookworm + +WORKDIR /app + +COPY pyproject.toml README.md ./ +COPY src/ ./src/ + +RUN pip install --no-cache-dir . + +COPY test/scripts/ /app/scripts/ +COPY test/entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +ENTRYPOINT ["/entrypoint.sh"] +CMD ["python", "/app/scripts/send_test.py"] diff --git a/test/entrypoint.sh b/test/entrypoint.sh new file mode 100644 index 0000000..6fa7518 --- /dev/null +++ b/test/entrypoint.sh @@ -0,0 +1,48 @@ +#!/bin/bash +set -e + +CERT_PATH="/mitmproxy-ca/mitmproxy-ca-cert.pem" + +echo "Waiting for mitmproxy CA certificate..." +for i in $(seq 1 30); do + if [ -f "$CERT_PATH" ]; then + echo "Found mitmproxy CA cert." + break + fi + sleep 1 +done + +if [ ! -f "$CERT_PATH" ]; then + echo "WARNING: mitmproxy CA cert not found at $CERT_PATH" +fi + +export SSL_CERT_FILE="$CERT_PATH" +export REQUESTS_CA_BUNDLE="$CERT_PATH" +export NODE_EXTRA_CA_CERTS="$CERT_PATH" + +# Build ~/.gemini for gemini-cli from env vars +if [ -n "$GEMINI_REFRESH_TOKEN" ]; then + GEMINI_DIR="/tmp/gemini-home/.gemini" + mkdir -p "$GEMINI_DIR" + + python3 - <<EOF +import json, os, time +data = { + "refresh_token": os.environ["GEMINI_REFRESH_TOKEN"], + "access_token": os.environ.get("GEMINI_ACCESS_TOKEN", ""), + "expiry_date": int(os.environ.get("GEMINI_TOKEN_EXPIRY", "0")), + "token_type": "Bearer", +} +with open("$GEMINI_DIR/oauth_creds.json", "w") as f: + json.dump(data, f, indent=2) +os.chmod("$GEMINI_DIR/oauth_creds.json", 0o600) +EOF + + if [ -f "/gemini-settings/settings.json" ]; then + cp /gemini-settings/settings.json "$GEMINI_DIR/settings.json" + fi + + export HOME="/tmp/gemini-home" +fi + +exec "$@" diff --git a/test/gemini-config/settings.json b/test/gemini-config/settings.json new file mode 100644 index 0000000..3032793 --- /dev/null +++ b/test/gemini-config/settings.json @@ -0,0 +1,7 @@ +{ + "security": { + "auth": { + "selectedType": "oauth-personal" + } + } +} diff --git a/test/mitmproxy/capture.py b/test/mitmproxy/capture.py new file mode 100644 index 0000000..5ddd756 --- /dev/null +++ b/test/mitmproxy/capture.py @@ -0,0 +1,77 @@ +import json +import os +import time + +from mitmproxy import http + + +class CaptureAddon: + def __init__(self): + self.output_dir = os.environ.get("CAPTURE_DIR", "/captures") + os.makedirs(self.output_dir, exist_ok=True) + self.counter = 0 + + def request(self, flow: http.HTTPFlow): + if "googleapis.com" not in (flow.request.host or ""): + return + + self.counter += 1 + flow.metadata["capture_id"] = self.counter + + body = None + try: + body = json.loads(flow.request.get_text()) + except Exception: + body = flow.request.get_text() + + source = flow.request.headers.get("x-capture-source") or ( + "gemini-cli" if "GeminiCLI" in flow.request.headers.get("user-agent", "") else "unknown" + ) + + data = { + "capture_id": self.counter, + "source": source, + "timestamp": time.time(), + "method": flow.request.method, + "url": flow.request.pretty_url, + "headers": dict(flow.request.headers), + "body": body, + } + + path = os.path.join(self.output_dir, f"req_{self.counter:04d}.json") + with open(path, "w") as f: + json.dump(data, f, indent=2, default=str) + + print( + f"[capture] #{self.counter} [{source}] {flow.request.method} {flow.request.pretty_url}" + ) + + def response(self, flow: http.HTTPFlow): + if "googleapis.com" not in (flow.request.host or ""): + return + + capture_id = flow.metadata.get("capture_id") + if not capture_id: + return + + body = None + try: + body = json.loads(flow.response.get_text()) + except Exception: + body = flow.response.get_text()[:4000] + + path = os.path.join(self.output_dir, f"res_{capture_id:04d}.json") + with open(path, "w") as f: + json.dump( + { + "status": flow.response.status_code, + "headers": dict(flow.response.headers), + "body": body, + }, + f, + indent=2, + default=str, + ) + + +addons = [CaptureAddon()] diff --git a/test/scripts/compare.py b/test/scripts/compare.py new file mode 100644 index 0000000..d6dc858 --- /dev/null +++ b/test/scripts/compare.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +"""Compare captured mitmproxy requests between gemini-cli and gemini-py.""" + +import json +import sys +from pathlib import Path + +CAPTURES_DIR = sys.argv[1] if len(sys.argv) > 1 else "./test/captures" + +IGNORE_HEADERS = { + "authorization", + "content-length", + "host", + "connection", + "proxy-connection", + "proxy-authorization", + "transfer-encoding", + "accept-encoding", + "x-capture-source", + "user-agent", +} + +# Session-specific fields inside request body +VOLATILE_REQUEST_KEYS = {"session_id", "contents"} + + +def load_captures(directory): + caps = [] + for f in sorted(Path(directory).glob("req_*.json")): + with open(f) as fh: + d = json.load(fh) + # Only compare the main content generation calls + url = d.get("url", "") + u = url.lower() + if "generatecontent" in u and "loadcodeassist" not in u and "counttokens" not in u: + caps.append(d) + return caps + + +def normalize_headers(headers): + return {k.lower(): v for k, v in headers.items() if k.lower() not in IGNORE_HEADERS} + + +def normalize_body(body): + if not isinstance(body, dict): + return body + out = {} + for k, v in body.items(): + if k == "user_prompt_id": + out[k] = "<redacted>" + elif k == "request" and isinstance(v, dict): + out[k] = { + rk: ("<redacted>" if rk in VOLATILE_REQUEST_KEYS else rv) for rk, rv in v.items() + } + else: + out[k] = v + return out + + +def fmt(v): + if isinstance(v, (dict, list)): + return json.dumps(v, indent=2, default=str) + return str(v) + + +def compare_dicts(label, d1, d2, n1, n2): + keys = sorted(set(list(d1.keys()) + list(d2.keys()))) + diffs = 0 + for key in keys: + v1 = d1.get(key, "<missing>") + v2 = d2.get(key, "<missing>") + if v1 == v2: + print(f" \033[32m✓\033[0m {key}") + else: + diffs += 1 + print(f" \033[31m✗\033[0m {key}") + print(f" {n1}: {fmt(v1)}") + print(f" {n2}: {fmt(v2)}") + if diffs == 0: + print(f" \033[32m{label}: IDENTICAL\033[0m") + else: + print(f" \033[31m{label}: {diffs} difference(s)\033[0m") + return diffs + + +def dump_request(cap): + src = cap.get("source", "?") + print(f" Source: {src}") + print(f" URL: {cap['url']}") + headers = normalize_headers(cap.get("headers", {})) + body = cap.get("body", {}) + print(f" Headers ({len(headers)}):") + for k in sorted(headers): + print(f" {k}: {headers[k]}") + if isinstance(body, dict): + print(f" Body keys: {sorted(body.keys())}") + print(f" model: {body.get('model', 'N/A')}") + print(f" project: {body.get('project', 'N/A')}") + req = body.get("request", {}) + if isinstance(req, dict): + print(f" generationConfig: {req.get('generationConfig', 'N/A')}") + + +def main(): + caps = load_captures(CAPTURES_DIR) + if not caps: + print(f"No generateContent captures found in {CAPTURES_DIR}") + print("Run gemini-cli and gemini-py through the proxy first.") + sys.exit(1) + + by_source = {} + for c in caps: + src = c.get("source", "unknown") + by_source.setdefault(src, []).append(c) + + sources = {k: len(v) for k, v in by_source.items()} + print(f"Found {len(caps)} generateContent request(s)") + print(f"Sources: {sources}\n") + + for i, cap in enumerate(caps): + print(f"--- Request #{cap.get('capture_id', i + 1)} ---") + dump_request(cap) + print() + + sources = list(by_source.keys()) + if len(sources) == 2 and all(len(v) >= 1 for v in by_source.values()): + n1, n2 = sources + + def pick(caps): + # Prefer streamGenerateContent (actual response) over routing calls + streaming = [c for c in caps if "streamgeneratecontent" in c.get("url", "").lower()] + return streaming[-1] if streaming else caps[-1] + + c1, c2 = pick(by_source[n1]), pick(by_source[n2]) + + print("=" * 60) + print(f"COMPARING: {n1} vs {n2}") + print("=" * 60) + + h1 = normalize_headers(c1.get("headers", {})) + h2 = normalize_headers(c2.get("headers", {})) + total = compare_dicts("Headers", h1, h2, n1, n2) + + b1 = normalize_body(c1.get("body", {})) + b2 = normalize_body(c2.get("body", {})) + total += compare_dicts("Body", b1, b2, n1, n2) + + print(f"\n{'=' * 60}") + if total == 0: + print("\033[32mRESULT: Requests are identical on the wire.\033[0m") + else: + print(f"\033[31mRESULT: {total} total difference(s).\033[0m") + sys.exit(0 if total == 0 else 1) + else: + print("Tip: Clear test/captures/ and run one request from each client to compare.") + + +if __name__ == "__main__": + main() diff --git a/test/scripts/send_test.py b/test/scripts/send_test.py new file mode 100644 index 0000000..2b5fe53 --- /dev/null +++ b/test/scripts/send_test.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +import asyncio +import os + +from gemini import GeminiClient, GeminiOptions + +PROMPT = os.environ.get("TEST_PROMPT", "Say hi in exactly 3 words") + +CONFIGS = [ + {"label": "flash-lite", "model": "gemini-2.5-flash-lite"}, + {"label": "flash", "model": "gemini-2.5-flash"}, + {"label": "pro", "model": "gemini-2.5-pro"}, +] + + +async def run_one(cfg): + opts = GeminiOptions(model=cfg["model"]) + async with GeminiClient(options=opts) as client: + text = "" + async for chunk in client.send_message_stream(PROMPT): + text += chunk.text_delta + print(f" [{cfg['label']}] {text.strip()[:80]}") + + +async def main(): + mode = os.environ.get("TEST_MODE", "single") + configs = CONFIGS if mode == "all" else [CONFIGS[0]] + + for cfg in configs: + print(f">>> {cfg['label']}: model={cfg['model']}") + await run_one(cfg) + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/test/scripts/tool_test.py b/test/scripts/tool_test.py new file mode 100644 index 0000000..2e5d6c9 --- /dev/null +++ b/test/scripts/tool_test.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +import asyncio +import json + +from gemini import FunctionDeclaration, GeminiClient, GeminiOptions + +TOOLS = [ + FunctionDeclaration( + name="get_weather", + description="Get the current weather for a city", + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name, e.g. Paris, Tokyo", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature unit (default: celsius)", + }, + }, + "required": ["location"], + }, + ), +] + +FAKE_WEATHER = { + "Paris": {"temperature": 18, "condition": "cloudy"}, + "London": {"temperature": 14, "condition": "rainy"}, + "Tokyo": {"temperature": 26, "condition": "sunny"}, +} + + +def handle_tool(name: str, args: dict) -> dict: + if name == "get_weather": + city = args.get("location", "") + unit = args.get("unit", "celsius") + data = FAKE_WEATHER.get(city, {"temperature": 20, "condition": "unknown"}) + temp = data["temperature"] + if unit == "fahrenheit": + temp = round(temp * 9 / 5 + 32) + return {"location": city, "temperature": temp, "unit": unit, "condition": data["condition"]} + raise ValueError(f"Unknown tool: {name}") + + +async def run_non_streaming(): + print("=== Non-streaming tool use ===") + opts = GeminiOptions(model="gemini-2.5-flash", tools=TOOLS) + async with GeminiClient(options=opts) as client: + prompt = "What's the weather in Paris and Tokyo? Use celsius." + print(f"User: {prompt}") + response = await client.send_message(prompt) + + while response.tool_calls: + results = [] + for tc in response.tool_calls: + print(f" [tool call] {tc.name}({json.dumps(tc.args)})") + result = handle_tool(tc.name, tc.args) + print(f" [tool result] {json.dumps(result)}") + results.append((tc.name, result)) + response = await client.send_tool_results(results) + + print(f"Model: {response.text.strip()}") + + +async def run_streaming(): + print("\n=== Streaming tool use ===") + opts = GeminiOptions(model="gemini-2.5-flash", tools=TOOLS) + async with GeminiClient(options=opts) as client: + prompt = "What's the weather in London right now?" + print(f"User: {prompt}") + + final = None + async for chunk in client.send_message_stream(prompt): + final = chunk + + if final and final.tool_calls: + results = [] + for tc in final.tool_calls: + print(f" [tool call] {tc.name}({json.dumps(tc.args)})") + result = handle_tool(tc.name, tc.args) + print(f" [tool result] {json.dumps(result)}") + results.append((tc.name, result)) + print("Model: ", end="", flush=True) + async for chunk in client.send_tool_results_stream(results): + print(chunk.text_delta, end="", flush=True) + print() + elif final: + print(f"Model: {final.response.text.strip() if final.response else ''}") + + +async def main(): + await run_non_streaming() + await run_streaming() + + +if __name__ == "__main__": + asyncio.run(main()) |
