forked from pgadmin-org/pgadmin4
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathchat.py
More file actions
303 lines (255 loc) · 9.95 KB
/
chat.py
File metadata and controls
303 lines (255 loc) · 9.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2026, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
"""LLM chat functionality with database tool integration.
This module provides high-level functions for running LLM conversations
that can use database tools to query and inspect PostgreSQL databases.
"""
import json
from collections.abc import Generator
from typing import Optional, Union
from pgadmin.llm.client import get_llm_client, is_llm_available
from pgadmin.llm.models import Message, LLMResponse, StopReason
from pgadmin.llm.tools import DATABASE_TOOLS, execute_tool, DatabaseToolError
from pgadmin.llm.utils import get_max_tool_iterations
# Default system prompt for database assistant
DEFAULT_SYSTEM_PROMPT = (
"You are a PostgreSQL database assistant integrated into pgAdmin 4. "
"You have access to tools that allow you to query the database and "
"inspect its schema.\n\n"
"When helping users:\n"
"1. First understand the database structure using get_database_schema "
"or get_table_info\n"
"2. Write efficient SQL queries to answer questions about the data\n"
"3. Explain your findings clearly and concisely\n"
"4. If a query might return many rows, consider using LIMIT or "
"aggregations\n\n"
"Important:\n"
"- All queries run in READ ONLY mode - you cannot modify data\n"
"- Results are limited to 1000 rows\n"
"- Always validate your understanding of the schema before writing "
"complex queries"
)
def chat_with_database(
user_message: str,
sid: int,
did: int,
conversation_history: Optional[list[Message]] = None,
system_prompt: Optional[str] = None,
max_tool_iterations: Optional[int] = None,
provider: Optional[str] = None,
model: Optional[str] = None
) -> tuple[str, list[Message]]:
"""
Run an LLM chat conversation with database tool access.
This function handles the full conversation loop, executing any
tool calls the LLM makes and continuing until a final response
is generated.
Args:
user_message: The user's message/question
sid: Server ID for database connection
did: Database ID for database connection
conversation_history: Optional list of previous messages
system_prompt: Optional custom system prompt (uses default if None)
max_tool_iterations: Maximum number of tool call
rounds. Uses preference setting if None.
provider: Optional LLM provider override
model: Optional model override
Returns:
Tuple of (final_response_text, updated_conversation_history)
Raises:
LLMClientError: If the LLM request fails
RuntimeError: If LLM is not available or max iterations exceeded
"""
if not is_llm_available():
raise RuntimeError("LLM is not configured. Please configure an LLM "
"provider in Preferences > AI.")
client = get_llm_client(provider=provider, model=model)
if not client:
raise RuntimeError("Failed to create LLM client")
# Initialize conversation history
messages = list(conversation_history) if conversation_history else []
messages.append(Message.user(user_message))
# Use default system prompt if none provided
if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
# Get max iterations from preferences if not specified
if max_tool_iterations is None:
max_tool_iterations = get_max_tool_iterations()
iteration = 0
while iteration < max_tool_iterations:
iteration += 1
# Call the LLM
response = client.chat(
messages=messages,
tools=DATABASE_TOOLS,
system_prompt=system_prompt
)
# Add assistant response to history
messages.append(response.to_message())
# Check if we're done
if response.stop_reason != StopReason.TOOL_USE:
return response.content, messages
# Execute tool calls
tool_results = []
for tool_call in response.tool_calls:
try:
result = execute_tool(
tool_name=tool_call.name,
arguments=tool_call.arguments,
sid=sid,
did=did
)
tool_results.append(Message.tool_result(
tool_call_id=tool_call.id,
content=json.dumps(result, default=str),
is_error=False
))
except (DatabaseToolError, ValueError) as e:
tool_results.append(Message.tool_result(
tool_call_id=tool_call.id,
content=json.dumps({"error": str(e)}),
is_error=True
))
except Exception as e:
tool_results.append(Message.tool_result(
tool_call_id=tool_call.id,
content=json.dumps({
"error": f"Unexpected error: {str(e)}"
}),
is_error=True
))
# Add tool results to history
messages.extend(tool_results)
raise RuntimeError(
f"Exceeded maximum tool iterations ({max_tool_iterations})"
)
def chat_with_database_stream(
user_message: str,
sid: int,
did: int,
conversation_history: Optional[list[Message]] = None,
system_prompt: Optional[str] = None,
max_tool_iterations: Optional[int] = None,
provider: Optional[str] = None,
model: Optional[str] = None
) -> Generator[Union[str, tuple], None, None]:
"""
Stream an LLM chat conversation with database tool access.
Like chat_with_database, but yields text chunks as the final
response streams in. During tool-use iterations, no text is
yielded (tools are executed silently).
Yields:
str: Text content chunks from the final LLM response.
The last item yielded is a 3-tuple of
('complete', final_response_text, updated_conversation_history).
Raises:
LLMClientError: If the LLM request fails.
RuntimeError: If LLM is not available or max iterations exceeded.
"""
if not is_llm_available():
raise RuntimeError("LLM is not configured. Please configure an LLM "
"provider in Preferences > AI.")
client = get_llm_client(provider=provider, model=model)
if not client:
raise RuntimeError("Failed to create LLM client")
messages = list(conversation_history) if conversation_history else []
messages.append(Message.user(user_message))
if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
if max_tool_iterations is None:
max_tool_iterations = get_max_tool_iterations()
iteration = 0
while iteration < max_tool_iterations:
iteration += 1
# Stream the LLM response, yielding text chunks as they arrive
response = None
for item in client.chat_stream(
messages=messages,
tools=DATABASE_TOOLS,
system_prompt=system_prompt
):
if isinstance(item, LLMResponse):
response = item
elif isinstance(item, str):
yield item
if response is None:
raise RuntimeError("No response received from LLM")
messages.append(response.to_message())
if response.stop_reason != StopReason.TOOL_USE:
# Final response - yield a 3-tuple to distinguish from
# the 2-tuple tool_use event
yield ('complete', response.content, messages)
return
# Signal that tools are being executed so the caller can
# reset streaming state and show a thinking indicator
yield ('tool_use', [tc.name for tc in response.tool_calls])
# Execute tool calls
tool_results = []
for tool_call in response.tool_calls:
try:
result = execute_tool(
tool_name=tool_call.name,
arguments=tool_call.arguments,
sid=sid,
did=did
)
tool_results.append(Message.tool_result(
tool_call_id=tool_call.id,
content=json.dumps(result, default=str),
is_error=False
))
except (DatabaseToolError, ValueError) as e:
tool_results.append(Message.tool_result(
tool_call_id=tool_call.id,
content=json.dumps({"error": str(e)}),
is_error=True
))
except Exception as e:
tool_results.append(Message.tool_result(
tool_call_id=tool_call.id,
content=json.dumps({
"error": f"Unexpected error: {str(e)}"
}),
is_error=True
))
messages.extend(tool_results)
raise RuntimeError(
f"Exceeded maximum tool iterations ({max_tool_iterations})"
)
def single_query(
question: str,
sid: int,
did: int,
provider: Optional[str] = None,
model: Optional[str] = None
) -> str:
"""
Ask a single question about the database.
This is a convenience function for one-shot questions without
maintaining conversation history.
Args:
question: The question to ask
sid: Server ID
did: Database ID
provider: Optional LLM provider override
model: Optional model override
Returns:
The LLM's response text
Raises:
LLMClientError: If the LLM request fails
RuntimeError: If LLM is not available
"""
response, _ = chat_with_database(
user_message=question,
sid=sid,
did=did,
provider=provider,
model=model
)
return response