aboutsummaryrefslogtreecommitdiffstats
path: root/src/claude/streaming.py
blob: eec7de57a5a051f5bbe24586607ca05211f01b94 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
Streaming support for Claude API.

Handles Server-Sent Events (SSE) parsing and stream management.
"""

import json
from collections.abc import AsyncIterator
from typing import Any

import httpx

from .types import StreamChunk


async def parse_sse_stream(
    client: httpx.AsyncClient,
    method: str,
    url: str,
    headers: dict[str, str],
    json_data: dict[str, Any],
) -> AsyncIterator[StreamChunk]:
    """
    Parse Server-Sent Events stream from Claude API.

    Note: Claude Code sends Accept: application/json but still receives
    text/event-stream responses when stream: true is in the request body.

    Args:
        client: httpx AsyncClient
        method: HTTP method (usually "POST")
        url: API endpoint URL
        headers: Request headers
        json_data: Request body JSON

    Yields:
        StreamChunk objects with parsed event data
    """
    async with client.stream(method, url, headers=headers, json=json_data) as response:
        if response.status_code >= 400:
            # Read error response before raising
            error_text = await response.aread()
            print(f"API Error Response: {error_text.decode('utf-8', errors='replace')}")
            response.raise_for_status()

        # Parse SSE manually
        async for line in response.aiter_lines():
            line = line.strip()

            if not line:
                # Empty line separates events
                continue

            if line.startswith(":"):
                # Comment line, skip
                continue

            if line.startswith("event:"):
                continue

            if line.startswith("data:"):
                # Data line
                data_str = line[5:].strip()

                if not data_str:
                    continue

                try:
                    data = json.loads(data_str)
                except json.JSONDecodeError:
                    # Skip malformed JSON
                    continue

                # Extract text delta if present
                text_delta = None
                if data.get("type") == "content_block_delta":
                    delta = data.get("delta", {})
                    if delta.get("type") == "text_delta":
                        text_delta = delta.get("text", "")

                # Create chunk
                chunk = StreamChunk(
                    event_type=data.get("type", "unknown"),
                    data=data,
                    text_delta=text_delta,
                    content_block=data.get("content_block") if "content_block" in data else None,
                )

                yield chunk


class StreamParser:
    """
    Parses and accumulates streaming events into complete messages.

    Usage:
        parser = StreamParser()
        async for chunk in stream:
            parser.add_chunk(chunk)
            if chunk.text_delta:
                print(chunk.text_delta, end='', flush=True)

        message = parser.to_message()
    """

    def __init__(self):
        self.chunks: list = []
        self.message_id: str | None = None
        self.model: str | None = None
        self.role: str = "assistant"
        self.content_blocks: list = []
        self.current_block: dict[str, Any] | None = None
        self.current_block_text: str = ""
        self.stop_reason: str | None = None
        self.usage: dict[str, Any] | None = None

    def add_chunk(self, chunk: StreamChunk):
        """Process a stream chunk and update state."""
        self.chunks.append(chunk)

        event_type = chunk.event_type

        if event_type == "message_start":
            message = chunk.data.get("message", {})
            self.message_id = message.get("id")
            self.model = message.get("model")
            self.role = message.get("role", "assistant")

        elif event_type == "content_block_start":
            # Start new content block
            if self.current_block:
                # Finish previous block
                self._finish_current_block()

            self.current_block = chunk.data.get("content_block", {}).copy()
            self.current_block_text = ""

            # For tool_use blocks, ensure input starts as empty string for accumulation
            if self.current_block.get("type") == "tool_use":
                self.current_block["input"] = ""

        elif event_type == "content_block_delta":
            # Add delta to current block
            delta = chunk.data.get("delta", {})

            if delta.get("type") == "text_delta":
                # Text delta
                text = delta.get("text", "")
                self.current_block_text += text
                if self.current_block:
                    self.current_block["text"] = self.current_block_text

            elif delta.get("type") == "input_json_delta":
                # Tool use input delta
                partial_json = delta.get("partial_json", "")
                if self.current_block:
                    current_json = self.current_block.get("input", "")
                    self.current_block["input"] = current_json + partial_json

        elif event_type == "content_block_stop":
            # Finish current block
            self._finish_current_block()

        elif event_type == "message_delta":
            # Message metadata
            delta = chunk.data.get("delta", {})
            self.stop_reason = delta.get("stop_reason")
            if "usage" in chunk.data:
                self.usage = chunk.data["usage"]

        elif event_type == "message_stop":
            # Stream complete
            pass

    def _finish_current_block(self):
        """Finish current content block and add to blocks list."""
        if self.current_block:
            # Parse tool input JSON if needed
            if self.current_block.get("type") == "tool_use":
                input_str = self.current_block.get("input", "")
                if isinstance(input_str, str):
                    try:
                        self.current_block["input"] = json.loads(input_str)
                    except json.JSONDecodeError:
                        pass  # Keep as string if not valid JSON

            self.content_blocks.append(self.current_block)
            self.current_block = None
            self.current_block_text = ""

    def to_dict(self) -> dict[str, Any]:
        """Convert parsed stream to message dict."""
        # Finish any incomplete block
        if self.current_block:
            self._finish_current_block()

        return {
            "id": self.message_id or "",
            "type": "message",
            "role": self.role,
            "content": self.content_blocks,
            "model": self.model or "",
            "stop_reason": self.stop_reason,
            "usage": self.usage,
        }