aboutsummaryrefslogtreecommitdiffstats
path: root/src/gemini/types.py
blob: 7e00e2a5a309943caf214c86cc1172740ff2a30e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from typing import Any

from pydantic import BaseModel


class Part(BaseModel):
    text: str | None = None
    thought: bool | None = None
    thought_signature: str | None = None
    function_call: dict[str, Any] | None = None
    function_response: dict[str, Any] | None = None

    model_config = {"populate_by_name": True, "extra": "allow"}


class Content(BaseModel):
    role: str
    parts: list[dict[str, Any]]

    model_config = {"extra": "allow"}


class GenerationConfig(BaseModel):
    temperature: float | None = None
    top_p: float | None = None
    top_k: int | None = None
    candidate_count: int | None = None
    max_output_tokens: int | None = None
    stop_sequences: list[str] | None = None
    response_mime_type: str | None = None
    thinking_config: dict[str, Any] | None = None

    model_config = {"populate_by_name": True, "extra": "allow"}

    def to_api(self) -> dict[str, Any]:
        out: dict[str, Any] = {}
        if self.temperature is not None:
            out["temperature"] = self.temperature
        if self.top_p is not None:
            out["topP"] = self.top_p
        if self.top_k is not None:
            out["topK"] = self.top_k
        if self.candidate_count is not None:
            out["candidateCount"] = self.candidate_count
        if self.max_output_tokens is not None:
            out["maxOutputTokens"] = self.max_output_tokens
        if self.stop_sequences is not None:
            out["stopSequences"] = self.stop_sequences
        if self.response_mime_type is not None:
            out["responseMimeType"] = self.response_mime_type
        if self.thinking_config is not None:
            out["thinkingConfig"] = self.thinking_config
        return out


class UsageMetadata(BaseModel):
    prompt_token_count: int | None = None
    candidates_token_count: int | None = None
    total_token_count: int | None = None
    thoughts_token_count: int | None = None

    model_config = {"populate_by_name": True, "extra": "allow"}


class Candidate(BaseModel):
    content: Content | None = None
    finish_reason: str | None = None
    index: int | None = None

    model_config = {"populate_by_name": True, "extra": "allow"}


class GenerateContentResponse(BaseModel):
    candidates: list[Candidate] | None = None
    usage_metadata: UsageMetadata | None = None
    model_version: str | None = None
    response_id: str | None = None

    model_config = {"populate_by_name": True, "extra": "allow"}

    @property
    def text(self) -> str:
        if not self.candidates:
            return ""
        parts = []
        for candidate in self.candidates:
            if candidate.content and candidate.content.parts:
                for part in candidate.content.parts:
                    if isinstance(part, dict):
                        if part.get("thought"):
                            continue
                        t = part.get("text")
                        if t:
                            parts.append(t)
        return "".join(parts)

    @property
    def tool_calls(self) -> "list[ToolCall]":
        calls: list[ToolCall] = []
        if not self.candidates:
            return calls
        for c in self.candidates:
            if c.content and c.content.parts:
                for part in c.content.parts:
                    if isinstance(part, dict) and "functionCall" in part:
                        fc = part["functionCall"]
                        calls.append(ToolCall(name=fc["name"], args=fc.get("args", {})))
        return calls

    @property
    def thinking(self) -> str:
        if not self.candidates:
            return ""
        parts = []
        for candidate in self.candidates:
            if candidate.content and candidate.content.parts:
                for part in candidate.content.parts:
                    if isinstance(part, dict) and part.get("thought"):
                        t = part.get("text", "")
                        if t:
                            parts.append(t)
        return "".join(parts)


class StreamChunk(BaseModel):
    response: GenerateContentResponse | None = None
    trace_id: str | None = None
    raw: dict[str, Any] = {}

    @property
    def text_delta(self) -> str:
        if self.response:
            return self.response.text
        return ""

    @property
    def tool_calls(self) -> "list[ToolCall]":
        if self.response:
            return self.response.tool_calls
        return []


class ToolCall(BaseModel):
    name: str
    args: dict[str, Any] = {}


class FunctionDeclaration(BaseModel):
    name: str
    description: str = ""
    parameters: dict[str, Any] | None = None

    model_config = {"extra": "allow"}

    def to_api(self) -> dict[str, Any]:
        out: dict[str, Any] = {"name": self.name}
        if self.description:
            out["description"] = self.description
        if self.parameters is not None:
            out["parameters"] = self.parameters
        return out


class GeminiOptions(BaseModel):
    model: str = "gemini-2.5-pro"
    max_output_tokens: int = 32768
    temperature: float = 1.0
    top_p: float = 0.95
    top_k: int = 64
    thinking_budget: int | None = None
    stream: bool = True
    system_prompt: str | None = None
    session_id: str | None = None
    credentials_path: str | None = None
    tools: list[FunctionDeclaration] | None = None