aboutsummaryrefslogtreecommitdiffstats
path: root/src/gemini/client.py
diff options
context:
space:
mode:
authorClaude <claude@anthropic.com>2026-03-04 19:14:55 +0100
committerClaude <claude@anthropic.com>2026-03-04 19:14:55 +0100
commit171c5b86ef05974426ba5c5d8547c8025977d1a2 (patch)
tree2a1193e2bb81a6341e55d0b883a3fc33f77f8be1 /src/gemini/client.py
parent9f14edf2b97286e02830d528038b32d5b31aaa0a (diff)
parent0278c87f062a9ae7d617b92be22b175558a05086 (diff)
downloadgemini-py-main.tar.gz
gemini-py-main.zip
Add initial versionHEADmain
Diffstat (limited to 'src/gemini/client.py')
-rw-r--r--src/gemini/client.py285
1 files changed, 285 insertions, 0 deletions
diff --git a/src/gemini/client.py b/src/gemini/client.py
new file mode 100644
index 0000000..3e8dda1
--- /dev/null
+++ b/src/gemini/client.py
@@ -0,0 +1,285 @@
+import asyncio
+import os
+import re
+import uuid
+from collections.abc import AsyncIterator
+from typing import Any
+
+import httpx
+
+from .auth import OAuthCredentials
+from .streaming import parse_sse_stream
+from .types import (
+ GeminiOptions,
+ GenerateContentResponse,
+ StreamChunk,
+)
+
+BASE_URL = "https://cloudcode-pa.googleapis.com"
+API_VERSION = "v1internal"
+CLI_VERSION = "0.31.0"
+NODE_VERSION = "20.20.0"
+
+
+class GeminiClient:
+ """
+ Async Python client for Gemini API via Gemini CLI's Code Assist endpoint.
+
+ Reverse-engineered from Gemini CLI v0.31.0.
+ Requires OAuth credentials from `~/.gemini/oauth_creds.json`
+ (created by logging into Gemini CLI).
+
+ Example:
+ async with GeminiClient() as client:
+ # Streaming
+ async for chunk in client.send_message_stream("Hello"):
+ print(chunk.text_delta, end="", flush=True)
+
+ # Non-streaming
+ response = await client.send_message("What is 2+2?")
+ print(response.text)
+ """
+
+ def __init__(
+ self,
+ options: GeminiOptions | None = None,
+ credentials_path: str | None = None,
+ ):
+ self.options = options or GeminiOptions()
+ self._creds = OAuthCredentials(credentials_path or self.options.credentials_path)
+ self._http = httpx.AsyncClient(http2=True, timeout=300.0)
+ self._project_id: str | None = None
+ self._session_id = self.options.session_id or str(uuid.uuid4())
+ self._messages: list[dict[str, Any]] = []
+
+ async def _token(self) -> str:
+ return await self._creds.get_valid_token(self._http)
+
+ def _headers(self, token: str) -> dict[str, str]:
+ headers = {
+ "accept": "application/json",
+ "accept-encoding": "gzip, deflate, br",
+ "authorization": f"Bearer {token}",
+ "content-type": "application/json",
+ "user-agent": (
+ f"GeminiCLI/{CLI_VERSION}/{self.options.model} "
+ f"(linux; x64) google-api-nodejs-client/10.6.1"
+ ),
+ "x-goog-api-client": f"gl-node/{NODE_VERSION}",
+ }
+ capture_source = os.environ.get("CAPTURE_SOURCE")
+ if capture_source:
+ headers["x-capture-source"] = capture_source
+ return headers
+
+ def _url(self, method: str) -> str:
+ return f"{BASE_URL}/{API_VERSION}:{method}"
+
+ async def _ensure_project(self) -> str:
+ if self._project_id:
+ return self._project_id
+ token = await self._token()
+ resp = await self._http.post(
+ self._url("loadCodeAssist"),
+ headers=self._headers(token),
+ json={
+ "metadata": {
+ "ideType": "IDE_UNSPECIFIED",
+ "platform": "PLATFORM_UNSPECIFIED",
+ "pluginType": "GEMINI",
+ }
+ },
+ )
+ resp.raise_for_status()
+ data = resp.json()
+ self._project_id = data.get("cloudaicompanionProject", "")
+ if not self._project_id:
+ raise ValueError(
+ "Could not determine project ID from loadCodeAssist. "
+ "Ensure your Google account has Gemini Code Assist access."
+ )
+ return self._project_id
+
+ def _generation_config(self) -> dict[str, Any]:
+ cfg: dict[str, Any] = {
+ "temperature": self.options.temperature,
+ "topP": self.options.top_p,
+ "topK": self.options.top_k,
+ "maxOutputTokens": self.options.max_output_tokens,
+ }
+ if self.options.thinking_budget is not None:
+ cfg["thinkingConfig"] = {
+ "includeThoughts": True,
+ "thinkingBudget": self.options.thinking_budget,
+ }
+ return cfg
+
+ def _build_request_body(self, project: str) -> dict[str, Any]:
+ inner: dict[str, Any] = {
+ "contents": self._messages,
+ "generationConfig": self._generation_config(),
+ "session_id": self._session_id,
+ }
+ if self.options.system_prompt:
+ inner["systemInstruction"] = {
+ "role": "user",
+ "parts": [{"text": self.options.system_prompt}],
+ }
+ if self.options.tools:
+ inner["tools"] = [{"functionDeclarations": [t.to_api() for t in self.options.tools]}]
+ return {
+ "model": self.options.model,
+ "project": project,
+ "user_prompt_id": uuid.uuid4().hex[:13],
+ "request": inner,
+ }
+
+ def _add_user_message(self, prompt: str):
+ self._messages.append({"role": "user", "parts": [{"text": prompt}]})
+
+ def _add_function_responses(self, results: list[tuple]):
+ parts = [{"functionResponse": {"name": n, "response": r}} for n, r in results]
+ self._messages.append({"role": "user", "parts": parts})
+
+ def _add_function_response(self, name: str, response: Any):
+ self._add_function_responses([(name, response)])
+
+ def _add_assistant_message(self, parts: list[dict[str, Any]]):
+ self._messages.append({"role": "model", "parts": parts})
+
+ async def _retry_post(
+ self, url: str, headers: dict[str, str], body: dict[str, Any]
+ ) -> httpx.Response:
+ for _attempt in range(4):
+ resp = await self._http.post(url, headers=headers, json=body)
+ if resp.status_code != 429:
+ return resp
+ wait = 30
+ m = re.search(r"(\d+)s", resp.text)
+ if m:
+ wait = int(m.group(1)) + 2
+ await asyncio.sleep(min(wait, 90))
+ return resp
+
+ async def _execute_stream(self) -> AsyncIterator[StreamChunk]:
+ project = await self._ensure_project()
+ token = await self._token()
+ body = self._build_request_body(project)
+ url = self._url("streamGenerateContent") + "?alt=sse"
+ headers = {**self._headers(token), "accept": "*/*"}
+
+ for attempt in range(4):
+ try:
+ all_parts: list[dict[str, Any]] = []
+ async for chunk in parse_sse_stream(self._http, "POST", url, headers, body):
+ if chunk.response and chunk.response.candidates:
+ for candidate in chunk.response.candidates:
+ if candidate.content and candidate.content.parts:
+ all_parts.extend(candidate.content.parts)
+ yield chunk
+ if all_parts:
+ self._add_assistant_message(all_parts)
+ return
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code != 429 or attempt == 3:
+ raise
+ wait = 30
+ m = re.search(r"(\d+)s", e.response.text)
+ if m:
+ wait = int(m.group(1)) + 2
+ await asyncio.sleep(min(wait, 90))
+
+ async def _execute(self) -> GenerateContentResponse:
+ project = await self._ensure_project()
+ token = await self._token()
+ body = self._build_request_body(project)
+
+ resp = await self._retry_post(
+ self._url("generateContent"),
+ self._headers(token),
+ body,
+ )
+ resp.raise_for_status()
+ data = resp.json()
+
+ from .streaming import _parse_response
+
+ result = _parse_response(data)
+ if result.candidates:
+ parts = []
+ for c in result.candidates:
+ if c.content:
+ parts.extend(c.content.parts)
+ self._add_assistant_message(parts)
+ return result
+
+ async def send_message_stream(self, prompt: str) -> AsyncIterator[StreamChunk]:
+ self._add_user_message(prompt)
+ async for chunk in self._execute_stream():
+ yield chunk
+
+ async def send_message(self, prompt: str) -> GenerateContentResponse:
+ self._add_user_message(prompt)
+ return await self._execute()
+
+ async def send_tool_result(self, name: str, response: Any) -> GenerateContentResponse:
+ return await self.send_tool_results([(name, response)])
+
+ async def send_tool_result_stream(self, name: str, response: Any) -> AsyncIterator[StreamChunk]:
+ async for chunk in self.send_tool_results_stream([(name, response)]):
+ yield chunk
+
+ async def send_tool_results(self, results: list[tuple]) -> GenerateContentResponse:
+ self._add_function_responses(results)
+ return await self._execute()
+
+ async def send_tool_results_stream(self, results: list[tuple]) -> AsyncIterator[StreamChunk]:
+ self._add_function_responses(results)
+ async for chunk in self._execute_stream():
+ yield chunk
+
+ def clear_history(self):
+ self._messages = []
+
+ @property
+ def history(self) -> list[dict[str, Any]]:
+ return self._messages
+
+ async def close(self):
+ await self._http.aclose()
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, *_):
+ await self.close()
+
+
+async def query(
+ prompt: str,
+ model: str = "gemini-2.5-pro",
+ credentials_path: str | None = None,
+ stream: bool = False,
+) -> GenerateContentResponse:
+ """
+ One-shot query helper.
+
+ Example:
+ result = await query("Explain quantum computing in one sentence")
+ print(result.text)
+ """
+ opts = GeminiOptions(model=model)
+ async with GeminiClient(options=opts, credentials_path=credentials_path) as client:
+ if stream:
+ all_parts = []
+ last = None
+ async for chunk in client.send_message_stream(prompt):
+ last = chunk
+ if chunk.response and chunk.response.candidates:
+ for c in chunk.response.candidates:
+ if c.content:
+ all_parts.extend(c.content.parts)
+ if last and last.response:
+ return last.response
+ return GenerateContentResponse(candidates=[])
+ return await client.send_message(prompt)