diff options
| author | Claude <claude@anthropic.com> | 2026-03-04 19:14:55 +0100 |
|---|---|---|
| committer | Claude <claude@anthropic.com> | 2026-03-04 19:14:55 +0100 |
| commit | 171c5b86ef05974426ba5c5d8547c8025977d1a2 (patch) | |
| tree | 2a1193e2bb81a6341e55d0b883a3fc33f77f8be1 /src/gemini/client.py | |
| parent | 9f14edf2b97286e02830d528038b32d5b31aaa0a (diff) | |
| parent | 0278c87f062a9ae7d617b92be22b175558a05086 (diff) | |
| download | gemini-py-main.tar.gz gemini-py-main.zip | |
Diffstat (limited to 'src/gemini/client.py')
| -rw-r--r-- | src/gemini/client.py | 285 |
1 files changed, 285 insertions, 0 deletions
diff --git a/src/gemini/client.py b/src/gemini/client.py new file mode 100644 index 0000000..3e8dda1 --- /dev/null +++ b/src/gemini/client.py @@ -0,0 +1,285 @@ +import asyncio +import os +import re +import uuid +from collections.abc import AsyncIterator +from typing import Any + +import httpx + +from .auth import OAuthCredentials +from .streaming import parse_sse_stream +from .types import ( + GeminiOptions, + GenerateContentResponse, + StreamChunk, +) + +BASE_URL = "https://cloudcode-pa.googleapis.com" +API_VERSION = "v1internal" +CLI_VERSION = "0.31.0" +NODE_VERSION = "20.20.0" + + +class GeminiClient: + """ + Async Python client for Gemini API via Gemini CLI's Code Assist endpoint. + + Reverse-engineered from Gemini CLI v0.31.0. + Requires OAuth credentials from `~/.gemini/oauth_creds.json` + (created by logging into Gemini CLI). + + Example: + async with GeminiClient() as client: + # Streaming + async for chunk in client.send_message_stream("Hello"): + print(chunk.text_delta, end="", flush=True) + + # Non-streaming + response = await client.send_message("What is 2+2?") + print(response.text) + """ + + def __init__( + self, + options: GeminiOptions | None = None, + credentials_path: str | None = None, + ): + self.options = options or GeminiOptions() + self._creds = OAuthCredentials(credentials_path or self.options.credentials_path) + self._http = httpx.AsyncClient(http2=True, timeout=300.0) + self._project_id: str | None = None + self._session_id = self.options.session_id or str(uuid.uuid4()) + self._messages: list[dict[str, Any]] = [] + + async def _token(self) -> str: + return await self._creds.get_valid_token(self._http) + + def _headers(self, token: str) -> dict[str, str]: + headers = { + "accept": "application/json", + "accept-encoding": "gzip, deflate, br", + "authorization": f"Bearer {token}", + "content-type": "application/json", + "user-agent": ( + f"GeminiCLI/{CLI_VERSION}/{self.options.model} " + f"(linux; x64) google-api-nodejs-client/10.6.1" + ), + "x-goog-api-client": f"gl-node/{NODE_VERSION}", + } + capture_source = os.environ.get("CAPTURE_SOURCE") + if capture_source: + headers["x-capture-source"] = capture_source + return headers + + def _url(self, method: str) -> str: + return f"{BASE_URL}/{API_VERSION}:{method}" + + async def _ensure_project(self) -> str: + if self._project_id: + return self._project_id + token = await self._token() + resp = await self._http.post( + self._url("loadCodeAssist"), + headers=self._headers(token), + json={ + "metadata": { + "ideType": "IDE_UNSPECIFIED", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + } + }, + ) + resp.raise_for_status() + data = resp.json() + self._project_id = data.get("cloudaicompanionProject", "") + if not self._project_id: + raise ValueError( + "Could not determine project ID from loadCodeAssist. " + "Ensure your Google account has Gemini Code Assist access." + ) + return self._project_id + + def _generation_config(self) -> dict[str, Any]: + cfg: dict[str, Any] = { + "temperature": self.options.temperature, + "topP": self.options.top_p, + "topK": self.options.top_k, + "maxOutputTokens": self.options.max_output_tokens, + } + if self.options.thinking_budget is not None: + cfg["thinkingConfig"] = { + "includeThoughts": True, + "thinkingBudget": self.options.thinking_budget, + } + return cfg + + def _build_request_body(self, project: str) -> dict[str, Any]: + inner: dict[str, Any] = { + "contents": self._messages, + "generationConfig": self._generation_config(), + "session_id": self._session_id, + } + if self.options.system_prompt: + inner["systemInstruction"] = { + "role": "user", + "parts": [{"text": self.options.system_prompt}], + } + if self.options.tools: + inner["tools"] = [{"functionDeclarations": [t.to_api() for t in self.options.tools]}] + return { + "model": self.options.model, + "project": project, + "user_prompt_id": uuid.uuid4().hex[:13], + "request": inner, + } + + def _add_user_message(self, prompt: str): + self._messages.append({"role": "user", "parts": [{"text": prompt}]}) + + def _add_function_responses(self, results: list[tuple]): + parts = [{"functionResponse": {"name": n, "response": r}} for n, r in results] + self._messages.append({"role": "user", "parts": parts}) + + def _add_function_response(self, name: str, response: Any): + self._add_function_responses([(name, response)]) + + def _add_assistant_message(self, parts: list[dict[str, Any]]): + self._messages.append({"role": "model", "parts": parts}) + + async def _retry_post( + self, url: str, headers: dict[str, str], body: dict[str, Any] + ) -> httpx.Response: + for _attempt in range(4): + resp = await self._http.post(url, headers=headers, json=body) + if resp.status_code != 429: + return resp + wait = 30 + m = re.search(r"(\d+)s", resp.text) + if m: + wait = int(m.group(1)) + 2 + await asyncio.sleep(min(wait, 90)) + return resp + + async def _execute_stream(self) -> AsyncIterator[StreamChunk]: + project = await self._ensure_project() + token = await self._token() + body = self._build_request_body(project) + url = self._url("streamGenerateContent") + "?alt=sse" + headers = {**self._headers(token), "accept": "*/*"} + + for attempt in range(4): + try: + all_parts: list[dict[str, Any]] = [] + async for chunk in parse_sse_stream(self._http, "POST", url, headers, body): + if chunk.response and chunk.response.candidates: + for candidate in chunk.response.candidates: + if candidate.content and candidate.content.parts: + all_parts.extend(candidate.content.parts) + yield chunk + if all_parts: + self._add_assistant_message(all_parts) + return + except httpx.HTTPStatusError as e: + if e.response.status_code != 429 or attempt == 3: + raise + wait = 30 + m = re.search(r"(\d+)s", e.response.text) + if m: + wait = int(m.group(1)) + 2 + await asyncio.sleep(min(wait, 90)) + + async def _execute(self) -> GenerateContentResponse: + project = await self._ensure_project() + token = await self._token() + body = self._build_request_body(project) + + resp = await self._retry_post( + self._url("generateContent"), + self._headers(token), + body, + ) + resp.raise_for_status() + data = resp.json() + + from .streaming import _parse_response + + result = _parse_response(data) + if result.candidates: + parts = [] + for c in result.candidates: + if c.content: + parts.extend(c.content.parts) + self._add_assistant_message(parts) + return result + + async def send_message_stream(self, prompt: str) -> AsyncIterator[StreamChunk]: + self._add_user_message(prompt) + async for chunk in self._execute_stream(): + yield chunk + + async def send_message(self, prompt: str) -> GenerateContentResponse: + self._add_user_message(prompt) + return await self._execute() + + async def send_tool_result(self, name: str, response: Any) -> GenerateContentResponse: + return await self.send_tool_results([(name, response)]) + + async def send_tool_result_stream(self, name: str, response: Any) -> AsyncIterator[StreamChunk]: + async for chunk in self.send_tool_results_stream([(name, response)]): + yield chunk + + async def send_tool_results(self, results: list[tuple]) -> GenerateContentResponse: + self._add_function_responses(results) + return await self._execute() + + async def send_tool_results_stream(self, results: list[tuple]) -> AsyncIterator[StreamChunk]: + self._add_function_responses(results) + async for chunk in self._execute_stream(): + yield chunk + + def clear_history(self): + self._messages = [] + + @property + def history(self) -> list[dict[str, Any]]: + return self._messages + + async def close(self): + await self._http.aclose() + + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + await self.close() + + +async def query( + prompt: str, + model: str = "gemini-2.5-pro", + credentials_path: str | None = None, + stream: bool = False, +) -> GenerateContentResponse: + """ + One-shot query helper. + + Example: + result = await query("Explain quantum computing in one sentence") + print(result.text) + """ + opts = GeminiOptions(model=model) + async with GeminiClient(options=opts, credentials_path=credentials_path) as client: + if stream: + all_parts = [] + last = None + async for chunk in client.send_message_stream(prompt): + last = chunk + if chunk.response and chunk.response.candidates: + for c in chunk.response.candidates: + if c.content: + all_parts.extend(c.content.parts) + if last and last.response: + return last.response + return GenerateContentResponse(candidates=[]) + return await client.send_message(prompt) |
