diff options
| author | Claude <claude@anthropic.com> | 2026-03-04 19:14:55 +0100 |
|---|---|---|
| committer | Claude <claude@anthropic.com> | 2026-03-04 19:14:55 +0100 |
| commit | 171c5b86ef05974426ba5c5d8547c8025977d1a2 (patch) | |
| tree | 2a1193e2bb81a6341e55d0b883a3fc33f77f8be1 /test/scripts/tool_test.py | |
| parent | 9f14edf2b97286e02830d528038b32d5b31aaa0a (diff) | |
| parent | 0278c87f062a9ae7d617b92be22b175558a05086 (diff) | |
| download | gemini-py-main.tar.gz gemini-py-main.zip | |
Diffstat (limited to 'test/scripts/tool_test.py')
| -rw-r--r-- | test/scripts/tool_test.py | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/test/scripts/tool_test.py b/test/scripts/tool_test.py new file mode 100644 index 0000000..2e5d6c9 --- /dev/null +++ b/test/scripts/tool_test.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +import asyncio +import json + +from gemini import FunctionDeclaration, GeminiClient, GeminiOptions + +TOOLS = [ + FunctionDeclaration( + name="get_weather", + description="Get the current weather for a city", + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name, e.g. Paris, Tokyo", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature unit (default: celsius)", + }, + }, + "required": ["location"], + }, + ), +] + +FAKE_WEATHER = { + "Paris": {"temperature": 18, "condition": "cloudy"}, + "London": {"temperature": 14, "condition": "rainy"}, + "Tokyo": {"temperature": 26, "condition": "sunny"}, +} + + +def handle_tool(name: str, args: dict) -> dict: + if name == "get_weather": + city = args.get("location", "") + unit = args.get("unit", "celsius") + data = FAKE_WEATHER.get(city, {"temperature": 20, "condition": "unknown"}) + temp = data["temperature"] + if unit == "fahrenheit": + temp = round(temp * 9 / 5 + 32) + return {"location": city, "temperature": temp, "unit": unit, "condition": data["condition"]} + raise ValueError(f"Unknown tool: {name}") + + +async def run_non_streaming(): + print("=== Non-streaming tool use ===") + opts = GeminiOptions(model="gemini-2.5-flash", tools=TOOLS) + async with GeminiClient(options=opts) as client: + prompt = "What's the weather in Paris and Tokyo? Use celsius." + print(f"User: {prompt}") + response = await client.send_message(prompt) + + while response.tool_calls: + results = [] + for tc in response.tool_calls: + print(f" [tool call] {tc.name}({json.dumps(tc.args)})") + result = handle_tool(tc.name, tc.args) + print(f" [tool result] {json.dumps(result)}") + results.append((tc.name, result)) + response = await client.send_tool_results(results) + + print(f"Model: {response.text.strip()}") + + +async def run_streaming(): + print("\n=== Streaming tool use ===") + opts = GeminiOptions(model="gemini-2.5-flash", tools=TOOLS) + async with GeminiClient(options=opts) as client: + prompt = "What's the weather in London right now?" + print(f"User: {prompt}") + + final = None + async for chunk in client.send_message_stream(prompt): + final = chunk + + if final and final.tool_calls: + results = [] + for tc in final.tool_calls: + print(f" [tool call] {tc.name}({json.dumps(tc.args)})") + result = handle_tool(tc.name, tc.args) + print(f" [tool result] {json.dumps(result)}") + results.append((tc.name, result)) + print("Model: ", end="", flush=True) + async for chunk in client.send_tool_results_stream(results): + print(chunk.text_delta, end="", flush=True) + print() + elif final: + print(f"Model: {final.response.text.strip() if final.response else ''}") + + +async def main(): + await run_non_streaming() + await run_streaming() + + +if __name__ == "__main__": + asyncio.run(main()) |
