-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Expand file tree
/
Copy pathaio.py
More file actions
623 lines (537 loc) · 22.5 KB
/
aio.py
File metadata and controls
623 lines (537 loc) · 22.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
from __future__ import annotations
import asyncio
import logging
from collections import defaultdict
from collections.abc import AsyncIterator, Callable, Iterable, Sequence
from contextlib import asynccontextmanager
from types import TracebackType
from typing import Any, cast
import aiosqlite
import orjson
import sqlite_vec # type: ignore[import-untyped]
from langgraph.store.base import (
GetOp,
ListNamespacesOp,
Op,
PutOp,
Result,
SearchOp,
TTLConfig,
)
from langgraph.store.base.batch import AsyncBatchedBaseStore
from langgraph.store.sqlite.base import (
_PLACEHOLDER,
BaseSqliteStore,
SqliteIndexConfig,
_decode_ns_text,
_ensure_index_config,
_group_ops,
_row_to_item,
_row_to_search_item,
)
logger = logging.getLogger(__name__)
class AsyncSqliteStore(AsyncBatchedBaseStore, BaseSqliteStore):
"""Asynchronous SQLite-backed store with optional vector search.
This class provides an asynchronous interface for storing and retrieving data
using a SQLite database with support for vector search capabilities.
Examples:
Basic setup and usage:
```python
from langgraph.store.sqlite import AsyncSqliteStore
async with AsyncSqliteStore.from_conn_string(":memory:") as store:
await store.setup() # Run migrations
# Store and retrieve data
await store.aput(("users", "123"), "prefs", {"theme": "dark"})
item = await store.aget(("users", "123"), "prefs")
```
Vector search using LangChain embeddings:
```python
from langchain_openai import OpenAIEmbeddings
from langgraph.store.sqlite import AsyncSqliteStore
async with AsyncSqliteStore.from_conn_string(
":memory:",
index={
"dims": 1536,
"embed": OpenAIEmbeddings(),
"fields": ["text"] # specify which fields to embed
}
) as store:
await store.setup() # Run migrations once
# Store documents
await store.aput(("docs",), "doc1", {"text": "Python tutorial"})
await store.aput(("docs",), "doc2", {"text": "TypeScript guide"})
await store.aput(("docs",), "doc3", {"text": "Other guide"}, index=False) # don't index
# Search by similarity
results = await store.asearch(("docs",), query="programming guides", limit=2)
```
Warning:
Make sure to call `setup()` before first use to create necessary tables and indexes.
Note:
This class requires the aiosqlite package. Install with `pip install aiosqlite`.
"""
def __init__(
self,
conn: aiosqlite.Connection,
*,
deserializer: Callable[[bytes | str | orjson.Fragment], dict[str, Any]]
| None = None,
index: SqliteIndexConfig | None = None,
ttl: TTLConfig | None = None,
):
"""Initialize the async SQLite store.
Args:
conn: The SQLite database connection.
deserializer: Optional custom deserializer function for values.
index: Optional vector search configuration.
ttl: Optional time-to-live configuration.
"""
super().__init__()
self._deserializer = deserializer
self.conn = conn
self.lock = asyncio.Lock()
self.loop = asyncio.get_running_loop()
self.is_setup = False
self.index_config = index
if self.index_config:
self.embeddings, self.index_config = _ensure_index_config(self.index_config)
else:
self.embeddings = None
self.ttl_config = ttl
self._ttl_sweeper_task: asyncio.Task[None] | None = None
self._ttl_stop_event = asyncio.Event()
@classmethod
@asynccontextmanager
async def from_conn_string(
cls,
conn_string: str,
*,
index: SqliteIndexConfig | None = None,
ttl: TTLConfig | None = None,
) -> AsyncIterator[AsyncSqliteStore]:
"""Create a new AsyncSqliteStore instance from a connection string.
Args:
conn_string: The SQLite connection string.
index: Optional vector search configuration.
ttl: Optional time-to-live configuration.
Returns:
An AsyncSqliteStore instance wrapped in an async context manager.
"""
async with aiosqlite.connect(conn_string, isolation_level=None) as conn:
yield cls(conn, index=index, ttl=ttl)
async def setup(self) -> None:
"""Set up the store database.
This method creates the necessary tables in the SQLite database if they don't
already exist and runs database migrations. It should be called before first use.
"""
async with self.lock:
if self.is_setup:
return
# Create migrations table if it doesn't exist
await self.conn.execute(
"""
CREATE TABLE IF NOT EXISTS store_migrations (
v INTEGER PRIMARY KEY
)
"""
)
# Check current migration version
async with self.conn.execute(
"SELECT v FROM store_migrations ORDER BY v DESC LIMIT 1"
) as cur:
row = await cur.fetchone()
if row is None:
version = -1
else:
version = row[0]
# Apply migrations
for v, sql in enumerate(self.MIGRATIONS[version + 1 :], start=version + 1):
await self.conn.executescript(sql)
await self.conn.execute(
"INSERT INTO store_migrations (v) VALUES (?)", (v,)
)
# Apply vector migrations if index config is provided
if self.index_config:
# Create vector migrations table if it doesn't exist
await self.conn.enable_load_extension(True)
await self.conn.load_extension(sqlite_vec.loadable_path())
await self.conn.enable_load_extension(False)
await self.conn.execute(
"""
CREATE TABLE IF NOT EXISTS vector_migrations (
v INTEGER PRIMARY KEY
)
"""
)
# Check current vector migration version
async with self.conn.execute(
"SELECT v FROM vector_migrations ORDER BY v DESC LIMIT 1"
) as cur:
row = await cur.fetchone()
if row is None:
version = -1
else:
version = row[0]
# Apply vector migrations
for v, sql in enumerate(
self.VECTOR_MIGRATIONS[version + 1 :], start=version + 1
):
await self.conn.executescript(sql)
await self.conn.execute(
"INSERT INTO vector_migrations (v) VALUES (?)", (v,)
)
self.is_setup = True
@asynccontextmanager
async def _cursor(
self, *, transaction: bool = True
) -> AsyncIterator[aiosqlite.Cursor]:
"""Get a cursor for the SQLite database.
Args:
transaction: Whether to use a transaction for database operations.
Yields:
An SQLite cursor object.
"""
if not self.is_setup:
await self.setup()
async with self.lock:
if transaction:
await self.conn.execute("BEGIN")
async with self.conn.cursor() as cur:
try:
yield cur
finally:
if transaction:
await self.conn.execute("COMMIT")
async def sweep_ttl(self) -> int:
"""Delete expired store items based on TTL.
Returns:
int: The number of deleted items.
"""
async with self._cursor() as cur:
await cur.execute(
"""
DELETE FROM store
WHERE expires_at IS NOT NULL AND expires_at < CURRENT_TIMESTAMP
"""
)
deleted_count = cur.rowcount
return deleted_count
async def start_ttl_sweeper(
self, sweep_interval_minutes: int | None = None
) -> asyncio.Task[None]:
"""Periodically delete expired store items based on TTL.
Returns:
Task that can be awaited or cancelled.
"""
if not self.ttl_config:
return asyncio.create_task(asyncio.sleep(0))
if self._ttl_sweeper_task is not None and not self._ttl_sweeper_task.done():
return self._ttl_sweeper_task
self._ttl_stop_event.clear()
interval = float(
sweep_interval_minutes or self.ttl_config.get("sweep_interval_minutes") or 5
)
logger.info(f"Starting store TTL sweeper with interval {interval} minutes")
async def _sweep_loop() -> None:
while not self._ttl_stop_event.is_set():
try:
try:
await asyncio.wait_for(
self._ttl_stop_event.wait(),
timeout=interval * 60,
)
break
except asyncio.TimeoutError:
pass
expired_items = await self.sweep_ttl()
if expired_items > 0:
logger.info(f"Store swept {expired_items} expired items")
except asyncio.CancelledError:
break
except Exception as exc:
logger.exception("Store TTL sweep iteration failed", exc_info=exc)
task = asyncio.create_task(_sweep_loop())
task.set_name("ttl_sweeper")
self._ttl_sweeper_task = task
return task
async def stop_ttl_sweeper(self, timeout: float | None = None) -> bool:
"""Stop the TTL sweeper task if it's running.
Args:
timeout: Maximum time to wait for the task to stop, in seconds.
If `None`, wait indefinitely.
Returns:
bool: True if the task was successfully stopped or wasn't running,
False if the timeout was reached before the task stopped.
"""
if self._ttl_sweeper_task is None or self._ttl_sweeper_task.done():
return True
logger.info("Stopping TTL sweeper task")
self._ttl_stop_event.set()
if timeout is not None:
try:
await asyncio.wait_for(self._ttl_sweeper_task, timeout=timeout)
success = True
except asyncio.TimeoutError:
success = False
else:
await self._ttl_sweeper_task
success = True
if success:
self._ttl_sweeper_task = None
logger.info("TTL sweeper task stopped")
else:
logger.warning("Timed out waiting for TTL sweeper task to stop")
return success
async def __aenter__(self) -> AsyncSqliteStore:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
# Ensure the TTL sweeper task is stopped when exiting the context
if hasattr(self, "_ttl_sweeper_task") and self._ttl_sweeper_task is not None:
# Set the event to signal the task to stop
self._ttl_stop_event.set()
# We don't wait for the task to complete here to avoid blocking
# The task will clean up itself gracefully
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
"""Execute a batch of operations asynchronously.
Args:
ops: Iterable of operations to execute.
Returns:
List of operation results.
"""
grouped_ops, num_ops = _group_ops(ops)
results: list[Result] = [None] * num_ops
async with self._cursor(transaction=True) as cur:
if GetOp in grouped_ops:
await self._batch_get_ops(
cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), results, cur
)
if SearchOp in grouped_ops:
await self._batch_search_ops(
cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]),
results,
cur,
)
if ListNamespacesOp in grouped_ops:
await self._batch_list_namespaces_ops(
cast(
Sequence[tuple[int, ListNamespacesOp]],
grouped_ops[ListNamespacesOp],
),
results,
cur,
)
if PutOp in grouped_ops:
await self._batch_put_ops(
cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]), cur
)
return results
async def _batch_get_ops(
self,
get_ops: Sequence[tuple[int, GetOp]],
results: list[Result],
cur: aiosqlite.Cursor,
) -> None:
"""Process batch GET operations.
Args:
get_ops: Sequence of GET operations.
results: List to store results in.
cur: Database cursor.
"""
# Group all queries by namespace to execute all operations for each namespace together
namespace_queries = defaultdict(list)
for prepared_query in self._get_batch_GET_ops_queries(get_ops):
namespace_queries[prepared_query.namespace].append(prepared_query)
# Process each namespace's operations
for namespace, queries in namespace_queries.items():
# Execute TTL refresh queries first
for query in queries:
if query.kind == "refresh":
try:
await cur.execute(query.query, query.params)
except Exception as e:
raise ValueError(
f"Error executing TTL refresh: \n{query.query}\n{query.params}\n{e}"
) from e
# Then execute GET queries and process results
for query in queries:
if query.kind == "get":
try:
await cur.execute(query.query, query.params)
except Exception as e:
raise ValueError(
f"Error executing GET query: \n{query.query}\n{query.params}\n{e}"
) from e
rows = await cur.fetchall()
key_to_row = {
row[0]: {
"key": row[0],
"value": row[1],
"created_at": row[2],
"updated_at": row[3],
"expires_at": row[4] if len(row) > 4 else None,
"ttl_minutes": row[5] if len(row) > 5 else None,
}
for row in rows
}
# Process results for this query
for idx, key in query.items:
row = key_to_row.get(key)
if row:
results[idx] = _row_to_item(
namespace, row, loader=self._deserializer
)
else:
results[idx] = None
async def _batch_put_ops(
self,
put_ops: Sequence[tuple[int, PutOp]],
cur: aiosqlite.Cursor,
) -> None:
"""Process batch PUT operations.
Args:
put_ops: Sequence of PUT operations.
cur: Database cursor.
"""
queries, embedding_request = self._prepare_batch_PUT_queries(put_ops)
if embedding_request:
if self.embeddings is None:
# Should not get here since the embedding config is required
# to return an embedding_request above
raise ValueError(
"Embedding configuration is required for vector operations "
f"(for semantic search). "
f"Please provide an Embeddings when initializing the {self.__class__.__name__}."
)
query, txt_params = embedding_request
# Update the params to replace the raw text with the vectors
vectors = await self.embeddings.aembed_documents(
[param[-1] for param in txt_params]
)
# Convert vectors to SQLite-friendly format
vector_params = []
for (ns, k, pathname, _), vector in zip(txt_params, vectors, strict=False):
vector_params.extend(
[ns, k, pathname, sqlite_vec.serialize_float32(vector)]
)
queries.append((query, vector_params))
for query, params in queries:
await cur.execute(query, params)
async def _batch_search_ops(
self,
search_ops: Sequence[tuple[int, SearchOp]],
results: list[Result],
cur: aiosqlite.Cursor,
) -> None:
"""Process batch SEARCH operations.
Args:
search_ops: Sequence of SEARCH operations.
results: List to store results in.
cur: Database cursor.
"""
prepared_queries, embedding_requests = self._prepare_batch_search_queries(
search_ops
)
# Setup dot_product function if it doesn't exist
if embedding_requests and self.embeddings:
vectors = await self.embeddings.aembed_documents(
[query for _, query in embedding_requests]
)
for (embed_req_idx, _), embedding in zip(
embedding_requests, vectors, strict=False
):
# Find the corresponding query in prepared_queries
# The embed_req_idx is the original index in search_ops, which should map to prepared_queries
if embed_req_idx < len(prepared_queries):
_params_list: list = prepared_queries[embed_req_idx][1]
for i, param in enumerate(_params_list):
if param is _PLACEHOLDER:
_params_list[i] = sqlite_vec.serialize_float32(embedding)
else:
logger.warning(
f"Embedding request index {embed_req_idx} out of bounds for prepared_queries."
)
for (original_op_idx, _), (query, params, needs_refresh) in zip(
search_ops, prepared_queries, strict=False
):
await cur.execute(query, params)
rows = await cur.fetchall()
if needs_refresh and rows and self.ttl_config:
keys_to_refresh = []
for row_data in rows:
# Assuming row_data[0] is prefix (text), row_data[1] is key (text)
# These are raw text values directly from the DB.
keys_to_refresh.append((row_data[0], row_data[1]))
if keys_to_refresh:
updates_by_prefix = defaultdict(list)
for prefix_text, key_text in keys_to_refresh:
updates_by_prefix[prefix_text].append(key_text)
for prefix_text, key_list in updates_by_prefix.items():
placeholders = ",".join(["?"] * len(key_list))
update_query = f"""
UPDATE store
SET expires_at = DATETIME(CURRENT_TIMESTAMP, '+' || ttl_minutes || ' minutes')
WHERE prefix = ? AND key IN ({placeholders}) AND ttl_minutes IS NOT NULL
"""
update_params = (prefix_text, *key_list)
try:
await cur.execute(update_query, update_params)
except Exception as e:
logger.error(
f"Error during TTL refresh update for search: {e}"
)
# Process rows into items
if "score" in query: # Vector search query
items = [
_row_to_search_item(
_decode_ns_text(row[0]), # prefix
{
"key": row[1], # key
"value": row[2], # value
"created_at": row[3],
"updated_at": row[4],
"expires_at": row[5] if len(row) > 5 else None,
"ttl_minutes": row[6] if len(row) > 6 else None,
"score": row[7] if len(row) > 7 else None,
},
loader=self._deserializer,
)
for row in rows
]
else: # Regular search query
items = [
_row_to_search_item(
_decode_ns_text(row[0]), # prefix
{
"key": row[1], # key
"value": row[2], # value
"created_at": row[3],
"updated_at": row[4],
"expires_at": row[5] if len(row) > 5 else None,
"ttl_minutes": row[6] if len(row) > 6 else None,
},
loader=self._deserializer,
)
for row in rows
]
results[original_op_idx] = items
async def _batch_list_namespaces_ops(
self,
list_ops: Sequence[tuple[int, ListNamespacesOp]],
results: list[Result],
cur: aiosqlite.Cursor,
) -> None:
"""Process batch LIST NAMESPACES operations.
Args:
list_ops: Sequence of LIST NAMESPACES operations.
results: List to store results in.
cur: Database cursor.
"""
queries = self._get_batch_list_namespaces_queries(list_ops)
for (query, params), (idx, _) in zip(queries, list_ops, strict=False):
await cur.execute(query, params)
rows = await cur.fetchall()
results[idx] = [_decode_ns_text(row[0]) for row in rows]