diff options
| author | Claude <claude@anthropic.com> | 2026-03-04 19:14:55 +0100 |
|---|---|---|
| committer | Claude <claude@anthropic.com> | 2026-03-04 19:14:55 +0100 |
| commit | 171c5b86ef05974426ba5c5d8547c8025977d1a2 (patch) | |
| tree | 2a1193e2bb81a6341e55d0b883a3fc33f77f8be1 /src | |
| parent | 9f14edf2b97286e02830d528038b32d5b31aaa0a (diff) | |
| parent | 0278c87f062a9ae7d617b92be22b175558a05086 (diff) | |
| download | gemini-py-main.tar.gz gemini-py-main.zip | |
Diffstat (limited to 'src')
| -rw-r--r-- | src/gemini/__init__.py | 53 | ||||
| -rw-r--r-- | src/gemini/auth.py | 91 | ||||
| -rw-r--r-- | src/gemini/cli.py | 60 | ||||
| -rw-r--r-- | src/gemini/client.py | 285 | ||||
| -rw-r--r-- | src/gemini/models.py | 24 | ||||
| -rw-r--r-- | src/gemini/streaming.py | 86 | ||||
| -rw-r--r-- | src/gemini/types.py | 175 |
7 files changed, 774 insertions, 0 deletions
diff --git a/src/gemini/__init__.py b/src/gemini/__init__.py new file mode 100644 index 0000000..8c78cd7 --- /dev/null +++ b/src/gemini/__init__.py @@ -0,0 +1,53 @@ +""" +gemini-py: Python SDK for Gemini API. + +Reverse-engineered from Gemini CLI v0.31.0. Uses OAuth credentials from +~/.gemini/oauth_creds.json (created by logging in via the Gemini CLI). + +Example: + import asyncio + from gemini import GeminiClient + + async def main(): + async with GeminiClient() as client: + # Streaming + async for chunk in client.send_message_stream("Hello!"): + print(chunk.text_delta, end="", flush=True) + print() + + # Non-streaming + response = await client.send_message("What is 2+2?") + print(response.text) + + asyncio.run(main()) +""" + +from .client import GeminiClient, query +from .models import Model, list_models +from .types import ( + Content, + FunctionDeclaration, + GeminiOptions, + GenerateContentResponse, + GenerationConfig, + StreamChunk, + ToolCall, + UsageMetadata, +) + +__version__ = "0.1.0" + +__all__ = [ + "GeminiClient", + "query", + "Model", + "list_models", + "GeminiOptions", + "GenerateContentResponse", + "StreamChunk", + "Content", + "GenerationConfig", + "UsageMetadata", + "FunctionDeclaration", + "ToolCall", +] diff --git a/src/gemini/auth.py b/src/gemini/auth.py new file mode 100644 index 0000000..bfbd85b --- /dev/null +++ b/src/gemini/auth.py @@ -0,0 +1,91 @@ +import json +import os +import time +from pathlib import Path + +import httpx + +OAUTH_CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" +OAUTH_CLIENT_SECRET = "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" +TOKEN_URL = "https://oauth2.googleapis.com/token" +DEFAULT_CREDS_PATH = Path.home() / ".gemini" / "oauth_creds.json" + + +class OAuthCredentials: + """ + Resolves credentials in priority order: + 1. Environment variables (GEMINI_REFRESH_TOKEN, GEMINI_ACCESS_TOKEN, GEMINI_TOKEN_EXPIRY) + 2. Credentials file (path argument or ~/.gemini/oauth_creds.json) + """ + + def __init__(self, creds_path: str | None = None): + self._path = Path(creds_path) if creds_path else DEFAULT_CREDS_PATH + self._data: dict = {} + self._from_env = False + self._load() + + def _load(self): + refresh = os.environ.get("GEMINI_REFRESH_TOKEN") + if refresh: + self._from_env = True + self._data = { + "refresh_token": refresh, + "access_token": os.environ.get("GEMINI_ACCESS_TOKEN", ""), + "expiry_date": int(os.environ.get("GEMINI_TOKEN_EXPIRY", "0")), + } + return + + if self._path.exists(): + self._data = json.loads(self._path.read_text()) + + def _save(self): + # Don't write back when credentials came from env vars + if self._from_env: + return + self._path.parent.mkdir(parents=True, exist_ok=True) + self._path.write_text(json.dumps(self._data, indent=2)) + os.chmod(self._path, 0o600) + + @property + def refresh_token(self) -> str | None: + return self._data.get("refresh_token") + + @property + def access_token(self) -> str | None: + return self._data.get("access_token") + + @property + def expiry_ms(self) -> int: + return self._data.get("expiry_date", 0) + + def is_expired(self) -> bool: + if not self.access_token: + return True + return time.time() * 1000 >= (self.expiry_ms - 300_000) + + async def get_valid_token(self, client: httpx.AsyncClient) -> str: + if self.is_expired(): + await self._refresh(client) + return self.access_token or "" + + async def _refresh(self, client: httpx.AsyncClient): + if not self.refresh_token: + raise ValueError( + "No refresh token found. Set GEMINI_REFRESH_TOKEN or provide a credentials file.\n" + "Obtain it by running: gemini (and completing the login flow)" + ) + resp = await client.post( + TOKEN_URL, + data={ + "refresh_token": self.refresh_token, + "client_id": OAUTH_CLIENT_ID, + "client_secret": OAUTH_CLIENT_SECRET, + "grant_type": "refresh_token", + }, + ) + resp.raise_for_status() + tokens = resp.json() + self._data.update(tokens) + if "expires_in" in tokens: + self._data["expiry_date"] = int((time.time() + tokens["expires_in"]) * 1000) + self._save() diff --git a/src/gemini/cli.py b/src/gemini/cli.py new file mode 100644 index 0000000..75e22a7 --- /dev/null +++ b/src/gemini/cli.py @@ -0,0 +1,60 @@ +import argparse +import asyncio +import sys + +from .client import GeminiClient +from .models import list_models +from .types import GeminiOptions + + +def parse_args(): + p = argparse.ArgumentParser(prog="gemini", description="Query Gemini via CLI") + p.add_argument("prompt", nargs="?", help="Prompt to send (reads stdin if omitted)") + p.add_argument("-m", "--model", default="gemini-2.5-pro", help="Model name") + p.add_argument("-c", "--credentials", default=None, help="Path to oauth_creds.json") + p.add_argument("--no-stream", action="store_true", help="Non-streaming mode") + p.add_argument( + "--thinking", + type=int, + default=None, + metavar="BUDGET", + help="Enable thinking mode with given token budget", + ) + p.add_argument("--list-models", action="store_true", help="List available models and exit") + return p.parse_args() + + +async def run(args): + prompt = args.prompt or sys.stdin.read().strip() + if not prompt: + print("Error: no prompt provided", file=sys.stderr) + sys.exit(1) + + opts = GeminiOptions( + model=args.model, + thinking_budget=args.thinking, + ) + + async with GeminiClient(options=opts, credentials_path=args.credentials) as client: + if args.no_stream: + response = await client.send_message(prompt) + print(response.text) + else: + async for chunk in client.send_message_stream(prompt): + if chunk.text_delta: + print(chunk.text_delta, end="", flush=True) + print() + + +def main(): + args = parse_args() + if args.list_models: + for m in list_models(): + tag = " [default]" if m.is_default else " [preview]" if m.is_preview else "" + print(f"{m.name}{tag}") + return + asyncio.run(run(args)) + + +if __name__ == "__main__": + main() diff --git a/src/gemini/client.py b/src/gemini/client.py new file mode 100644 index 0000000..3e8dda1 --- /dev/null +++ b/src/gemini/client.py @@ -0,0 +1,285 @@ +import asyncio +import os +import re +import uuid +from collections.abc import AsyncIterator +from typing import Any + +import httpx + +from .auth import OAuthCredentials +from .streaming import parse_sse_stream +from .types import ( + GeminiOptions, + GenerateContentResponse, + StreamChunk, +) + +BASE_URL = "https://cloudcode-pa.googleapis.com" +API_VERSION = "v1internal" +CLI_VERSION = "0.31.0" +NODE_VERSION = "20.20.0" + + +class GeminiClient: + """ + Async Python client for Gemini API via Gemini CLI's Code Assist endpoint. + + Reverse-engineered from Gemini CLI v0.31.0. + Requires OAuth credentials from `~/.gemini/oauth_creds.json` + (created by logging into Gemini CLI). + + Example: + async with GeminiClient() as client: + # Streaming + async for chunk in client.send_message_stream("Hello"): + print(chunk.text_delta, end="", flush=True) + + # Non-streaming + response = await client.send_message("What is 2+2?") + print(response.text) + """ + + def __init__( + self, + options: GeminiOptions | None = None, + credentials_path: str | None = None, + ): + self.options = options or GeminiOptions() + self._creds = OAuthCredentials(credentials_path or self.options.credentials_path) + self._http = httpx.AsyncClient(http2=True, timeout=300.0) + self._project_id: str | None = None + self._session_id = self.options.session_id or str(uuid.uuid4()) + self._messages: list[dict[str, Any]] = [] + + async def _token(self) -> str: + return await self._creds.get_valid_token(self._http) + + def _headers(self, token: str) -> dict[str, str]: + headers = { + "accept": "application/json", + "accept-encoding": "gzip, deflate, br", + "authorization": f"Bearer {token}", + "content-type": "application/json", + "user-agent": ( + f"GeminiCLI/{CLI_VERSION}/{self.options.model} " + f"(linux; x64) google-api-nodejs-client/10.6.1" + ), + "x-goog-api-client": f"gl-node/{NODE_VERSION}", + } + capture_source = os.environ.get("CAPTURE_SOURCE") + if capture_source: + headers["x-capture-source"] = capture_source + return headers + + def _url(self, method: str) -> str: + return f"{BASE_URL}/{API_VERSION}:{method}" + + async def _ensure_project(self) -> str: + if self._project_id: + return self._project_id + token = await self._token() + resp = await self._http.post( + self._url("loadCodeAssist"), + headers=self._headers(token), + json={ + "metadata": { + "ideType": "IDE_UNSPECIFIED", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + } + }, + ) + resp.raise_for_status() + data = resp.json() + self._project_id = data.get("cloudaicompanionProject", "") + if not self._project_id: + raise ValueError( + "Could not determine project ID from loadCodeAssist. " + "Ensure your Google account has Gemini Code Assist access." + ) + return self._project_id + + def _generation_config(self) -> dict[str, Any]: + cfg: dict[str, Any] = { + "temperature": self.options.temperature, + "topP": self.options.top_p, + "topK": self.options.top_k, + "maxOutputTokens": self.options.max_output_tokens, + } + if self.options.thinking_budget is not None: + cfg["thinkingConfig"] = { + "includeThoughts": True, + "thinkingBudget": self.options.thinking_budget, + } + return cfg + + def _build_request_body(self, project: str) -> dict[str, Any]: + inner: dict[str, Any] = { + "contents": self._messages, + "generationConfig": self._generation_config(), + "session_id": self._session_id, + } + if self.options.system_prompt: + inner["systemInstruction"] = { + "role": "user", + "parts": [{"text": self.options.system_prompt}], + } + if self.options.tools: + inner["tools"] = [{"functionDeclarations": [t.to_api() for t in self.options.tools]}] + return { + "model": self.options.model, + "project": project, + "user_prompt_id": uuid.uuid4().hex[:13], + "request": inner, + } + + def _add_user_message(self, prompt: str): + self._messages.append({"role": "user", "parts": [{"text": prompt}]}) + + def _add_function_responses(self, results: list[tuple]): + parts = [{"functionResponse": {"name": n, "response": r}} for n, r in results] + self._messages.append({"role": "user", "parts": parts}) + + def _add_function_response(self, name: str, response: Any): + self._add_function_responses([(name, response)]) + + def _add_assistant_message(self, parts: list[dict[str, Any]]): + self._messages.append({"role": "model", "parts": parts}) + + async def _retry_post( + self, url: str, headers: dict[str, str], body: dict[str, Any] + ) -> httpx.Response: + for _attempt in range(4): + resp = await self._http.post(url, headers=headers, json=body) + if resp.status_code != 429: + return resp + wait = 30 + m = re.search(r"(\d+)s", resp.text) + if m: + wait = int(m.group(1)) + 2 + await asyncio.sleep(min(wait, 90)) + return resp + + async def _execute_stream(self) -> AsyncIterator[StreamChunk]: + project = await self._ensure_project() + token = await self._token() + body = self._build_request_body(project) + url = self._url("streamGenerateContent") + "?alt=sse" + headers = {**self._headers(token), "accept": "*/*"} + + for attempt in range(4): + try: + all_parts: list[dict[str, Any]] = [] + async for chunk in parse_sse_stream(self._http, "POST", url, headers, body): + if chunk.response and chunk.response.candidates: + for candidate in chunk.response.candidates: + if candidate.content and candidate.content.parts: + all_parts.extend(candidate.content.parts) + yield chunk + if all_parts: + self._add_assistant_message(all_parts) + return + except httpx.HTTPStatusError as e: + if e.response.status_code != 429 or attempt == 3: + raise + wait = 30 + m = re.search(r"(\d+)s", e.response.text) + if m: + wait = int(m.group(1)) + 2 + await asyncio.sleep(min(wait, 90)) + + async def _execute(self) -> GenerateContentResponse: + project = await self._ensure_project() + token = await self._token() + body = self._build_request_body(project) + + resp = await self._retry_post( + self._url("generateContent"), + self._headers(token), + body, + ) + resp.raise_for_status() + data = resp.json() + + from .streaming import _parse_response + + result = _parse_response(data) + if result.candidates: + parts = [] + for c in result.candidates: + if c.content: + parts.extend(c.content.parts) + self._add_assistant_message(parts) + return result + + async def send_message_stream(self, prompt: str) -> AsyncIterator[StreamChunk]: + self._add_user_message(prompt) + async for chunk in self._execute_stream(): + yield chunk + + async def send_message(self, prompt: str) -> GenerateContentResponse: + self._add_user_message(prompt) + return await self._execute() + + async def send_tool_result(self, name: str, response: Any) -> GenerateContentResponse: + return await self.send_tool_results([(name, response)]) + + async def send_tool_result_stream(self, name: str, response: Any) -> AsyncIterator[StreamChunk]: + async for chunk in self.send_tool_results_stream([(name, response)]): + yield chunk + + async def send_tool_results(self, results: list[tuple]) -> GenerateContentResponse: + self._add_function_responses(results) + return await self._execute() + + async def send_tool_results_stream(self, results: list[tuple]) -> AsyncIterator[StreamChunk]: + self._add_function_responses(results) + async for chunk in self._execute_stream(): + yield chunk + + def clear_history(self): + self._messages = [] + + @property + def history(self) -> list[dict[str, Any]]: + return self._messages + + async def close(self): + await self._http.aclose() + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + await self.close() + + +async def query( + prompt: str, + model: str = "gemini-2.5-pro", + credentials_path: str | None = None, + stream: bool = False, +) -> GenerateContentResponse: + """ + One-shot query helper. + + Example: + result = await query("Explain quantum computing in one sentence") + print(result.text) + """ + opts = GeminiOptions(model=model) + async with GeminiClient(options=opts, credentials_path=credentials_path) as client: + if stream: + all_parts = [] + last = None + async for chunk in client.send_message_stream(prompt): + last = chunk + if chunk.response and chunk.response.candidates: + for c in chunk.response.candidates: + if c.content: + all_parts.extend(c.content.parts) + if last and last.response: + return last.response + return GenerateContentResponse(candidates=[]) + return await client.send_message(prompt) diff --git a/src/gemini/models.py b/src/gemini/models.py new file mode 100644 index 0000000..bca261f --- /dev/null +++ b/src/gemini/models.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Model: + name: str + is_default: bool = False + is_preview: bool = False + + +# Sourced from @google/gemini-cli-core dist/src/config/models.js (VALID_GEMINI_MODELS) +_MODELS: tuple[Model, ...] = ( + Model("gemini-2.5-pro", is_default=True), + Model("gemini-2.5-flash"), + Model("gemini-2.5-flash-lite"), + Model("gemini-3-pro-preview", is_preview=True), + Model("gemini-3-flash-preview", is_preview=True), + Model("gemini-3.1-pro-preview", is_preview=True), + Model("gemini-3.1-pro-preview-customtools", is_preview=True), +) + + +def list_models() -> list[Model]: + return list(_MODELS) diff --git a/src/gemini/streaming.py b/src/gemini/streaming.py new file mode 100644 index 0000000..1b83b4d --- /dev/null +++ b/src/gemini/streaming.py @@ -0,0 +1,86 @@ +import json +from collections.abc import AsyncIterator +from typing import Any + +import httpx + +from .types import GenerateContentResponse, StreamChunk + + +def _parse_candidate(c: dict[str, Any]) -> dict[str, Any]: + out: dict[str, Any] = {} + content = c.get("content", {}) + if content: + out["content"] = { + "role": content.get("role", "model"), + "parts": content.get("parts", []), + } + if "finishReason" in c: + out["finish_reason"] = c["finishReason"] + if "index" in c: + out["index"] = c["index"] + return out + + +def _parse_response(data: dict[str, Any]) -> GenerateContentResponse: + inner = data.get("response", data) + candidates_raw = inner.get("candidates", []) + candidates = [] + for c in candidates_raw: + parsed = _parse_candidate(c) + from .types import Candidate, Content + + content_data = parsed.get("content") + content = Content(**content_data) if content_data else None + candidates.append( + Candidate( + content=content, + finish_reason=parsed.get("finish_reason"), + index=parsed.get("index"), + ) + ) + + usage_raw = inner.get("usageMetadata", {}) + from .types import UsageMetadata + + usage = ( + UsageMetadata( + prompt_token_count=usage_raw.get("promptTokenCount"), + candidates_token_count=usage_raw.get("candidatesTokenCount"), + total_token_count=usage_raw.get("totalTokenCount"), + thoughts_token_count=usage_raw.get("thoughtsTokenCount"), + ) + if usage_raw + else None + ) + + return GenerateContentResponse( + candidates=candidates, + usage_metadata=usage, + model_version=inner.get("modelVersion"), + response_id=data.get("traceId"), + ) + + +async def parse_sse_stream( + client: httpx.AsyncClient, + method: str, + url: str, + headers: dict[str, str], + body: dict[str, Any], +) -> AsyncIterator[StreamChunk]: + async with client.stream(method, url, headers=headers, json=body) as resp: + resp.raise_for_status() + buffer: list[str] = [] + async for line in resp.aiter_lines(): + if line.startswith("data: "): + buffer.append(line[6:].strip()) + elif line == "" and buffer: + raw_text = "\n".join(buffer) + buffer = [] + try: + data = json.loads(raw_text) + response = _parse_response(data) + yield StreamChunk(response=response, trace_id=data.get("traceId"), raw=data) + except json.JSONDecodeError: + continue diff --git a/src/gemini/types.py b/src/gemini/types.py new file mode 100644 index 0000000..7e00e2a --- /dev/null +++ b/src/gemini/types.py @@ -0,0 +1,175 @@ +from typing import Any + +from pydantic import BaseModel + + +class Part(BaseModel): + text: str | None = None + thought: bool | None = None + thought_signature: str | None = None + function_call: dict[str, Any] | None = None + function_response: dict[str, Any] | None = None + + model_config = {"populate_by_name": True, "extra": "allow"} + + +class Content(BaseModel): + role: str + parts: list[dict[str, Any]] + + model_config = {"extra": "allow"} + + +class GenerationConfig(BaseModel): + temperature: float | None = None + top_p: float | None = None + top_k: int | None = None + candidate_count: int | None = None + max_output_tokens: int | None = None + stop_sequences: list[str] | None = None + response_mime_type: str | None = None + thinking_config: dict[str, Any] | None = None + + model_config = {"populate_by_name": True, "extra": "allow"} + + def to_api(self) -> dict[str, Any]: + out: dict[str, Any] = {} + if self.temperature is not None: + out["temperature"] = self.temperature + if self.top_p is not None: + out["topP"] = self.top_p + if self.top_k is not None: + out["topK"] = self.top_k + if self.candidate_count is not None: + out["candidateCount"] = self.candidate_count + if self.max_output_tokens is not None: + out["maxOutputTokens"] = self.max_output_tokens + if self.stop_sequences is not None: + out["stopSequences"] = self.stop_sequences + if self.response_mime_type is not None: + out["responseMimeType"] = self.response_mime_type + if self.thinking_config is not None: + out["thinkingConfig"] = self.thinking_config + return out + + +class UsageMetadata(BaseModel): + prompt_token_count: int | None = None + candidates_token_count: int | None = None + total_token_count: int | None = None + thoughts_token_count: int | None = None + + model_config = {"populate_by_name": True, "extra": "allow"} + + +class Candidate(BaseModel): + content: Content | None = None + finish_reason: str | None = None + index: int | None = None + + model_config = {"populate_by_name": True, "extra": "allow"} + + +class GenerateContentResponse(BaseModel): + candidates: list[Candidate] | None = None + usage_metadata: UsageMetadata | None = None + model_version: str | None = None + response_id: str | None = None + + model_config = {"populate_by_name": True, "extra": "allow"} + + @property + def text(self) -> str: + if not self.candidates: + return "" + parts = [] + for candidate in self.candidates: + if candidate.content and candidate.content.parts: + for part in candidate.content.parts: + if isinstance(part, dict): + if part.get("thought"): + continue + t = part.get("text") + if t: + parts.append(t) + return "".join(parts) + + @property + def tool_calls(self) -> "list[ToolCall]": + calls: list[ToolCall] = [] + if not self.candidates: + return calls + for c in self.candidates: + if c.content and c.content.parts: + for part in c.content.parts: + if isinstance(part, dict) and "functionCall" in part: + fc = part["functionCall"] + calls.append(ToolCall(name=fc["name"], args=fc.get("args", {}))) + return calls + + @property + def thinking(self) -> str: + if not self.candidates: + return "" + parts = [] + for candidate in self.candidates: + if candidate.content and candidate.content.parts: + for part in candidate.content.parts: + if isinstance(part, dict) and part.get("thought"): + t = part.get("text", "") + if t: + parts.append(t) + return "".join(parts) + + +class StreamChunk(BaseModel): + response: GenerateContentResponse | None = None + trace_id: str | None = None + raw: dict[str, Any] = {} + + @property + def text_delta(self) -> str: + if self.response: + return self.response.text + return "" + + @property + def tool_calls(self) -> "list[ToolCall]": + if self.response: + return self.response.tool_calls + return [] + + +class ToolCall(BaseModel): + name: str + args: dict[str, Any] = {} + + +class FunctionDeclaration(BaseModel): + name: str + description: str = "" + parameters: dict[str, Any] | None = None + + model_config = {"extra": "allow"} + + def to_api(self) -> dict[str, Any]: + out: dict[str, Any] = {"name": self.name} + if self.description: + out["description"] = self.description + if self.parameters is not None: + out["parameters"] = self.parameters + return out + + +class GeminiOptions(BaseModel): + model: str = "gemini-2.5-pro" + max_output_tokens: int = 32768 + temperature: float = 1.0 + top_p: float = 0.95 + top_k: int = 64 + thinking_budget: int | None = None + stream: bool = True + system_prompt: str | None = None + session_id: str | None = None + credentials_path: str | None = None + tools: list[FunctionDeclaration] | None = None |
