import json from collections.abc import AsyncIterator from typing import Any import httpx from .types import GenerateContentResponse, StreamChunk def _parse_candidate(c: dict[str, Any]) -> dict[str, Any]: out: dict[str, Any] = {} content = c.get("content", {}) if content: out["content"] = { "role": content.get("role", "model"), "parts": content.get("parts", []), } if "finishReason" in c: out["finish_reason"] = c["finishReason"] if "index" in c: out["index"] = c["index"] return out def _parse_response(data: dict[str, Any]) -> GenerateContentResponse: inner = data.get("response", data) candidates_raw = inner.get("candidates", []) candidates = [] for c in candidates_raw: parsed = _parse_candidate(c) from .types import Candidate, Content content_data = parsed.get("content") content = Content(**content_data) if content_data else None candidates.append( Candidate( content=content, finish_reason=parsed.get("finish_reason"), index=parsed.get("index"), ) ) usage_raw = inner.get("usageMetadata", {}) from .types import UsageMetadata usage = ( UsageMetadata( prompt_token_count=usage_raw.get("promptTokenCount"), candidates_token_count=usage_raw.get("candidatesTokenCount"), total_token_count=usage_raw.get("totalTokenCount"), thoughts_token_count=usage_raw.get("thoughtsTokenCount"), ) if usage_raw else None ) return GenerateContentResponse( candidates=candidates, usage_metadata=usage, model_version=inner.get("modelVersion"), response_id=data.get("traceId"), ) async def parse_sse_stream( client: httpx.AsyncClient, method: str, url: str, headers: dict[str, str], body: dict[str, Any], ) -> AsyncIterator[StreamChunk]: async with client.stream(method, url, headers=headers, json=body) as resp: resp.raise_for_status() buffer: list[str] = [] async for line in resp.aiter_lines(): if line.startswith("data: "): buffer.append(line[6:].strip()) elif line == "" and buffer: raw_text = "\n".join(buffer) buffer = [] try: data = json.loads(raw_text) response = _parse_response(data) yield StreamChunk(response=response, trace_id=data.get("traceId"), raw=data) except json.JSONDecodeError: continue