aboutsummaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
authorClaude <claude@anthropic.com>2026-03-04 19:14:55 +0100
committerClaude <claude@anthropic.com>2026-03-04 19:14:55 +0100
commit171c5b86ef05974426ba5c5d8547c8025977d1a2 (patch)
tree2a1193e2bb81a6341e55d0b883a3fc33f77f8be1 /test
parent9f14edf2b97286e02830d528038b32d5b31aaa0a (diff)
parent0278c87f062a9ae7d617b92be22b175558a05086 (diff)
downloadgemini-py-main.tar.gz
gemini-py-main.zip
Add initial versionHEADmain
Diffstat (limited to 'test')
-rw-r--r--test/Dockerfile.gemini-cli11
-rw-r--r--test/Dockerfile.gemini-py15
-rw-r--r--test/entrypoint.sh48
-rw-r--r--test/gemini-config/settings.json7
-rw-r--r--test/mitmproxy/capture.py77
-rw-r--r--test/scripts/compare.py159
-rw-r--r--test/scripts/send_test.py36
-rw-r--r--test/scripts/tool_test.py100
8 files changed, 453 insertions, 0 deletions
diff --git a/test/Dockerfile.gemini-cli b/test/Dockerfile.gemini-cli
new file mode 100644
index 0000000..44e0c38
--- /dev/null
+++ b/test/Dockerfile.gemini-cli
@@ -0,0 +1,11 @@
+FROM node:20-bookworm
+
+RUN npm install -g @google/gemini-cli
+
+COPY test/entrypoint.sh /entrypoint.sh
+RUN chmod +x /entrypoint.sh
+
+WORKDIR /workspace
+
+ENTRYPOINT ["/entrypoint.sh"]
+CMD ["sh", "-c", "gemini --yolo -p \"${TEST_PROMPT:-Say hi in exactly 3 words}\""]
diff --git a/test/Dockerfile.gemini-py b/test/Dockerfile.gemini-py
new file mode 100644
index 0000000..fd256c7
--- /dev/null
+++ b/test/Dockerfile.gemini-py
@@ -0,0 +1,15 @@
+FROM python:3.12-bookworm
+
+WORKDIR /app
+
+COPY pyproject.toml README.md ./
+COPY src/ ./src/
+
+RUN pip install --no-cache-dir .
+
+COPY test/scripts/ /app/scripts/
+COPY test/entrypoint.sh /entrypoint.sh
+RUN chmod +x /entrypoint.sh
+
+ENTRYPOINT ["/entrypoint.sh"]
+CMD ["python", "/app/scripts/send_test.py"]
diff --git a/test/entrypoint.sh b/test/entrypoint.sh
new file mode 100644
index 0000000..6fa7518
--- /dev/null
+++ b/test/entrypoint.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+set -e
+
+CERT_PATH="/mitmproxy-ca/mitmproxy-ca-cert.pem"
+
+echo "Waiting for mitmproxy CA certificate..."
+for i in $(seq 1 30); do
+ if [ -f "$CERT_PATH" ]; then
+ echo "Found mitmproxy CA cert."
+ break
+ fi
+ sleep 1
+done
+
+if [ ! -f "$CERT_PATH" ]; then
+ echo "WARNING: mitmproxy CA cert not found at $CERT_PATH"
+fi
+
+export SSL_CERT_FILE="$CERT_PATH"
+export REQUESTS_CA_BUNDLE="$CERT_PATH"
+export NODE_EXTRA_CA_CERTS="$CERT_PATH"
+
+# Build ~/.gemini for gemini-cli from env vars
+if [ -n "$GEMINI_REFRESH_TOKEN" ]; then
+ GEMINI_DIR="/tmp/gemini-home/.gemini"
+ mkdir -p "$GEMINI_DIR"
+
+ python3 - <<EOF
+import json, os, time
+data = {
+ "refresh_token": os.environ["GEMINI_REFRESH_TOKEN"],
+ "access_token": os.environ.get("GEMINI_ACCESS_TOKEN", ""),
+ "expiry_date": int(os.environ.get("GEMINI_TOKEN_EXPIRY", "0")),
+ "token_type": "Bearer",
+}
+with open("$GEMINI_DIR/oauth_creds.json", "w") as f:
+ json.dump(data, f, indent=2)
+os.chmod("$GEMINI_DIR/oauth_creds.json", 0o600)
+EOF
+
+ if [ -f "/gemini-settings/settings.json" ]; then
+ cp /gemini-settings/settings.json "$GEMINI_DIR/settings.json"
+ fi
+
+ export HOME="/tmp/gemini-home"
+fi
+
+exec "$@"
diff --git a/test/gemini-config/settings.json b/test/gemini-config/settings.json
new file mode 100644
index 0000000..3032793
--- /dev/null
+++ b/test/gemini-config/settings.json
@@ -0,0 +1,7 @@
+{
+ "security": {
+ "auth": {
+ "selectedType": "oauth-personal"
+ }
+ }
+}
diff --git a/test/mitmproxy/capture.py b/test/mitmproxy/capture.py
new file mode 100644
index 0000000..5ddd756
--- /dev/null
+++ b/test/mitmproxy/capture.py
@@ -0,0 +1,77 @@
+import json
+import os
+import time
+
+from mitmproxy import http
+
+
+class CaptureAddon:
+ def __init__(self):
+ self.output_dir = os.environ.get("CAPTURE_DIR", "/captures")
+ os.makedirs(self.output_dir, exist_ok=True)
+ self.counter = 0
+
+ def request(self, flow: http.HTTPFlow):
+ if "googleapis.com" not in (flow.request.host or ""):
+ return
+
+ self.counter += 1
+ flow.metadata["capture_id"] = self.counter
+
+ body = None
+ try:
+ body = json.loads(flow.request.get_text())
+ except Exception:
+ body = flow.request.get_text()
+
+ source = flow.request.headers.get("x-capture-source") or (
+ "gemini-cli" if "GeminiCLI" in flow.request.headers.get("user-agent", "") else "unknown"
+ )
+
+ data = {
+ "capture_id": self.counter,
+ "source": source,
+ "timestamp": time.time(),
+ "method": flow.request.method,
+ "url": flow.request.pretty_url,
+ "headers": dict(flow.request.headers),
+ "body": body,
+ }
+
+ path = os.path.join(self.output_dir, f"req_{self.counter:04d}.json")
+ with open(path, "w") as f:
+ json.dump(data, f, indent=2, default=str)
+
+ print(
+ f"[capture] #{self.counter} [{source}] {flow.request.method} {flow.request.pretty_url}"
+ )
+
+ def response(self, flow: http.HTTPFlow):
+ if "googleapis.com" not in (flow.request.host or ""):
+ return
+
+ capture_id = flow.metadata.get("capture_id")
+ if not capture_id:
+ return
+
+ body = None
+ try:
+ body = json.loads(flow.response.get_text())
+ except Exception:
+ body = flow.response.get_text()[:4000]
+
+ path = os.path.join(self.output_dir, f"res_{capture_id:04d}.json")
+ with open(path, "w") as f:
+ json.dump(
+ {
+ "status": flow.response.status_code,
+ "headers": dict(flow.response.headers),
+ "body": body,
+ },
+ f,
+ indent=2,
+ default=str,
+ )
+
+
+addons = [CaptureAddon()]
diff --git a/test/scripts/compare.py b/test/scripts/compare.py
new file mode 100644
index 0000000..d6dc858
--- /dev/null
+++ b/test/scripts/compare.py
@@ -0,0 +1,159 @@
+#!/usr/bin/env python3
+"""Compare captured mitmproxy requests between gemini-cli and gemini-py."""
+
+import json
+import sys
+from pathlib import Path
+
+CAPTURES_DIR = sys.argv[1] if len(sys.argv) > 1 else "./test/captures"
+
+IGNORE_HEADERS = {
+ "authorization",
+ "content-length",
+ "host",
+ "connection",
+ "proxy-connection",
+ "proxy-authorization",
+ "transfer-encoding",
+ "accept-encoding",
+ "x-capture-source",
+ "user-agent",
+}
+
+# Session-specific fields inside request body
+VOLATILE_REQUEST_KEYS = {"session_id", "contents"}
+
+
+def load_captures(directory):
+ caps = []
+ for f in sorted(Path(directory).glob("req_*.json")):
+ with open(f) as fh:
+ d = json.load(fh)
+ # Only compare the main content generation calls
+ url = d.get("url", "")
+ u = url.lower()
+ if "generatecontent" in u and "loadcodeassist" not in u and "counttokens" not in u:
+ caps.append(d)
+ return caps
+
+
+def normalize_headers(headers):
+ return {k.lower(): v for k, v in headers.items() if k.lower() not in IGNORE_HEADERS}
+
+
+def normalize_body(body):
+ if not isinstance(body, dict):
+ return body
+ out = {}
+ for k, v in body.items():
+ if k == "user_prompt_id":
+ out[k] = "<redacted>"
+ elif k == "request" and isinstance(v, dict):
+ out[k] = {
+ rk: ("<redacted>" if rk in VOLATILE_REQUEST_KEYS else rv) for rk, rv in v.items()
+ }
+ else:
+ out[k] = v
+ return out
+
+
+def fmt(v):
+ if isinstance(v, (dict, list)):
+ return json.dumps(v, indent=2, default=str)
+ return str(v)
+
+
+def compare_dicts(label, d1, d2, n1, n2):
+ keys = sorted(set(list(d1.keys()) + list(d2.keys())))
+ diffs = 0
+ for key in keys:
+ v1 = d1.get(key, "<missing>")
+ v2 = d2.get(key, "<missing>")
+ if v1 == v2:
+ print(f" \033[32m✓\033[0m {key}")
+ else:
+ diffs += 1
+ print(f" \033[31m✗\033[0m {key}")
+ print(f" {n1}: {fmt(v1)}")
+ print(f" {n2}: {fmt(v2)}")
+ if diffs == 0:
+ print(f" \033[32m{label}: IDENTICAL\033[0m")
+ else:
+ print(f" \033[31m{label}: {diffs} difference(s)\033[0m")
+ return diffs
+
+
+def dump_request(cap):
+ src = cap.get("source", "?")
+ print(f" Source: {src}")
+ print(f" URL: {cap['url']}")
+ headers = normalize_headers(cap.get("headers", {}))
+ body = cap.get("body", {})
+ print(f" Headers ({len(headers)}):")
+ for k in sorted(headers):
+ print(f" {k}: {headers[k]}")
+ if isinstance(body, dict):
+ print(f" Body keys: {sorted(body.keys())}")
+ print(f" model: {body.get('model', 'N/A')}")
+ print(f" project: {body.get('project', 'N/A')}")
+ req = body.get("request", {})
+ if isinstance(req, dict):
+ print(f" generationConfig: {req.get('generationConfig', 'N/A')}")
+
+
+def main():
+ caps = load_captures(CAPTURES_DIR)
+ if not caps:
+ print(f"No generateContent captures found in {CAPTURES_DIR}")
+ print("Run gemini-cli and gemini-py through the proxy first.")
+ sys.exit(1)
+
+ by_source = {}
+ for c in caps:
+ src = c.get("source", "unknown")
+ by_source.setdefault(src, []).append(c)
+
+ sources = {k: len(v) for k, v in by_source.items()}
+ print(f"Found {len(caps)} generateContent request(s)")
+ print(f"Sources: {sources}\n")
+
+ for i, cap in enumerate(caps):
+ print(f"--- Request #{cap.get('capture_id', i + 1)} ---")
+ dump_request(cap)
+ print()
+
+ sources = list(by_source.keys())
+ if len(sources) == 2 and all(len(v) >= 1 for v in by_source.values()):
+ n1, n2 = sources
+
+ def pick(caps):
+ # Prefer streamGenerateContent (actual response) over routing calls
+ streaming = [c for c in caps if "streamgeneratecontent" in c.get("url", "").lower()]
+ return streaming[-1] if streaming else caps[-1]
+
+ c1, c2 = pick(by_source[n1]), pick(by_source[n2])
+
+ print("=" * 60)
+ print(f"COMPARING: {n1} vs {n2}")
+ print("=" * 60)
+
+ h1 = normalize_headers(c1.get("headers", {}))
+ h2 = normalize_headers(c2.get("headers", {}))
+ total = compare_dicts("Headers", h1, h2, n1, n2)
+
+ b1 = normalize_body(c1.get("body", {}))
+ b2 = normalize_body(c2.get("body", {}))
+ total += compare_dicts("Body", b1, b2, n1, n2)
+
+ print(f"\n{'=' * 60}")
+ if total == 0:
+ print("\033[32mRESULT: Requests are identical on the wire.\033[0m")
+ else:
+ print(f"\033[31mRESULT: {total} total difference(s).\033[0m")
+ sys.exit(0 if total == 0 else 1)
+ else:
+ print("Tip: Clear test/captures/ and run one request from each client to compare.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/scripts/send_test.py b/test/scripts/send_test.py
new file mode 100644
index 0000000..2b5fe53
--- /dev/null
+++ b/test/scripts/send_test.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python3
+import asyncio
+import os
+
+from gemini import GeminiClient, GeminiOptions
+
+PROMPT = os.environ.get("TEST_PROMPT", "Say hi in exactly 3 words")
+
+CONFIGS = [
+ {"label": "flash-lite", "model": "gemini-2.5-flash-lite"},
+ {"label": "flash", "model": "gemini-2.5-flash"},
+ {"label": "pro", "model": "gemini-2.5-pro"},
+]
+
+
+async def run_one(cfg):
+ opts = GeminiOptions(model=cfg["model"])
+ async with GeminiClient(options=opts) as client:
+ text = ""
+ async for chunk in client.send_message_stream(PROMPT):
+ text += chunk.text_delta
+ print(f" [{cfg['label']}] {text.strip()[:80]}")
+
+
+async def main():
+ mode = os.environ.get("TEST_MODE", "single")
+ configs = CONFIGS if mode == "all" else [CONFIGS[0]]
+
+ for cfg in configs:
+ print(f">>> {cfg['label']}: model={cfg['model']}")
+ await run_one(cfg)
+ print()
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/test/scripts/tool_test.py b/test/scripts/tool_test.py
new file mode 100644
index 0000000..2e5d6c9
--- /dev/null
+++ b/test/scripts/tool_test.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python3
+import asyncio
+import json
+
+from gemini import FunctionDeclaration, GeminiClient, GeminiOptions
+
+TOOLS = [
+ FunctionDeclaration(
+ name="get_weather",
+ description="Get the current weather for a city",
+ parameters={
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "City name, e.g. Paris, Tokyo",
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "Temperature unit (default: celsius)",
+ },
+ },
+ "required": ["location"],
+ },
+ ),
+]
+
+FAKE_WEATHER = {
+ "Paris": {"temperature": 18, "condition": "cloudy"},
+ "London": {"temperature": 14, "condition": "rainy"},
+ "Tokyo": {"temperature": 26, "condition": "sunny"},
+}
+
+
+def handle_tool(name: str, args: dict) -> dict:
+ if name == "get_weather":
+ city = args.get("location", "")
+ unit = args.get("unit", "celsius")
+ data = FAKE_WEATHER.get(city, {"temperature": 20, "condition": "unknown"})
+ temp = data["temperature"]
+ if unit == "fahrenheit":
+ temp = round(temp * 9 / 5 + 32)
+ return {"location": city, "temperature": temp, "unit": unit, "condition": data["condition"]}
+ raise ValueError(f"Unknown tool: {name}")
+
+
+async def run_non_streaming():
+ print("=== Non-streaming tool use ===")
+ opts = GeminiOptions(model="gemini-2.5-flash", tools=TOOLS)
+ async with GeminiClient(options=opts) as client:
+ prompt = "What's the weather in Paris and Tokyo? Use celsius."
+ print(f"User: {prompt}")
+ response = await client.send_message(prompt)
+
+ while response.tool_calls:
+ results = []
+ for tc in response.tool_calls:
+ print(f" [tool call] {tc.name}({json.dumps(tc.args)})")
+ result = handle_tool(tc.name, tc.args)
+ print(f" [tool result] {json.dumps(result)}")
+ results.append((tc.name, result))
+ response = await client.send_tool_results(results)
+
+ print(f"Model: {response.text.strip()}")
+
+
+async def run_streaming():
+ print("\n=== Streaming tool use ===")
+ opts = GeminiOptions(model="gemini-2.5-flash", tools=TOOLS)
+ async with GeminiClient(options=opts) as client:
+ prompt = "What's the weather in London right now?"
+ print(f"User: {prompt}")
+
+ final = None
+ async for chunk in client.send_message_stream(prompt):
+ final = chunk
+
+ if final and final.tool_calls:
+ results = []
+ for tc in final.tool_calls:
+ print(f" [tool call] {tc.name}({json.dumps(tc.args)})")
+ result = handle_tool(tc.name, tc.args)
+ print(f" [tool result] {json.dumps(result)}")
+ results.append((tc.name, result))
+ print("Model: ", end="", flush=True)
+ async for chunk in client.send_tool_results_stream(results):
+ print(chunk.text_delta, end="", flush=True)
+ print()
+ elif final:
+ print(f"Model: {final.response.text.strip() if final.response else ''}")
+
+
+async def main():
+ await run_non_streaming()
+ await run_streaming()
+
+
+if __name__ == "__main__":
+ asyncio.run(main())