1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
import json
from collections.abc import AsyncIterator
from typing import Any
import httpx
from .types import GenerateContentResponse, StreamChunk
def _parse_candidate(c: dict[str, Any]) -> dict[str, Any]:
out: dict[str, Any] = {}
content = c.get("content", {})
if content:
out["content"] = {
"role": content.get("role", "model"),
"parts": content.get("parts", []),
}
if "finishReason" in c:
out["finish_reason"] = c["finishReason"]
if "index" in c:
out["index"] = c["index"]
return out
def _parse_response(data: dict[str, Any]) -> GenerateContentResponse:
inner = data.get("response", data)
candidates_raw = inner.get("candidates", [])
candidates = []
for c in candidates_raw:
parsed = _parse_candidate(c)
from .types import Candidate, Content
content_data = parsed.get("content")
content = Content(**content_data) if content_data else None
candidates.append(
Candidate(
content=content,
finish_reason=parsed.get("finish_reason"),
index=parsed.get("index"),
)
)
usage_raw = inner.get("usageMetadata", {})
from .types import UsageMetadata
usage = (
UsageMetadata(
prompt_token_count=usage_raw.get("promptTokenCount"),
candidates_token_count=usage_raw.get("candidatesTokenCount"),
total_token_count=usage_raw.get("totalTokenCount"),
thoughts_token_count=usage_raw.get("thoughtsTokenCount"),
)
if usage_raw
else None
)
return GenerateContentResponse(
candidates=candidates,
usage_metadata=usage,
model_version=inner.get("modelVersion"),
response_id=data.get("traceId"),
)
async def parse_sse_stream(
client: httpx.AsyncClient,
method: str,
url: str,
headers: dict[str, str],
body: dict[str, Any],
) -> AsyncIterator[StreamChunk]:
async with client.stream(method, url, headers=headers, json=body) as resp:
resp.raise_for_status()
buffer: list[str] = []
async for line in resp.aiter_lines():
if line.startswith("data: "):
buffer.append(line[6:].strip())
elif line == "" and buffer:
raw_text = "\n".join(buffer)
buffer = []
try:
data = json.loads(raw_text)
response = _parse_response(data)
yield StreamChunk(response=response, trace_id=data.get("traceId"), raw=data)
except json.JSONDecodeError:
continue
|