aboutsummaryrefslogtreecommitdiffstats
path: root/src/gemini/streaming.py
diff options
context:
space:
mode:
authorClaude <claude@anthropic.com>2026-03-04 19:14:55 +0100
committerClaude <claude@anthropic.com>2026-03-04 19:14:55 +0100
commit171c5b86ef05974426ba5c5d8547c8025977d1a2 (patch)
tree2a1193e2bb81a6341e55d0b883a3fc33f77f8be1 /src/gemini/streaming.py
parent9f14edf2b97286e02830d528038b32d5b31aaa0a (diff)
parent0278c87f062a9ae7d617b92be22b175558a05086 (diff)
downloadgemini-py-main.tar.gz
gemini-py-main.zip
Add initial versionHEADmain
Diffstat (limited to 'src/gemini/streaming.py')
-rw-r--r--src/gemini/streaming.py86
1 files changed, 86 insertions, 0 deletions
diff --git a/src/gemini/streaming.py b/src/gemini/streaming.py
new file mode 100644
index 0000000..1b83b4d
--- /dev/null
+++ b/src/gemini/streaming.py
@@ -0,0 +1,86 @@
+import json
+from collections.abc import AsyncIterator
+from typing import Any
+
+import httpx
+
+from .types import GenerateContentResponse, StreamChunk
+
+
+def _parse_candidate(c: dict[str, Any]) -> dict[str, Any]:
+ out: dict[str, Any] = {}
+ content = c.get("content", {})
+ if content:
+ out["content"] = {
+ "role": content.get("role", "model"),
+ "parts": content.get("parts", []),
+ }
+ if "finishReason" in c:
+ out["finish_reason"] = c["finishReason"]
+ if "index" in c:
+ out["index"] = c["index"]
+ return out
+
+
+def _parse_response(data: dict[str, Any]) -> GenerateContentResponse:
+ inner = data.get("response", data)
+ candidates_raw = inner.get("candidates", [])
+ candidates = []
+ for c in candidates_raw:
+ parsed = _parse_candidate(c)
+ from .types import Candidate, Content
+
+ content_data = parsed.get("content")
+ content = Content(**content_data) if content_data else None
+ candidates.append(
+ Candidate(
+ content=content,
+ finish_reason=parsed.get("finish_reason"),
+ index=parsed.get("index"),
+ )
+ )
+
+ usage_raw = inner.get("usageMetadata", {})
+ from .types import UsageMetadata
+
+ usage = (
+ UsageMetadata(
+ prompt_token_count=usage_raw.get("promptTokenCount"),
+ candidates_token_count=usage_raw.get("candidatesTokenCount"),
+ total_token_count=usage_raw.get("totalTokenCount"),
+ thoughts_token_count=usage_raw.get("thoughtsTokenCount"),
+ )
+ if usage_raw
+ else None
+ )
+
+ return GenerateContentResponse(
+ candidates=candidates,
+ usage_metadata=usage,
+ model_version=inner.get("modelVersion"),
+ response_id=data.get("traceId"),
+ )
+
+
+async def parse_sse_stream(
+ client: httpx.AsyncClient,
+ method: str,
+ url: str,
+ headers: dict[str, str],
+ body: dict[str, Any],
+) -> AsyncIterator[StreamChunk]:
+ async with client.stream(method, url, headers=headers, json=body) as resp:
+ resp.raise_for_status()
+ buffer: list[str] = []
+ async for line in resp.aiter_lines():
+ if line.startswith("data: "):
+ buffer.append(line[6:].strip())
+ elif line == "" and buffer:
+ raw_text = "\n".join(buffer)
+ buffer = []
+ try:
+ data = json.loads(raw_text)
+ response = _parse_response(data)
+ yield StreamChunk(response=response, trace_id=data.get("traceId"), raw=data)
+ except json.JSONDecodeError:
+ continue