aboutsummaryrefslogtreecommitdiffstats
path: root/src/gemini/streaming.py
blob: 1b83b4d09bc8b1d3d07435febd4aca4a06dcc162 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import json
from collections.abc import AsyncIterator
from typing import Any

import httpx

from .types import GenerateContentResponse, StreamChunk


def _parse_candidate(c: dict[str, Any]) -> dict[str, Any]:
    out: dict[str, Any] = {}
    content = c.get("content", {})
    if content:
        out["content"] = {
            "role": content.get("role", "model"),
            "parts": content.get("parts", []),
        }
    if "finishReason" in c:
        out["finish_reason"] = c["finishReason"]
    if "index" in c:
        out["index"] = c["index"]
    return out


def _parse_response(data: dict[str, Any]) -> GenerateContentResponse:
    inner = data.get("response", data)
    candidates_raw = inner.get("candidates", [])
    candidates = []
    for c in candidates_raw:
        parsed = _parse_candidate(c)
        from .types import Candidate, Content

        content_data = parsed.get("content")
        content = Content(**content_data) if content_data else None
        candidates.append(
            Candidate(
                content=content,
                finish_reason=parsed.get("finish_reason"),
                index=parsed.get("index"),
            )
        )

    usage_raw = inner.get("usageMetadata", {})
    from .types import UsageMetadata

    usage = (
        UsageMetadata(
            prompt_token_count=usage_raw.get("promptTokenCount"),
            candidates_token_count=usage_raw.get("candidatesTokenCount"),
            total_token_count=usage_raw.get("totalTokenCount"),
            thoughts_token_count=usage_raw.get("thoughtsTokenCount"),
        )
        if usage_raw
        else None
    )

    return GenerateContentResponse(
        candidates=candidates,
        usage_metadata=usage,
        model_version=inner.get("modelVersion"),
        response_id=data.get("traceId"),
    )


async def parse_sse_stream(
    client: httpx.AsyncClient,
    method: str,
    url: str,
    headers: dict[str, str],
    body: dict[str, Any],
) -> AsyncIterator[StreamChunk]:
    async with client.stream(method, url, headers=headers, json=body) as resp:
        resp.raise_for_status()
        buffer: list[str] = []
        async for line in resp.aiter_lines():
            if line.startswith("data: "):
                buffer.append(line[6:].strip())
            elif line == "" and buffer:
                raw_text = "\n".join(buffer)
                buffer = []
                try:
                    data = json.loads(raw_text)
                    response = _parse_response(data)
                    yield StreamChunk(response=response, trace_id=data.get("traceId"), raw=data)
                except json.JSONDecodeError:
                    continue