diff options
| author | Claude <claude@anthropic.com> | 2026-03-04 19:14:55 +0100 |
|---|---|---|
| committer | Claude <claude@anthropic.com> | 2026-03-04 19:14:55 +0100 |
| commit | 171c5b86ef05974426ba5c5d8547c8025977d1a2 (patch) | |
| tree | 2a1193e2bb81a6341e55d0b883a3fc33f77f8be1 /src/gemini/streaming.py | |
| parent | 9f14edf2b97286e02830d528038b32d5b31aaa0a (diff) | |
| parent | 0278c87f062a9ae7d617b92be22b175558a05086 (diff) | |
| download | gemini-py-main.tar.gz gemini-py-main.zip | |
Diffstat (limited to 'src/gemini/streaming.py')
| -rw-r--r-- | src/gemini/streaming.py | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/src/gemini/streaming.py b/src/gemini/streaming.py new file mode 100644 index 0000000..1b83b4d --- /dev/null +++ b/src/gemini/streaming.py @@ -0,0 +1,86 @@ +import json +from collections.abc import AsyncIterator +from typing import Any + +import httpx + +from .types import GenerateContentResponse, StreamChunk + + +def _parse_candidate(c: dict[str, Any]) -> dict[str, Any]: + out: dict[str, Any] = {} + content = c.get("content", {}) + if content: + out["content"] = { + "role": content.get("role", "model"), + "parts": content.get("parts", []), + } + if "finishReason" in c: + out["finish_reason"] = c["finishReason"] + if "index" in c: + out["index"] = c["index"] + return out + + +def _parse_response(data: dict[str, Any]) -> GenerateContentResponse: + inner = data.get("response", data) + candidates_raw = inner.get("candidates", []) + candidates = [] + for c in candidates_raw: + parsed = _parse_candidate(c) + from .types import Candidate, Content + + content_data = parsed.get("content") + content = Content(**content_data) if content_data else None + candidates.append( + Candidate( + content=content, + finish_reason=parsed.get("finish_reason"), + index=parsed.get("index"), + ) + ) + + usage_raw = inner.get("usageMetadata", {}) + from .types import UsageMetadata + + usage = ( + UsageMetadata( + prompt_token_count=usage_raw.get("promptTokenCount"), + candidates_token_count=usage_raw.get("candidatesTokenCount"), + total_token_count=usage_raw.get("totalTokenCount"), + thoughts_token_count=usage_raw.get("thoughtsTokenCount"), + ) + if usage_raw + else None + ) + + return GenerateContentResponse( + candidates=candidates, + usage_metadata=usage, + model_version=inner.get("modelVersion"), + response_id=data.get("traceId"), + ) + + +async def parse_sse_stream( + client: httpx.AsyncClient, + method: str, + url: str, + headers: dict[str, str], + body: dict[str, Any], +) -> AsyncIterator[StreamChunk]: + async with client.stream(method, url, headers=headers, json=body) as resp: + resp.raise_for_status() + buffer: list[str] = [] + async for line in resp.aiter_lines(): + if line.startswith("data: "): + buffer.append(line[6:].strip()) + elif line == "" and buffer: + raw_text = "\n".join(buffer) + buffer = [] + try: + data = json.loads(raw_text) + response = _parse_response(data) + yield StreamChunk(response=response, trace_id=data.get("traceId"), raw=data) + except json.JSONDecodeError: + continue |
