Skip to content

Commit e657a8e

Browse files
authored
Merge pull request #86 from CaliLuke/feat/driver-injection
feat: Driver injection, FunctionQuery, and generator fixes
2 parents 5ccdbf0 + ad5f40b commit e657a8e

9 files changed

Lines changed: 432 additions & 17 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type-bridge provides Pythonic/TypeScript abstractions over TypeDB's native TypeQ
2323

2424
| Package | Version | Install |
2525
|---------|---------|---------|
26-
| [Python](./packages/python) | 1.1.0 | `pip install type-bridge` |
26+
| [Python](./packages/python) | 1.2.2 | `pip install type-bridge` |
2727
| [TypeScript](./packages/typescript) | 0.1.0 | `npm install @type-bridge/type-bridge` |
2828

2929
## Quick Start

packages/python/docs/api/crud.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,6 +1381,87 @@ count = user_manager.filter(Age.gt(Age(65))).delete() # Returns count
13811381
# Use filter().delete() instead for filter-based deletion
13821382
```
13831383

1384+
## Database Configuration
1385+
1386+
### Basic Connection
1387+
1388+
```python
1389+
from type_bridge import Database
1390+
1391+
# Default connection
1392+
db = Database() # localhost:1729, database="typedb"
1393+
1394+
# Custom connection
1395+
db = Database(
1396+
address="192.168.1.100:1729",
1397+
database="mydb",
1398+
username="admin",
1399+
password="secret"
1400+
)
1401+
db.connect()
1402+
1403+
# Context manager (auto-connects and closes)
1404+
with Database(database="mydb") as db:
1405+
person_manager = Person.manager(db)
1406+
# ... operations ...
1407+
```
1408+
1409+
### Driver Injection
1410+
1411+
For advanced use cases, you can inject an external `Driver` instance instead of having `Database` create one internally. This enables:
1412+
1413+
- **Connection sharing** across multiple `Database` instances
1414+
- **Resource pooling** with custom driver management
1415+
- **Easier testing** via mock driver injection
1416+
1417+
```python
1418+
from typedb.driver import TypeDB, Credentials, DriverOptions
1419+
1420+
# Create a shared driver
1421+
driver = TypeDB.driver(
1422+
"localhost:1729",
1423+
Credentials("admin", "password"),
1424+
DriverOptions()
1425+
)
1426+
1427+
# Multiple databases share one connection
1428+
db1 = Database(database="project_a", driver=driver)
1429+
db2 = Database(database="project_b", driver=driver)
1430+
1431+
# Use both databases
1432+
with db1.transaction("write") as tx:
1433+
Person.manager(tx).insert(alice)
1434+
1435+
with db2.transaction("read") as tx:
1436+
results = Artifact.manager(tx).all()
1437+
1438+
# Close databases (only clears references, doesn't close driver)
1439+
db1.close()
1440+
db2.close()
1441+
1442+
# Close driver when done (caller's responsibility)
1443+
driver.close()
1444+
```
1445+
1446+
**Ownership semantics:**
1447+
- `driver=None` (default): `Database` creates and owns the driver, `close()` closes it
1448+
- `driver=<Driver>`: `Database` uses but doesn't own it, `close()` only clears the reference
1449+
1450+
### Testing with Mock Driver
1451+
1452+
```python
1453+
from unittest.mock import MagicMock
1454+
1455+
def test_database_operations():
1456+
mock_driver = MagicMock()
1457+
mock_driver.databases.contains.return_value = True
1458+
1459+
db = Database(database="test_db", driver=mock_driver)
1460+
1461+
assert db.database_exists() is True
1462+
mock_driver.databases.contains.assert_called_with("test_db")
1463+
```
1464+
13841465
## See Also
13851466

13861467
- [Entities](entities.md) - Entity definition

packages/python/docs/api/generator.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ relation authoring sub contribution,
243243
relation review,
244244
relates reviewer,
245245
relates reviewed,
246-
owns score,
247-
owns timestamp;
246+
owns score @card(1), // Required attribute
247+
owns timestamp; // Optional (no @card = 0..1)
248248
249249
// Cardinality constraints on roles
250250
relation social-relation @abstract,
@@ -277,8 +277,8 @@ class Authoring(Contribution):
277277

278278
class Review(Relation):
279279
flags = TypeFlags(name="review")
280-
score: attributes.Score
281-
timestamp: attributes.Timestamp
280+
score: attributes.Score # Required (@card(1))
281+
timestamp: attributes.Timestamp | None = None # Optional (no @card)
282282
reviewer: Role[entities.User] = Role("reviewer", entities.User)
283283
reviewed: Role[entities.Publication] = Role("reviewed", entities.Publication)
284284
```
@@ -375,6 +375,8 @@ relation friendship,
375375

376376
## Cardinality Mapping
377377

378+
The following cardinality rules apply to attributes on both **entities** and **relations**:
379+
378380
| TypeQL | Python Type | Default |
379381
|--------|-------------|---------|
380382
| `@card(1)` or `@card(1..1)` | `Type` | Required |
@@ -385,6 +387,8 @@ relation friendship,
385387
| `@key` | `Type = Flag(Key)` | Key (implies required) |
386388
| `@unique` | `Type = Flag(Unique)` | Unique (implies required) |
387389

390+
**Inheritance:** Child types inherit cardinality constraints from parent types. A child can override inherited constraints by redeclaring the attribute with a different `@card`.
391+
388392
## Comments
389393

390394
The parser supports both `#` (shell-style) and `//` (C-style) comments:

packages/python/tests/unit/generator/test_generator.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,152 @@ def test_relation_with_owns(self) -> None:
255255

256256
assert "since: attributes.Since" in source
257257

258+
def test_relation_optional_attribute(self) -> None:
259+
"""Render relation with optional attribute (no @card constraint)."""
260+
schema = parse_tql_schema("""
261+
define
262+
attribute sequence_index, value integer;
263+
entity milestone,
264+
plays task_grouping:milestone;
265+
entity task,
266+
plays task_grouping:task;
267+
268+
define
269+
relation task_grouping,
270+
relates milestone @card(1),
271+
relates task @card(0..),
272+
owns sequence_index;
273+
""")
274+
attr_names = build_class_name_map(schema.attributes)
275+
entity_names = build_class_name_map(schema.entities)
276+
relation_names = build_class_name_map(schema.relations)
277+
source = render_relations(schema, attr_names, entity_names, relation_names)
278+
279+
# Without @card constraint, attribute should be optional
280+
assert "sequence_index: attributes.Sequence_index | None = None" in source
281+
282+
def test_relation_required_attribute(self) -> None:
283+
"""Render relation with required attribute (@card(1))."""
284+
schema = parse_tql_schema("""
285+
define
286+
attribute weight, value double;
287+
entity node,
288+
plays edge:endpoint;
289+
290+
define
291+
relation edge,
292+
relates endpoint,
293+
owns weight @card(1);
294+
""")
295+
attr_names = build_class_name_map(schema.attributes)
296+
entity_names = build_class_name_map(schema.entities)
297+
relation_names = build_class_name_map(schema.relations)
298+
source = render_relations(schema, attr_names, entity_names, relation_names)
299+
300+
# With @card(1), attribute should be required (no | None = None)
301+
assert "weight: attributes.Weight" in source
302+
assert "weight: attributes.Weight | None" not in source
303+
304+
def test_relation_key_attribute(self) -> None:
305+
"""Render relation with @key attribute."""
306+
schema = parse_tql_schema("""
307+
define
308+
attribute edge_id, value string;
309+
entity node,
310+
plays connection:endpoint;
311+
312+
define
313+
relation connection,
314+
relates endpoint,
315+
owns edge_id @key;
316+
""")
317+
attr_names = build_class_name_map(schema.attributes)
318+
entity_names = build_class_name_map(schema.entities)
319+
relation_names = build_class_name_map(schema.relations)
320+
source = render_relations(schema, attr_names, entity_names, relation_names)
321+
322+
# With @key, attribute should use Flag(Key)
323+
assert "edge_id: attributes.Edge_id = Flag(Key)" in source
324+
assert "from type_bridge import" in source
325+
assert "Flag" in source
326+
assert "Key" in source
327+
328+
def test_relation_multi_value_attribute(self) -> None:
329+
"""Render relation with multi-value attribute (@card(0..))."""
330+
schema = parse_tql_schema("""
331+
define
332+
attribute tag, value string;
333+
entity item,
334+
plays tagging:item;
335+
336+
define
337+
relation tagging,
338+
relates item,
339+
owns tag @card(0..);
340+
""")
341+
attr_names = build_class_name_map(schema.attributes)
342+
entity_names = build_class_name_map(schema.entities)
343+
relation_names = build_class_name_map(schema.relations)
344+
source = render_relations(schema, attr_names, entity_names, relation_names)
345+
346+
# With @card(0..), attribute should be a list
347+
assert "list[attributes.Tag]" in source
348+
assert "Card" in source
349+
350+
def test_relation_inherits_key_from_parent(self) -> None:
351+
"""Child relation inherits @key constraint from parent."""
352+
schema = parse_tql_schema("""
353+
define
354+
attribute rel_id, value string;
355+
entity node,
356+
plays base_rel:endpoint,
357+
plays child_rel:endpoint;
358+
359+
define
360+
relation base_rel @abstract,
361+
relates endpoint,
362+
owns rel_id @key;
363+
364+
define
365+
relation child_rel sub base_rel;
366+
""")
367+
attr_names = build_class_name_map(schema.attributes)
368+
entity_names = build_class_name_map(schema.entities)
369+
relation_names = build_class_name_map(schema.relations)
370+
source = render_relations(schema, attr_names, entity_names, relation_names)
371+
372+
# Child should inherit @key from parent
373+
assert "class Child_rel(Base_rel):" in source
374+
assert "rel_id: attributes.Rel_id = Flag(Key)" in source
375+
376+
def test_relation_inherits_cardinality_from_parent(self) -> None:
377+
"""Child relation inherits cardinality constraint from parent."""
378+
schema = parse_tql_schema("""
379+
define
380+
attribute weight, value double;
381+
entity node,
382+
plays base_edge:endpoint,
383+
plays weighted_edge:endpoint;
384+
385+
define
386+
relation base_edge @abstract,
387+
relates endpoint,
388+
owns weight @card(1);
389+
390+
define
391+
relation weighted_edge sub base_edge;
392+
""")
393+
attr_names = build_class_name_map(schema.attributes)
394+
entity_names = build_class_name_map(schema.entities)
395+
relation_names = build_class_name_map(schema.relations)
396+
source = render_relations(schema, attr_names, entity_names, relation_names)
397+
398+
# Child should inherit required cardinality from parent
399+
# The parent declares weight as required, child should not have it optional
400+
assert "class Weighted_edge(Base_edge):" in source
401+
# weight should NOT be in child since it's inherited from parent
402+
# But if it were re-declared, it should still be required
403+
258404

259405
class TestComingSoonAnnotationStubs:
260406
"""Tests for coming-soon annotation stubs (TODO comments in generated code)."""

packages/python/tests/unit/session/test_session_unit.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,104 @@ def test_connection_accepts_transaction_context(self):
275275
# Type checking - Connection should accept TransactionContext
276276
conn: Connection = ctx
277277
assert isinstance(conn, TransactionContext)
278+
279+
280+
class TestDriverInjection:
281+
"""Tests for external driver injection feature (issue #85)."""
282+
283+
def test_driver_none_by_default(self):
284+
"""Database should have no driver by default."""
285+
db = Database()
286+
assert db._driver is None
287+
assert db._owns_driver is True # Will own any driver it creates
288+
289+
def test_injected_driver_stored(self):
290+
"""Database should store injected driver."""
291+
mock_driver = MagicMock()
292+
db = Database(driver=mock_driver)
293+
assert db._driver is mock_driver
294+
assert db._owns_driver is False # Does not own injected driver
295+
296+
def test_connect_skips_when_driver_injected(self):
297+
"""connect() should be a no-op when driver is injected."""
298+
mock_driver = MagicMock()
299+
db = Database(driver=mock_driver)
300+
301+
# connect() should not modify the driver
302+
db.connect()
303+
assert db._driver is mock_driver
304+
assert db._owns_driver is False
305+
306+
def test_close_clears_reference_but_does_not_close_injected_driver(self):
307+
"""close() should clear reference but not close injected driver."""
308+
mock_driver = MagicMock()
309+
db = Database(driver=mock_driver)
310+
311+
db.close()
312+
313+
# Reference should be cleared
314+
assert db._driver is None
315+
# But close() should NOT have been called on the driver
316+
mock_driver.close.assert_not_called()
317+
318+
def test_close_closes_owned_driver(self):
319+
"""close() should close driver when Database owns it."""
320+
mock_driver = MagicMock()
321+
db = Database()
322+
# Simulate connect() creating a driver
323+
db._driver = mock_driver
324+
db._owns_driver = True
325+
326+
db.close()
327+
328+
# Driver should be closed
329+
mock_driver.close.assert_called_once()
330+
assert db._driver is None
331+
332+
def test_driver_property_returns_injected_driver(self):
333+
"""driver property should return injected driver without connecting."""
334+
mock_driver = MagicMock()
335+
db = Database(driver=mock_driver)
336+
337+
# Accessing driver property should return the injected driver
338+
assert db.driver is mock_driver
339+
# connect() should not have been called (no new driver created)
340+
assert db._owns_driver is False
341+
342+
def test_context_manager_with_injected_driver(self):
343+
"""Context manager should work with injected driver."""
344+
mock_driver = MagicMock()
345+
346+
with Database(driver=mock_driver) as db:
347+
assert db._driver is mock_driver
348+
349+
# After exit, reference cleared but driver not closed
350+
assert db._driver is None
351+
mock_driver.close.assert_not_called()
352+
353+
def test_multiple_databases_share_driver(self):
354+
"""Multiple Database instances can share the same driver."""
355+
mock_driver = MagicMock()
356+
357+
db1 = Database(database="db1", driver=mock_driver)
358+
db2 = Database(database="db2", driver=mock_driver)
359+
360+
assert db1._driver is mock_driver
361+
assert db2._driver is mock_driver
362+
assert db1._owns_driver is False
363+
assert db2._owns_driver is False
364+
365+
# Close both - driver should NOT be closed
366+
db1.close()
367+
db2.close()
368+
mock_driver.close.assert_not_called()
369+
370+
def test_database_exists_with_injected_driver(self):
371+
"""database_exists() should work with injected driver."""
372+
mock_driver = MagicMock()
373+
mock_driver.databases.contains.return_value = True
374+
375+
db = Database(database="test_db", driver=mock_driver)
376+
377+
assert db.database_exists() is True
378+
mock_driver.databases.contains.assert_called_with("test_db")

0 commit comments

Comments
 (0)