aboutsummaryrefslogtreecommitdiffstats
path: root/packages/multillm-agentwrap/src/multillm_agentwrap/provider.py
blob: 52f9ff75867b7fd355d108ce23ded32626067b67 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
"""
Agent wrapper provider implementation.

Wraps chat providers to provide agentic capabilities.
"""

import sys
from typing import Any, AsyncIterator

from multillm import (
    BaseAgentProvider,
    AgentMessage,
    AgentOptions,
    Tool,
    ProviderError,
    load_provider_config,
    merge_config,
)


class AgentWrapProvider(BaseAgentProvider):
    """
    Agent wrapper provider that wraps chat providers with agentic capabilities.

    The model parameter should be the chat provider and model to wrap.
    For example, when using "agentwrap/google/gemini":
    - Provider: "agentwrap"
    - Model: "google/gemini" (passed to this provider)

    This provider will:
    1. Use the specified chat provider internally via chat_complete()
    2. Implement tool execution loop
    3. Manage conversation history
    4. Provide agentic multi-turn interactions

    Usage:
        # Via client
        client = multillm.Client()
        async for msg in client.agent_run("agentwrap/google/gemini", "Hello"):
            print(msg)

        # With tools
        async for msg in client.agent_run(
            "agentwrap/openai/gpt-4",
            "What's 2+2?",
            tools=[calculate_tool],
        ):
            print(msg)
    """

    PROVIDER_NAME = "agentwrap"

    def __init__(self, config: dict[str, Any] | None = None):
        super().__init__(config)
        self._client = None

    def _get_client(self):
        """Get or create client instance for making chat API calls."""
        if self._client is None:
            # Import here to avoid circular dependency
            from multillm import Client
            self._client = Client()
        return self._client

    def _build_options(self, options: AgentOptions | None) -> dict[str, Any]:
        """Build options dict for wrapped provider."""
        if options is None:
            return {}

        opts = {}
        if options.system_prompt:
            opts["system_prompt"] = options.system_prompt

        # Merge with extra options (temperature, max_tokens, etc.)
        if options.extra:
            opts.update(options.extra)

        return opts

    async def _execute_tool(
        self,
        tool_call: dict,
        tools: list[Tool] | None,
    ) -> dict:
        """
        Execute a tool call and return the result.

        Args:
            tool_call: Tool call from chat response (OpenAI format)
            tools: List of available tools with handlers

        Returns:
            Tool result dict with 'content' key
        """
        function_name = tool_call["function"]["name"]
        function_args = tool_call["function"].get("arguments", {})

        # Find the tool with matching name
        if tools:
            for tool in tools:
                if tool.name == function_name:
                    # Execute the tool handler
                    try:
                        result = tool.handler(function_args)
                        # Handle async handlers
                        if hasattr(result, "__await__"):
                            result = await result

                        # Return formatted result
                        return {"content": str(result)}

                    except Exception as e:
                        return {
                            "content": f"Error executing tool: {e}",
                            "is_error": True
                        }

        # Tool not found or no handlers
        return {
            "content": f"Tool '{function_name}' not found",
            "is_error": True
        }

    async def run(
        self,
        prompt: str,
        options: AgentOptions | None = None,
        tools: list[Tool] | None = None,
    ) -> AsyncIterator[AgentMessage]:
        """
        Run agentic workflow with the wrapped chat provider.

        Args:
            prompt: User message to send
            options: Agent options (max_turns, system_prompt, etc.)
            tools: Optional tools the agent can use

        Yields:
            AgentMessage objects representing the agent's actions and responses
        """
        # Yield session start message
        yield AgentMessage(
            type="system",
            content="Agentic session started",
            raw=None,
        )

        # Get wrapped model from config
        # When client routes "agentwrap/google/gemini", we receive "google/gemini" as model
        file_config = load_provider_config(self.PROVIDER_NAME)
        merged_config = merge_config(file_config, self.config, {})
        wrapped_model = merged_config.get("wrapped_model")

        if not wrapped_model:
            raise ProviderError(
                "AgentWrap provider requires 'wrapped_model' in config. "
                "When using via client, the model should be specified as 'agentwrap/provider/model'."
            )

        # Build options for chat API
        chat_options = self._build_options(options)

        # Get max turns
        max_turns = options.max_turns if options and options.max_turns else 10

        # Initialize conversation history
        messages = []

        # Add system prompt if provided
        if options and options.system_prompt:
            messages.append({
                "role": "system",
                "content": options.system_prompt
            })

        # Add user message
        messages.append({
            "role": "user",
            "content": prompt
        })

        # Get client
        client = self._get_client()

        # Tool execution loop
        final_text = ""
        for turn in range(max_turns):
            try:
                # Call chat_complete with wrapped model
                response = await client.chat_complete(
                    wrapped_model,
                    messages,
                    tools=tools,
                    **chat_options
                )

                # Get text from response
                text = response.choices[0].message.content or ""
                tool_calls = response.choices[0].message.tool_calls or []

                # Add assistant message to history
                messages.append({
                    "role": "assistant",
                    "content": text,
                    "tool_calls": tool_calls if tool_calls else None
                })

                # Yield text message if present
                if text:
                    final_text = text
                    yield AgentMessage(
                        type="text",
                        content=text,
                        raw=response,
                    )

                # Check if we're done (no tool calls)
                if not tool_calls:
                    break

                # Process tool calls
                for tool_call in tool_calls:
                    # Yield tool use message
                    yield AgentMessage(
                        type="tool_use",
                        tool_name=tool_call["function"]["name"],
                        tool_input=tool_call["function"].get("arguments", {}),
                        raw=tool_call,
                    )

                    # Execute tool if handler available
                    tool_result = await self._execute_tool(tool_call, tools)

                    # Yield tool result message
                    yield AgentMessage(
                        type="tool_result",
                        tool_name=tool_call["function"]["name"],
                        tool_result=tool_result["content"],
                        raw=tool_result,
                    )

                    # Add tool result to message history
                    messages.append({
                        "role": "tool",
                        "tool_call_id": tool_call["id"],
                        "name": tool_call["function"]["name"],
                        "content": tool_result["content"]
                    })

            except Exception as e:
                # Yield error and stop
                error_msg = f"Error in agentic loop: {e}"
                print(f"\n{error_msg}", file=sys.stderr)
                yield AgentMessage(
                    type="error",
                    content=error_msg,
                    raw=e,
                )
                raise ProviderError(error_msg) from e

        # Yield final result
        yield AgentMessage(
            type="result",
            content=final_text,
            raw=None,
        )

    async def run_interactive(
        self,
        options: AgentOptions | None = None,
        tools: list[Tool] | None = None,
    ):
        """
        Interactive sessions not yet implemented for agentwrap.

        Use multiple calls to run() instead.
        """
        raise NotImplementedError(
            "Interactive sessions not yet implemented for agentwrap provider. "
            "Use multiple calls to run() for multi-turn conversations."
        )