from __future__ import annotations import asyncio import hashlib import os import platform import uuid from collections.abc import AsyncIterator from typing import Any import httpx from .streaming import StreamParser, parse_sse_stream from .types import AssistantMessage, StreamChunk, ToolHandler, ToolResult, ToolUse MODELS: dict[str, dict[str, str]] = { "claude-haiku-4-5-20251001": {"family": "haiku", "display": "Haiku 4.5"}, "claude-sonnet-4-6": {"family": "sonnet", "display": "Sonnet 4.6"}, "claude-opus-4-6": {"family": "opus", "display": "Opus 4.6"}, } BETA_BASE = "oauth-2025-04-20,interleaved-thinking-2025-05-14,prompt-caching-scope-2026-01-05,claude-code-20250219" BETA_ADAPTIVE = "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,prompt-caching-scope-2026-01-05,effort-2025-11-24,adaptive-thinking-2026-01-28" def _supports_adaptive(model: str) -> bool: m = model.lower() return "sonnet" in m or "opus" in m class ChatClient: BASE_URL = "https://api.anthropic.com" API_VERSION = "2023-06-01" @staticmethod def list_models() -> dict[str, dict[str, str]]: return dict(MODELS) def __init__( self, api_key: str | None = None, model: str = "claude-sonnet-4-6", max_tokens: int = 8192, system: str | None = None, timeout: float = 300.0, tools: list[dict[str, Any]] | None = None, max_retries: int = 3, backoff_factor: float = 1.0, ): self.api_key = ( api_key or os.getenv("ANTHROPIC_API_KEY") or os.getenv("CLAUDE_CODE_OAUTH_TOKEN") ) if not self.api_key: raise ValueError("API key required: set ANTHROPIC_API_KEY or CLAUDE_CODE_OAUTH_TOKEN") self.model = model self.max_tokens = max_tokens self.system = system self.tools = tools self.max_retries = max_retries self.backoff_factor = backoff_factor self.client = httpx.AsyncClient(http2=True, timeout=timeout) self._device_id = hashlib.sha256(platform.node().encode()).hexdigest() self._account_uuid = os.getenv("CLAUDE_ACCOUNT_UUID") or str(uuid.uuid4()) self._session_id = str(uuid.uuid4()) def _headers(self, model: str | None = None) -> dict[str, str]: m = model or self.model beta = BETA_ADAPTIVE if _supports_adaptive(m) else BETA_BASE return { "accept": "application/json", "accept-language": "*", "anthropic-beta": beta, "anthropic-dangerous-direct-browser-access": "true", "anthropic-version": self.API_VERSION, "authorization": f"Bearer {self.api_key}", "content-type": "application/json", "sec-fetch-mode": "cors", "user-agent": "claude-cli/2.1.63 (external, sdk-cli)", "x-app": "cli", "x-stainless-arch": "x64", "x-stainless-lang": "js", "x-stainless-os": "Linux", "x-stainless-package-version": "0.74.0", "x-stainless-retry-count": "0", "x-stainless-runtime": "node", "x-stainless-runtime-version": "v20.20.0", "x-stainless-timeout": "600", } def _metadata(self) -> dict[str, str]: return { "user_id": f"user_{self._device_id}_account_{self._account_uuid}_session_{self._session_id}" } @staticmethod def _normalize_messages( messages: str | list[dict[str, Any]], ) -> list[dict[str, Any]]: if isinstance(messages, str): return [{"role": "user", "content": messages}] return list(messages) def _body( self, messages: list[dict[str, Any]], stream: bool, max_tokens: int | None = None, system: str | None = None, model: str | None = None, tools: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: m = model or self.model body: dict[str, Any] = { "model": m, "max_tokens": max_tokens or self.max_tokens, "messages": messages, "metadata": self._metadata(), "stream": stream, } sys_text = system or self.system if sys_text: body["system"] = sys_text effective_tools = tools or self.tools if effective_tools: body["tools"] = effective_tools return body async def chat( self, messages: str | list[dict[str, Any]], *, model: str | None = None, max_tokens: int | None = None, system: str | None = None, tools: list[dict[str, Any]] | None = None, ) -> AssistantMessage: msgs = self._normalize_messages(messages) m = model or self.model body = self._body( msgs, stream=False, max_tokens=max_tokens, system=system, model=m, tools=tools ) url = f"{self.BASE_URL}/v1/messages" last_error: Exception | None = None backoff = self.backoff_factor for attempt in range(self.max_retries): try: resp = await self.client.post(url, headers=self._headers(m), json=body) resp.raise_for_status() return AssistantMessage.model_validate(resp.json()) except (httpx.ConnectError, httpx.ConnectTimeout) as e: last_error = e if attempt < self.max_retries - 1: await asyncio.sleep(backoff) backoff *= 2 raise last_error # type: ignore[misc] async def stream( self, messages: str | list[dict[str, Any]], *, model: str | None = None, max_tokens: int | None = None, system: str | None = None, tools: list[dict[str, Any]] | None = None, ) -> AsyncIterator[StreamChunk]: msgs = self._normalize_messages(messages) m = model or self.model body = self._body( msgs, stream=True, max_tokens=max_tokens, system=system, model=m, tools=tools ) url = f"{self.BASE_URL}/v1/messages" last_error = None backoff = self.backoff_factor for attempt in range(self.max_retries): try: async for chunk in parse_sse_stream( self.client, "POST", url, self._headers(m), body ): yield chunk return except (httpx.ConnectError, httpx.ConnectTimeout) as e: last_error = e if attempt < self.max_retries - 1: await asyncio.sleep(backoff) backoff *= 2 if last_error: raise last_error async def collect( self, messages: str | list[dict[str, Any]], *, model: str | None = None, max_tokens: int | None = None, system: str | None = None, tools: list[dict[str, Any]] | None = None, ) -> str: parser = StreamParser() async for chunk in self.stream( messages, model=model, max_tokens=max_tokens, system=system, tools=tools ): parser.add_chunk(chunk) msg = parser.to_dict() parts = [] for block in msg.get("content", []): if block.get("type") == "text": parts.append(block["text"]) return "".join(parts) async def run( self, messages: str | list[dict[str, Any]], *, tools: list[dict[str, Any]] | None = None, tool_handler: ToolHandler, model: str | None = None, max_tokens: int | None = None, system: str | None = None, max_turns: int = 10, ) -> AssistantMessage: msgs = self._normalize_messages(messages) for _ in range(max_turns): resp = await self.chat( msgs, model=model, max_tokens=max_tokens, system=system, tools=tools, ) if not resp.has_tool_use: return resp msgs.append({"role": "assistant", "content": resp.content}) tool_results = [] for tc in resp.tool_calls: try: result = await tool_handler(tc.name, tc.input) tool_results.append( ToolResult( type="tool_result", tool_use_id=tc.id, content=result, is_error=False, ) ) except Exception as exc: tool_results.append( ToolResult( type="tool_result", tool_use_id=tc.id, content=str(exc), is_error=True, ) ) msgs.append( { "role": "user", "content": [r.model_dump() for r in tool_results], } ) return resp async def run_stream( self, messages: str | list[dict[str, Any]], *, tools: list[dict[str, Any]] | None = None, tool_handler: ToolHandler, model: str | None = None, max_tokens: int | None = None, system: str | None = None, max_turns: int = 10, ) -> AsyncIterator[StreamChunk | ToolUse | ToolResult]: msgs = self._normalize_messages(messages) for _ in range(max_turns): parser = StreamParser() async for chunk in self.stream( msgs, model=model, max_tokens=max_tokens, system=system, tools=tools, ): parser.add_chunk(chunk) yield chunk msg_dict = parser.to_dict() resp = AssistantMessage.model_validate(msg_dict) if not resp.has_tool_use: return msgs.append({"role": "assistant", "content": resp.content}) tool_results = [] for tc in resp.tool_calls: yield tc try: result = await tool_handler(tc.name, tc.input) tr = ToolResult( type="tool_result", tool_use_id=tc.id, content=result, is_error=False, ) except Exception as exc: tr = ToolResult( type="tool_result", tool_use_id=tc.id, content=str(exc), is_error=True, ) yield tr tool_results.append(tr) msgs.append( { "role": "user", "content": [r.model_dump() for r in tool_results], } ) async def close(self): await self.client.aclose() async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close()